├── README.md ├── models ├── __init__.py └── networks.py ├── utils ├── __init__.py ├── data_utils.py ├── training_utils.py └── visualization_utils.py ├── environments ├── __init__.py ├── lqr_simple.py ├── cartpole.py └── pendulum.py ├── learning ├── configs │ ├── __init__.py │ └── default.py └── fit_dynamics_model.py ├── data ├── checkpoint_0 ├── checkpoint_1000 ├── dataset.npy.npz ├── checkpoints │ ├── checkpoint_620 │ ├── checkpoint_31000 │ ├── checkpoint_15500-MLP-swish-pendulum_determinstic_uniform_dataset-5000 │ ├── checkpoint_3000-MLP-swish-pendulum_determinstic_uniform_dataset-1000 │ ├── checkpoint_3000-MLP-swish-pendulum_determinstic_uniform_dataset-500 │ ├── checkpoint_62500-MLP-relu-pendulum_determinstic_uniform_dataset-20000 │ ├── checkpoint_62500-MLP-swish-pendulum_determinstic_uniform_dataset-20000 │ └── checkpoint_31000-MLP-relu-pendulum_determinstic_optimal_policy_dataset-10000 ├── pendulum_determinstic_dataset.npy ├── pendulum_determinstic_NN_state.npy ├── pendulum_determinstic_uniform_dataset-100.npy ├── pendulum_determinstic_uniform_dataset-1000.npy ├── pendulum_determinstic_uniform_dataset-200.npy ├── pendulum_determinstic_uniform_dataset-500.npy ├── pendulum_determinstic_uniform_dataset-5000.npy ├── pendulum_determinstic_uniform_dataset-10000.npy └── pendulum_determinstic_uniform_dataset-20000.npy ├── assets ├── lqr_gtsam.png ├── lqr_watson20.png ├── pendulum_watson20.png ├── pendulum_MPC_watson20.png └── pendulum_MPC_funcapprox_watson20.png ├── .gitignore ├── requirements.txt ├── control ├── __init__.py ├── _variables.py ├── _general_factors.py └── _factors.py ├── LICENSE └── examples ├── collect_data ├── pendulum_uniform.py └── pendulum_determinstic.py ├── deprecated └── lqr_example_deprecated.py ├── lqr ├── lqr_watson20.py └── lqr_gtsam.py └── pendulum ├── nonlinsq_open_loop_known.py ├── nonlinsq_closed_loop_known.py └── nonlinsq_closed_loop_funcaprox.py /README.md: -------------------------------------------------------------------------------- 1 | # bic -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /environments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /learning/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/checkpoint_0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoint_0 -------------------------------------------------------------------------------- /assets/lqr_gtsam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/assets/lqr_gtsam.png -------------------------------------------------------------------------------- /data/checkpoint_1000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoint_1000 -------------------------------------------------------------------------------- /data/dataset.npy.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/dataset.npy.npz -------------------------------------------------------------------------------- /assets/lqr_watson20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/assets/lqr_watson20.png -------------------------------------------------------------------------------- /assets/pendulum_watson20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/assets/pendulum_watson20.png -------------------------------------------------------------------------------- /data/checkpoints/checkpoint_620: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoints/checkpoint_620 -------------------------------------------------------------------------------- /assets/pendulum_MPC_watson20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/assets/pendulum_MPC_watson20.png -------------------------------------------------------------------------------- /data/checkpoints/checkpoint_31000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoints/checkpoint_31000 -------------------------------------------------------------------------------- /data/pendulum_determinstic_dataset.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_dataset.npy -------------------------------------------------------------------------------- /data/pendulum_determinstic_NN_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_NN_state.npy -------------------------------------------------------------------------------- /assets/pendulum_MPC_funcapprox_watson20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/assets/pendulum_MPC_funcapprox_watson20.png -------------------------------------------------------------------------------- /data/pendulum_determinstic_uniform_dataset-100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_uniform_dataset-100.npy -------------------------------------------------------------------------------- /data/pendulum_determinstic_uniform_dataset-1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_uniform_dataset-1000.npy -------------------------------------------------------------------------------- /data/pendulum_determinstic_uniform_dataset-200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_uniform_dataset-200.npy -------------------------------------------------------------------------------- /data/pendulum_determinstic_uniform_dataset-500.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_uniform_dataset-500.npy -------------------------------------------------------------------------------- /data/pendulum_determinstic_uniform_dataset-5000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_uniform_dataset-5000.npy -------------------------------------------------------------------------------- /data/pendulum_determinstic_uniform_dataset-10000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_uniform_dataset-10000.npy -------------------------------------------------------------------------------- /data/pendulum_determinstic_uniform_dataset-20000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/pendulum_determinstic_uniform_dataset-20000.npy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.egg-info 4 | __pycache__ 5 | .mypy_cache 6 | .dmypy.json 7 | .pytype 8 | .hypothesis 9 | .ipynb_checkpoints 10 | .DS_Store 11 | .idea 12 | # data/ 13 | -------------------------------------------------------------------------------- /data/checkpoints/checkpoint_15500-MLP-swish-pendulum_determinstic_uniform_dataset-5000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoints/checkpoint_15500-MLP-swish-pendulum_determinstic_uniform_dataset-5000 -------------------------------------------------------------------------------- /data/checkpoints/checkpoint_3000-MLP-swish-pendulum_determinstic_uniform_dataset-1000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoints/checkpoint_3000-MLP-swish-pendulum_determinstic_uniform_dataset-1000 -------------------------------------------------------------------------------- /data/checkpoints/checkpoint_3000-MLP-swish-pendulum_determinstic_uniform_dataset-500: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoints/checkpoint_3000-MLP-swish-pendulum_determinstic_uniform_dataset-500 -------------------------------------------------------------------------------- /data/checkpoints/checkpoint_62500-MLP-relu-pendulum_determinstic_uniform_dataset-20000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoints/checkpoint_62500-MLP-relu-pendulum_determinstic_uniform_dataset-20000 -------------------------------------------------------------------------------- /data/checkpoints/checkpoint_62500-MLP-swish-pendulum_determinstic_uniform_dataset-20000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoints/checkpoint_62500-MLP-swish-pendulum_determinstic_uniform_dataset-20000 -------------------------------------------------------------------------------- /data/checkpoints/checkpoint_31000-MLP-relu-pendulum_determinstic_optimal_policy_dataset-10000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/bic/main/data/checkpoints/checkpoint_31000-MLP-relu-pendulum_determinstic_optimal_policy_dataset-10000 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | absl_py==1.0.0 3 | clu==0.0.6 4 | flax==0.4.1 5 | jax==0.3.4 6 | jax_dataclasses==1.2.1 7 | matplotlib==3.5.1 8 | ml_collections==0.1.1 9 | numpy==1.21.5 10 | optax==0.1.1 11 | overrides==6.1.0 12 | seaborn==0.11.2 13 | tensorflow==2.8.0 14 | -------------------------------------------------------------------------------- /control/__init__.py: -------------------------------------------------------------------------------- 1 | from ._factors import PriorFactor, TransformedPriorFactor, LQRTripletFactor, LQRInitialFactor 2 | from ._general_factors import GeneralFactorSAS, GeneralFactorAS 3 | from ._variables import BoundedRealVectorVariable 4 | 5 | __all__ = ["PriorFactor", "TransformedPriorFactor", "LQRTripletFactor", "LQRInitialFactor", 6 | "GeneralFactorSAS", "GeneralFactorAS", 7 | "BoundedRealVectorVariable"] 8 | 9 | -------------------------------------------------------------------------------- /environments/lqr_simple.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | class LQREnv: 5 | 6 | @staticmethod 7 | @jax.jit 8 | def lqr_simple_watson(s0, a0): 9 | A = jnp.array([[1.1, 0.0], [0.1, 1.1]]) # state transition linear operator on x: Ax 10 | B = jnp.array([[0.1], [0.0]]) # state transition linear operator on u: Bu 11 | c = jnp.array([-1., -2.]) # state transition bias, i.e., x(t+1) = Ax(t) + Bu(t) + c 12 | return jnp.dot(A, s0) + jnp.dot(B, a0) + c 13 | 14 | @staticmethod 15 | @jax.jit 16 | def lqr_simple_gtsam(s0, a0): 17 | A = jnp.array([[1.03]]) # slightly unstable system :) 18 | B = jnp.array([[0.03]]) 19 | return jnp.dot(A, s0) + jnp.dot(B, a0) 20 | 21 | -------------------------------------------------------------------------------- /environments/cartpole.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | @jax.jit 6 | def cartpole_dynamics(state, action): 7 | 8 | g = 9.81 9 | Mc = 0.37 10 | Mp = 0.127 11 | Mt = Mc + Mp 12 | l = 0.3365 13 | fs_hz = 250.0 14 | dt = 1 / fs_hz 15 | u_mx = 5.0 16 | 17 | if len(state.shape) == 1: 18 | assert len(state.shape) == len(action.shape) 19 | state = jnp.reshape(state, (1, -1)) 20 | action = jnp.reshape(action, (1, -1)) 21 | 22 | x = state 23 | u = action 24 | 25 | _u = jnp.clip(u, -u_mx, u_mx).squeeze() 26 | 27 | th = x[:, 1] 28 | dth2 = jnp.power(x[:, 3], 2) 29 | sth = jnp.sin(th) 30 | cth = jnp.cos(th) 31 | 32 | _num = -Mp * l * sth * cth * dth2 + Mt * g * sth - _u * cth 33 | _denom = l * ((4.0 / 3.0) * Mt - Mp * cth ** 2) 34 | th_acc = _num / _denom 35 | x_acc = (Mp * l * sth * dth2 - Mp * l * th_acc * cth + _u) / Mt 36 | 37 | y1 = x[:, 0] + dt * x[:, 2] 38 | y2 = x[:, 1] + dt * x[:, 3] 39 | y3 = x[:, 2] + dt * x_acc 40 | y4 = x[:, 3] + dt * th_acc 41 | 42 | y = jnp.vstack((y1, y2, y3, y4)).T 43 | return y 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /environments/pendulum.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax 3 | 4 | @jax.jit 5 | def pendulum_dynamics(state: jnp.ndarray, 6 | action: jnp.ndarray) -> jnp.ndarray: 7 | dt = 0.05 8 | m = 1.0 9 | l = 1.0 10 | d = 1e-2 # damping 11 | g = 9.80665 12 | u_mx = 5. 13 | 14 | batch_mode = len(state.shape) ==2 15 | 16 | if not batch_mode: 17 | assert len(state.shape) == len(action.shape) 18 | state = jnp.reshape(state, (1, -1)) 19 | action = jnp.reshape(action, (1, -1)) 20 | 21 | x = state 22 | u = action 23 | u = jnp.clip(u, -u_mx, u_mx) 24 | th_dot_dot = -3.0 * g / (2 * l) * jnp.sin(x[:, 0] + jnp.pi) - d * x[:, 1] 25 | th_dot_dot += 3.0 / (m * l ** 2) * u.squeeze() 26 | x_dot = x[:, 1] + th_dot_dot * dt 27 | x_pos = x[:, 0] + x_dot * dt 28 | x2 = jnp.vstack((x_pos, x_dot)).T 29 | if not batch_mode: 30 | return x2.reshape(-1) 31 | else: 32 | return x2 33 | 34 | 35 | if __name__ == "__main__": 36 | # print_yay(1, 2) 37 | x = jnp.array([1, 1.]) 38 | u = jnp.array([1.]) 39 | print(pendulum_dynamics(x, u), type(pendulum_dynamics(x, u))) 40 | -------------------------------------------------------------------------------- /learning/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Default Hyperparameter configuration.""" 16 | import ml_collections 17 | 18 | 19 | def get_config(): 20 | """Get the default hyperparameter configuration.""" 21 | config = ml_collections.ConfigDict() 22 | 23 | config.learning_rate = 0.01 24 | config.momentum = 0.9 25 | config.batch_size = 128 26 | config.num_epochs = 1000 27 | config.data_path = 'data/pendulum_determinstic_uniform_dataset-200.npy' 28 | config.input_dim = 3 29 | config.activation = 'swish' 30 | config.fc_dims = [128, 64, 20, 2] 31 | return config 32 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import numpy as np 4 | import jax 5 | import jax.numpy as jnp 6 | import flax.linen as nn 7 | 8 | class MLP(nn.Module): 9 | features: Sequence[int] 10 | activation: str 11 | 12 | @nn.compact 13 | def __call__(self, x): 14 | for feat in self.features[:-1]: 15 | if self.activation == 'relu': 16 | x = nn.relu(nn.Dense(feat)(x)) 17 | elif self.activation == 'swish': 18 | x = nn.swish(nn.Dense(feat)(x)) 19 | else: 20 | raise NotImplementedError 21 | x = nn.Dense(self.features[-1])(x) 22 | return x 23 | 24 | # import jax 25 | # import haiku as hk 26 | # import jax.numpy as jnp 27 | # 28 | # from typing import Mapping 29 | # import numpy as np 30 | # 31 | # Batch = Mapping[str, jnp.ndarray] 32 | # 33 | # 34 | # def pendulum_net(batch: Batch) -> jnp.ndarray: 35 | # x, u = batch['x'], batch['u'] 36 | # inp = jnp.concatenate([x, u], axis=1) 37 | # mlp = hk.Sequential([ 38 | # hk.Flatten(), 39 | # hk.Linear(20), jax.nn.relu, 40 | # hk.Linear(20), jax.nn.relu, 41 | # hk.Linear(2), 42 | # ]) 43 | # return mlp(inp) 44 | # 45 | # 46 | # def pendulum_net_wrap(x, u) -> jnp.ndarray: 47 | # batch = {'x': x, 'u': u} 48 | # return pendulum_net 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /examples/collect_data/pendulum_uniform.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))) 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from environments.pendulum import pendulum_dynamics 9 | from jax import random 10 | from utils.data_utils import make_train_test_split 11 | 12 | if __name__ == '__main__': 13 | key = random.PRNGKey(42) 14 | newkey, subkey = random.split(key) 15 | 16 | x0 = jnp.array([jnp.pi, 0.]) 17 | u0 = jnp.array([0]) 18 | 19 | n_data = 200 # 10 thousands 20 | 21 | x_max, x_min = 2 * jnp.pi, -2 * jnp.pi 22 | u_max, u_min = 5, -5 23 | 24 | states = jax.random.uniform(subkey, shape=(n_data, 2), minval=x_min, maxval=x_max) 25 | actions = jax.random.uniform(subkey, shape=(n_data, 1), minval=u_min, maxval=u_max) 26 | 27 | Xs = jnp.concatenate([states, actions], 1) 28 | 29 | Ys = pendulum_dynamics(states, actions) 30 | # import ipdb; ipdb.set_trace() 31 | 32 | train_X, train_Y, test_X, test_Y = make_train_test_split(Xs, Ys, 0.2) 33 | 34 | dataset = {'train_x': train_X, 'train_y': train_Y, 'test_x': test_X, 'test_y': test_Y} 35 | 36 | jnp.save('data/pendulum_determinstic_uniform_dataset-%d' % n_data, dataset) 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import random 3 | 4 | 5 | def transform_traj_to_input_prediction_format(traj, dim_x, dim_u): 6 | """ 7 | :param traj: T * (dim_x + dim_u) 8 | :return: 9 | """ 10 | xs = [] 11 | ys = [] 12 | T = traj.shape[0] 13 | assert traj.shape[1] == dim_u + dim_x 14 | for t in range(T-1): 15 | x = traj[t] # dim_x + dim_u 16 | y = traj[t+1][:dim_x] 17 | xs.append(x) 18 | ys.append(y) 19 | return jnp.stack(xs, 0), jnp.stack(ys, 0) 20 | 21 | 22 | def transform_trajs_to_training_data(trajs, dim_x, dim_u): 23 | """ 24 | :param trajs: n_trajs * T * (dim_x + dim_u) 25 | :param dim_x: 26 | :param dim_u: 27 | :return: 28 | """ 29 | n_trajs, n_steps, _ = trajs.shape 30 | 31 | Xs = [] # store state action 32 | Ys = [] # store the next state 33 | 34 | for t in range(n_trajs): 35 | traj = trajs[t] 36 | xs, ys = transform_traj_to_input_prediction_format(traj, dim_x, dim_u) 37 | Xs.append(xs) 38 | Ys.append(ys) 39 | return jnp.concatenate(Xs, 0), jnp.concatenate(Ys, 0) 40 | 41 | 42 | def make_train_test_split(Xs, Ys, test_ratio=0.2): 43 | n_data = Xs.shape[0] 44 | indices = list(range(n_data)) 45 | 46 | n_test = int(n_data * test_ratio) 47 | n_train = n_data - n_test 48 | 49 | random.shuffle(indices) 50 | 51 | train_indices = jnp.array(indices[:n_train]) 52 | test_indices = jnp.array(indices[n_train:]) 53 | 54 | train_X = Xs[train_indices] 55 | train_Y = Ys[train_indices] 56 | 57 | test_X = Xs[test_indices] 58 | test_Y = Ys[test_indices] 59 | 60 | return train_X, train_Y, test_X, test_Y 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /learning/fit_dynamics_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 4 | 5 | from absl import app 6 | from absl import flags 7 | from absl import logging 8 | from clu import platform 9 | import jax 10 | import jax.numpy as jnp 11 | from ml_collections import config_flags 12 | import tensorflow as tf 13 | 14 | from utils.training_utils import train_and_evaluate 15 | 16 | FLAGS = flags.FLAGS 17 | 18 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 19 | config_flags.DEFINE_config_file( 20 | 'config', 21 | None, 22 | 'configs/default.py', 23 | lock_config=True) 24 | 25 | 26 | def main(argv): 27 | if len(argv) > 1: 28 | raise app.UsageError('Too many command-line arguments.') 29 | 30 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 31 | # it unavailable to JAX. 32 | tf.config.experimental.set_visible_devices([], 'GPU') 33 | 34 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 35 | logging.info('JAX local devices: %r', jax.local_devices()) 36 | 37 | # Add a note so that we can tell which task is which JAX host. 38 | # (Depending on the platform task 0 is not guaranteed to be host 0) 39 | platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, process_count: {jax.process_count()}') 40 | platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, 41 | FLAGS.workdir, 'workdir') 42 | 43 | state = train_and_evaluate(FLAGS.config, FLAGS.workdir) 44 | jnp.save('data/pendulum_determinstic_NN_state', state.params) 45 | 46 | 47 | if __name__ == '__main__': 48 | flags.mark_flags_as_required(['config', 'workdir']) 49 | app.run(main) 50 | -------------------------------------------------------------------------------- /control/_variables.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 4 | 5 | import functools 6 | import jaxfg.hints as hints 7 | from jaxfg.core import VariableBase 8 | import jax.numpy as jnp 9 | from overrides import EnforceOverrides, final, overrides 10 | from typing import Type, TypeVar, Mapping 11 | 12 | VariableType = TypeVar("VariableType", bound="VariableBase") 13 | VariableValueType = TypeVar("VariableValueType", bound=hints.VariableValue) 14 | 15 | 16 | class _BoundedRealVectorVariableTemplate: 17 | """Usage: `RealVectorVariable[N]`, where `N` is an integer dimension.""" 18 | 19 | def __init__(self, min_val, max_val): 20 | self.min_val = min_val 21 | self.max_val = max_val 22 | 23 | # @classmethod 24 | @functools.lru_cache(maxsize=None) 25 | def __getitem__(self, dim: int) -> Type[VariableBase]: 26 | assert isinstance(dim, int) 27 | 28 | class _BoundedRealVectorVariable(VariableBase[hints.Array]): 29 | 30 | @classmethod 31 | @overrides 32 | @final 33 | def get_default_value(cls) -> hints.Array: 34 | return jnp.zeros(dim) 35 | 36 | @classmethod 37 | @overrides 38 | @final 39 | def manifold_retract( 40 | cls, x: VariableValueType, local_delta: hints.LocalVariableValue 41 | ) -> VariableValueType: 42 | r"""Retract local delta to manifold. 43 | Typically written as `x $\oplus$ local_delta` or `x $\boxplus$ local_delta`. 44 | Args: 45 | x: Absolute parameter to update. 46 | local_delta: Delta value in local parameterizaton. 47 | Returns: 48 | Updated parameterization. 49 | """ 50 | return cls.unflatten(jnp.clip(cls.flatten(x) + local_delta, self.min_val, self.max_val)) 51 | 52 | return _BoundedRealVectorVariable 53 | 54 | 55 | BoundedRealVectorVariable: Mapping[int, Type[VariableBase[hints.Array]]] 56 | BoundedRealVectorVariable = _BoundedRealVectorVariableTemplate 57 | -------------------------------------------------------------------------------- /control/_general_factors.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Tuple, Callable 2 | 3 | import jax_dataclasses as jdc 4 | from jax import numpy as jnp 5 | from overrides import overrides 6 | 7 | from jaxfg import noises 8 | from jaxfg.core._factor_base import FactorBase 9 | 10 | 11 | class StateActionStateTriplet(NamedTuple): 12 | prev_state: jnp.ndarray 13 | action: jnp.ndarray 14 | next_state: jnp.ndarray 15 | 16 | 17 | class ActionStateTuple(NamedTuple): 18 | action: jnp.ndarray 19 | next_state: jnp.ndarray 20 | 21 | 22 | class GeneralFactorSAS: 23 | 24 | @staticmethod 25 | def make( 26 | prev_state: jnp.ndarray, 27 | action: jnp.ndarray, 28 | next_state: jnp.ndarray, 29 | transit_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 30 | noise_model: noises.NoiseModelBase, 31 | ): 32 | @jdc.pytree_dataclass 33 | class _Temp(FactorBase[StateActionStateTriplet]): 34 | def transition_function( 35 | self, s0: jnp.ndarray, a0: jnp.ndarray 36 | ) -> jnp.ndarray: 37 | return transit_function(s0, a0) 38 | 39 | @overrides 40 | def compute_residual_vector( 41 | self, variable_values: StateActionStateTriplet 42 | ) -> jnp.ndarray: 43 | s0 = variable_values.prev_state 44 | s1 = variable_values.next_state 45 | a0 = variable_values.action 46 | return transit_function(s0, a0) - s1 47 | 48 | factor = _Temp(variables=(prev_state, action, next_state,), noise_model=noise_model,) 49 | return factor 50 | 51 | 52 | class GeneralFactorAS: 53 | 54 | @staticmethod 55 | def make( 56 | prev_state: jnp.ndarray, 57 | action: jnp.ndarray, 58 | next_state: jnp.ndarray, 59 | transit_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], 60 | noise_model: noises.NoiseModelBase, 61 | ) -> FactorBase[ActionStateTuple]: 62 | @jdc.pytree_dataclass 63 | class _Temp(FactorBase[ActionStateTuple]): 64 | initial_state: jnp.ndarray 65 | 66 | def transition_function( 67 | self, s0: jnp.ndarray, a0: jnp.ndarray 68 | ) -> jnp.ndarray: 69 | return transit_function(s0, a0) 70 | 71 | @overrides 72 | def compute_residual_vector( 73 | self, variable_values: ActionStateTuple 74 | ) -> jnp.ndarray: 75 | s1 = variable_values.next_state 76 | a0 = variable_values.action 77 | return transit_function(self.initial_state, a0) - s1 78 | 79 | factor = _Temp(variables=(action, next_state,), noise_model=noise_model, initial_state=prev_state) 80 | return factor 81 | -------------------------------------------------------------------------------- /examples/deprecated/lqr_example_deprecated.py: -------------------------------------------------------------------------------- 1 | from control import PriorFactor, LQRTripletFactor, LQRInitialFactor 2 | from jax import numpy as jnp 3 | import jaxfg 4 | from jaxfg.core import RealVectorVariable 5 | from jaxfg.solvers import LevenbergMarquardtSolver 6 | from typing import List 7 | from utils.visualization_utils import LQRVis 8 | 9 | dim_x = 2 10 | dim_u = 1 11 | 12 | A = jnp.array([[1.1, 0.0], [0.1, 1.1]]) # state transition linear operator on x: Ax 13 | B = jnp.array([[0.1], [0.0]]) # state transition linear operator on u: Bu 14 | c = jnp.array([-1., -2.]) # state transition bias, i.e., x(t+1) = Ax(t) + Bu(t) + c 15 | 16 | Q_inv = jnp.linalg.inv(jnp.array([[10., 0.], [0., 10.]])) # covariance of state prior 17 | R_inv = jnp.linalg.inv(jnp.array([[1.0]])) # covariance of action prior 18 | 19 | cov_dyn = jnp.array([1e-10]*dim_x) # covariance of dynamics 20 | X0 = jnp.array([5., 5.]) # initial state 21 | Xg = jnp.array([10., 10.]) # goal state 22 | ug = jnp.array([0.]) # goal action 23 | T = 60 # horizon 24 | 25 | # ============================ build factor graphs ========================== 26 | state_variables = [RealVectorVariable[dim_x]() for _ in range(T)] 27 | action_variables = [RealVectorVariable[dim_u]() for _ in range(T)] 28 | 29 | action_state_factors: List[jaxfg.core.FactorBase] = \ 30 | [LQRInitialFactor.make(X0, 31 | action_variables[0], 32 | state_variables[0], A, B, c, 33 | jaxfg.noises.DiagonalGaussian.make_from_covariance(cov_dyn)) 34 | ] 35 | 36 | state_action_state_factors: List[jaxfg.core.FactorBase] = \ 37 | [LQRTripletFactor.make(state_variables[i], 38 | action_variables[i+1], 39 | state_variables[i+1], 40 | A, B, c, 41 | jaxfg.noises.DiagonalGaussian.make_from_covariance(cov_dyn)) 42 | for i in range(T-1)] 43 | 44 | state_prior_factors: List[jaxfg.core.FactorBase] = \ 45 | [PriorFactor.make(state_variables[i], 46 | Xg, 47 | jaxfg.noises.Gaussian.make_from_covariance(Q_inv)) 48 | for i in range(T)] 49 | 50 | action_prior_factors: List[jaxfg.core.FactorBase] = \ 51 | [PriorFactor.make(action_variables[i], 52 | jnp.zeros(dim_u), 53 | jaxfg.noises.Gaussian.make_from_covariance(R_inv)) 54 | for i in range(T)] 55 | 56 | factors: List[jaxfg.core.FactorBase] = action_state_factors \ 57 | + state_action_state_factors \ 58 | + state_prior_factors \ 59 | + action_prior_factors 60 | 61 | state_action_variables = state_variables + action_variables 62 | 63 | graph = jaxfg.core.StackedFactorGraph.make(factors) 64 | initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 65 | print("Initial assignments:") 66 | print(initial_assignments) 67 | 68 | # Solve. Note that the first call to solve() will be much slower than subsequent calls. 69 | with jaxfg.utils.stopwatch("First solve (slower because of JIT compilation)"): 70 | solution_assignments = graph.solve(initial_assignments)#, solver=LevenbergMarquardtSolver()) 71 | solution_assignments.storage.block_until_ready() # type: ignore 72 | 73 | with jaxfg.utils.stopwatch("Solve after initial compilation"): 74 | solution_assignments = graph.solve(initial_assignments)#, solver=LevenbergMarquardtSolver()) 75 | solution_assignments.storage.block_until_ready() # type: ignore 76 | 77 | # Print all solved variable values. 78 | print("Solutions (jaxfg.core.VariableAssignments):") 79 | print(solution_assignments) 80 | print() 81 | # import ipdb;ipdb.set_trace() 82 | 83 | # ======================== below is for visualization ============================== 84 | us = [solution_assignments.get_value(action_variables[i]) for i in range(T)] 85 | xs1 = [solution_assignments.get_value(state_variables[i])[0] for i in range(T)] 86 | xs2 = [solution_assignments.get_value(state_variables[i])[1] for i in range(T)] 87 | xs1 = [X0[0]] + xs1 88 | xs2 = [X0[1]] + xs2 89 | us = us + [0] 90 | 91 | ts = list(range(T+1)) 92 | LQRVis.init_plot() 93 | LQRVis.plot_trajectory(xs1, xs2, us, T+1) 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /examples/lqr/lqr_watson20.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))) 4 | 5 | from control import PriorFactor, GeneralFactorSAS, GeneralFactorAS 6 | from environments.lqr_simple import LQREnv 7 | from jax import numpy as jnp 8 | import jaxfg 9 | from jaxfg.core import RealVectorVariable 10 | from jaxfg.solvers import LevenbergMarquardtSolver 11 | from typing import List 12 | from utils.visualization_utils import LQRVisWatson 13 | 14 | 15 | # define the configurations of the environment 16 | dim_x = 2 17 | dim_u = 1 18 | 19 | Q_inv = jnp.linalg.inv(jnp.array([[10., 0.], [0., 10.]])) # covariance of state prior 20 | R_inv = jnp.linalg.inv(jnp.array([[1.0]])) # covariance of action prior 21 | 22 | cov_dyn = jnp.array([1e-10]*dim_x) # covariance of dynamics 23 | X0 = jnp.array([5., 5.]) # initial state 24 | Xg = jnp.array([10., 10.]) # goal state 25 | ug = jnp.array([0.]) # goal action 26 | T = 60 # horizon 27 | 28 | # ============================ build factor graphs ========================== 29 | state_variables = [RealVectorVariable[dim_x]() for _ in range(T)] 30 | action_variables = [RealVectorVariable[dim_u]() for _ in range(T)] 31 | 32 | action_state_factors: List[jaxfg.core.FactorBase] = \ 33 | [GeneralFactorAS.make(X0, 34 | action_variables[0], 35 | state_variables[0], 36 | LQREnv.lqr_simple_watson, 37 | jaxfg.noises.DiagonalGaussian.make_from_covariance(cov_dyn)) 38 | ] 39 | 40 | state_action_state_factors: List[jaxfg.core.FactorBase] = \ 41 | [GeneralFactorSAS.make(state_variables[i], 42 | action_variables[i+1], 43 | state_variables[i+1], 44 | LQREnv.lqr_simple_watson, 45 | jaxfg.noises.DiagonalGaussian.make_from_covariance(cov_dyn)) 46 | for i in range(T-1)] 47 | 48 | state_prior_factors: List[jaxfg.core.FactorBase] = \ 49 | [PriorFactor.make(state_variables[i], 50 | Xg, 51 | jaxfg.noises.Gaussian.make_from_covariance(Q_inv)) 52 | for i in range(T)] 53 | 54 | action_prior_factors: List[jaxfg.core.FactorBase] = \ 55 | [PriorFactor.make(action_variables[i], 56 | jnp.zeros(dim_u), 57 | jaxfg.noises.Gaussian.make_from_covariance(R_inv)) 58 | for i in range(T)] 59 | 60 | factors: List[jaxfg.core.FactorBase] = action_state_factors \ 61 | + state_action_state_factors \ 62 | + state_prior_factors \ 63 | + action_prior_factors 64 | 65 | state_action_variables = state_variables + action_variables 66 | 67 | graph = jaxfg.core.StackedFactorGraph.make(factors) 68 | initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 69 | print("Initial assignments:") 70 | print(initial_assignments) 71 | 72 | # Solve. Note that the first call to solve() will be much slower than subsequent calls. 73 | with jaxfg.utils.stopwatch("First solve (slower because of JIT compilation)"): 74 | solution_assignments = graph.solve(initial_assignments)#, solver=LevenbergMarquardtSolver()) 75 | solution_assignments.storage.block_until_ready() # type: ignore 76 | 77 | with jaxfg.utils.stopwatch("Solve after initial compilation"): 78 | solution_assignments = graph.solve(initial_assignments)#, solver=LevenbergMarquardtSolver()) 79 | solution_assignments.storage.block_until_ready() # type: ignore 80 | 81 | # Print all solved variable values. 82 | print("Solutions (jaxfg.core.VariableAssignments):") 83 | print(solution_assignments) 84 | print() 85 | # import ipdb;ipdb.set_trace() 86 | 87 | # ======================== below is for visualization ============================== 88 | us = [solution_assignments.get_value(action_variables[i]) for i in range(T)] 89 | xs1 = [solution_assignments.get_value(state_variables[i])[0] for i in range(T)] 90 | xs2 = [solution_assignments.get_value(state_variables[i])[1] for i in range(T)] 91 | xs1 = [X0[0]] + xs1 92 | xs2 = [X0[1]] + xs2 93 | us = us + [0] 94 | 95 | ts = list(range(T+1)) 96 | LQRVisWatson.init_plot() 97 | LQRVisWatson.plot_trajectory(xs1, xs2, us, T+1, save_path='assets/lqr_watson20.png') 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | from absl import logging 2 | from flax import linen as nn 3 | from flax.metrics import tensorboard 4 | from flax.training import train_state 5 | from flax.training import checkpoints 6 | from models.networks import MLP 7 | import jax 8 | import jax.numpy as jnp 9 | import ml_collections 10 | import numpy as np 11 | import optax 12 | # import tensorflow_datasets as tfds 13 | 14 | 15 | @jax.jit 16 | def apply_model(state, x, y): 17 | """Computes gradients, loss and accuracy for a single batch.""" 18 | def loss_fn(params): 19 | pred = MLP([128, 64, 20, 2], 'swish').apply({'params': params}, x) 20 | loss = jnp.mean(optax.l2_loss(pred, y)) 21 | return loss, pred 22 | 23 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 24 | (loss, pred), grads = grad_fn(state.params) 25 | return grads, loss 26 | 27 | 28 | @jax.jit 29 | def update_model(state, grads): 30 | return state.apply_gradients(grads=grads) 31 | 32 | 33 | def train_epoch(state, train_ds, batch_size, rng): 34 | """Train for a single epoch.""" 35 | train_ds_size = len(train_ds['x']) 36 | steps_per_epoch = train_ds_size // batch_size 37 | 38 | perms = jax.random.permutation(rng, len(train_ds['x'])) 39 | perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch 40 | perms = perms.reshape((steps_per_epoch, batch_size)) 41 | 42 | epoch_loss = [] 43 | 44 | for perm in perms: 45 | batch_x = train_ds['x'][perm, ...] 46 | batch_y = train_ds['y'][perm, ...] 47 | grads, loss = apply_model(state, batch_x, batch_y) 48 | state = update_model(state, grads) 49 | epoch_loss.append(loss) 50 | train_loss = np.mean(epoch_loss) 51 | return state, train_loss 52 | 53 | 54 | def create_train_state(rng, config): 55 | """Creates initial `TrainState`.""" 56 | mlp = MLP([128, 64, 20, 2], config.activation) 57 | params = mlp.init(rng, jnp.ones([1, config.input_dim]))['params'] 58 | tx = optax.sgd(config.learning_rate, config.momentum) 59 | return train_state.TrainState.create( 60 | apply_fn=mlp.apply, params=params, tx=tx) 61 | 62 | 63 | def get_datasets(path): 64 | data_load = jnp.load(path, allow_pickle=True) 65 | data = data_load.item() 66 | train_x, train_y = data['train_x'], data['train_y'] 67 | test_x, test_y = data['test_x'], data['test_y'] 68 | 69 | train_data = {'x': train_x, 'y': train_y} 70 | test_data = {'x': test_x, 'y': test_y} 71 | return train_data, test_data 72 | 73 | 74 | def train_and_evaluate(config: ml_collections.ConfigDict, 75 | workdir: str) -> train_state.TrainState: 76 | """Execute model training and evaluation loop. 77 | Args: 78 | config: Hyperparameter configuration for training and evaluation. 79 | workdir: Directory where the tensorboard summaries are written to. 80 | Returns: 81 | The train state (which includes the `.params`). 82 | """ 83 | train_ds, test_ds = get_datasets(config.data_path) 84 | rng = jax.random.PRNGKey(0) 85 | 86 | summary_writer = tensorboard.SummaryWriter(workdir) 87 | summary_writer.hparams(dict(config)) 88 | 89 | rng, init_rng = jax.random.split(rng) 90 | state= create_train_state(init_rng, config) 91 | 92 | for epoch in range(1, config.num_epochs + 1): 93 | rng, input_rng = jax.random.split(rng) 94 | state, train_loss = train_epoch(state, train_ds, 95 | config.batch_size, 96 | input_rng) 97 | _, test_loss = apply_model(state, test_ds['x'], 98 | test_ds['y']) 99 | # print(MLP([128, 64, 20, 2]).apply({'params': state.params}, test_ds['x'])) 100 | # print(test_ds['y']) 101 | logging.info( 102 | 'epoch:% 3d, train_loss: %.4f, test_loss: %.4f' 103 | % (epoch, train_loss, test_loss)) 104 | 105 | summary_writer.scalar('train_loss', train_loss, epoch) 106 | summary_writer.scalar('test_loss', test_loss, epoch) 107 | save_checkpoint(state, workdir) 108 | summary_writer.flush() 109 | return state 110 | 111 | 112 | def restore_checkpoint(state, workdir): 113 | return checkpoints.restore_checkpoint(workdir, state) 114 | 115 | 116 | def save_checkpoint(state, workdir): 117 | if jax.process_index() == 0: 118 | # get train state from the first replica 119 | # state = jax.device_get(jax.tree_map(lambda x: x[0], state)) 120 | step = int(state.step) 121 | checkpoints.save_checkpoint(workdir, state, step, keep=3) 122 | -------------------------------------------------------------------------------- /examples/lqr/lqr_gtsam.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))) 4 | 5 | from control import PriorFactor, GeneralFactorSAS, GeneralFactorAS 6 | from environments.lqr_simple import LQREnv 7 | from jax import numpy as jnp 8 | import jaxfg 9 | from jaxfg.core import RealVectorVariable 10 | from jaxfg.solvers import LevenbergMarquardtSolver 11 | from typing import List 12 | from utils.visualization_utils import LQRVisGTSAM 13 | 14 | 15 | # define the configurations of the environment 16 | dim_x = 1 17 | dim_u = 1 18 | 19 | # The following transition function is wrapped as in LQREnv.lqr_simple_gtsm 20 | # A = jnp.array([[1.03]]) # slightly unstable system :) 21 | # B = jnp.array([[0.03]]) 22 | Q_inv = jnp.linalg.inv(jnp.array([[0.21]])) # covariance of state prior 23 | R_inv = jnp.linalg.inv(jnp.array([[0.05]])) # covariance of action prior 24 | 25 | X0 = jnp.array([-10.]) # initial state 26 | Xg = jnp.array([0.]) # goal state 27 | T = 100 28 | cov_dyn = jnp.array([1e-10]) 29 | 30 | # ============================ build factor graphs ========================== 31 | state_variables = [RealVectorVariable[dim_x]() for _ in range(T)] 32 | action_variables = [RealVectorVariable[dim_u]() for _ in range(T)] 33 | 34 | action_state_factors: List[jaxfg.core.FactorBase] = \ 35 | [GeneralFactorAS.make(X0, 36 | action_variables[0], 37 | state_variables[0], 38 | LQREnv.lqr_simple_gtsam, # define the transition function 39 | jaxfg.noises.DiagonalGaussian.make_from_covariance(cov_dyn)) 40 | ] 41 | 42 | state_action_state_factors: List[jaxfg.core.FactorBase] = \ 43 | [GeneralFactorSAS.make(state_variables[i], 44 | action_variables[i+1], 45 | state_variables[i+1], 46 | LQREnv.lqr_simple_gtsam, # define the transition function 47 | jaxfg.noises.DiagonalGaussian.make_from_covariance(cov_dyn)) 48 | for i in range(T-1)] 49 | 50 | state_prior_factors: List[jaxfg.core.FactorBase] = \ 51 | [PriorFactor.make(state_variables[i], 52 | Xg, 53 | jaxfg.noises.Gaussian.make_from_covariance(Q_inv)) 54 | for i in range(T)] 55 | 56 | action_prior_factors: List[jaxfg.core.FactorBase] = \ 57 | [PriorFactor.make(action_variables[i], 58 | jnp.zeros(dim_u), 59 | jaxfg.noises.Gaussian.make_from_covariance(R_inv)) 60 | for i in range(T)] 61 | 62 | factors: List[jaxfg.core.FactorBase] = action_state_factors \ 63 | + state_action_state_factors \ 64 | + state_prior_factors \ 65 | + action_prior_factors 66 | 67 | state_action_variables = state_variables + action_variables 68 | 69 | graph = jaxfg.core.StackedFactorGraph.make(factors) 70 | initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 71 | print("Initial assignments:") 72 | print(initial_assignments) 73 | 74 | # Solve. Note that the first call to solve() will be much slower than subsequent calls. 75 | # The optimizer is by default using Gauss-Newton. 76 | with jaxfg.utils.stopwatch("First solve (slower because of JIT compilation)"): 77 | solution_assignments = graph.solve(initial_assignments) #, solver=LevenbergMarquardtSolver()) 78 | solution_assignments.storage.block_until_ready() # type: ignore 79 | 80 | with jaxfg.utils.stopwatch("Solve after initial compilation"): 81 | solution_assignments = graph.solve(initial_assignments) #, solver=LevenbergMarquardtSolver()) 82 | solution_assignments.storage.block_until_ready() # type: ignore 83 | 84 | # Print all solved variable values. 85 | print("Solutions (jaxfg.core.VariableAssignments):") 86 | print(solution_assignments) 87 | print() 88 | # import ipdb;ipdb.set_trace() 89 | 90 | # ======================== below is for visualization ============================== 91 | us = [solution_assignments.get_value(action_variables[i]) for i in range(T)] 92 | us = us + [0] 93 | xs = [solution_assignments.get_value(state_variables[i]) for i in range(T)] 94 | xs = [X0] + xs # append the initial location at the beginning 95 | 96 | # computing the true trajectory 97 | true_xs = [X0] 98 | for i in range(T): 99 | prev_x = true_xs[-1] 100 | u = us[i] 101 | next_x = LQREnv.lqr_simple_gtsam(prev_x, u) 102 | true_xs.append(next_x) 103 | 104 | 105 | LQRVisGTSAM.init_plot() 106 | LQRVisGTSAM.plot_trajectory(xs, true_xs, us, T+1, save_path='assets/lqr_gtsam.png') 107 | 108 | 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /examples/pendulum/nonlinsq_open_loop_known.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))) 4 | 5 | from control import PriorFactor, TransformedPriorFactor, GeneralFactorSAS, GeneralFactorAS 6 | from jax import numpy as jnp 7 | import jaxfg 8 | from jaxfg.solvers import LevenbergMarquardtSolver 9 | from jaxfg.core import RealVectorVariable 10 | from environments.pendulum import pendulum_dynamics 11 | import matplotlib.pyplot as plt 12 | from typing import List 13 | from utils.visualization_utils import PendulumVis 14 | 15 | 16 | dim_x = 2 17 | dim_u = 1 18 | X0 = jnp.array([jnp.pi, 0.0]) # [\theta, \dot{\theta}] 19 | Xag = jnp.array([0., 1., 0.]) # goal state [sin(\theta), cos(\theta), \dot{\theta}] 20 | Q_inv = jnp.diag(jnp.array([100, 1, 100])) # covariance of transformed state 21 | R_inv = jnp.diag(jnp.array([50])) # covariance of action 22 | cov_dyn = jnp.array([[1e-5, 0.], [0., 1e-5]]) # covariance of state transition; small value means deterministic 23 | T = 100 # horizon 24 | 25 | state_variables = [RealVectorVariable[dim_x]() for _ in range(T)] 26 | action_variables = [RealVectorVariable[dim_u]() for _ in range(T)] 27 | 28 | action_state_factors: List[jaxfg.core.FactorBase] = \ 29 | [GeneralFactorAS.make(X0, 30 | action_variables[0], 31 | state_variables[0], 32 | pendulum_dynamics, 33 | jaxfg.noises.Gaussian.make_from_covariance(cov_dyn))] 34 | 35 | state_action_state_factors: List[jaxfg.core.FactorBase] = \ 36 | [GeneralFactorSAS.make(state_variables[i], 37 | action_variables[i+1], 38 | state_variables[i+1], 39 | pendulum_dynamics, 40 | jaxfg.noises.Gaussian.make_from_covariance(cov_dyn)) 41 | for i in range(T-1)] 42 | 43 | state_prior_factors: List[jaxfg.core.FactorBase] = \ 44 | [TransformedPriorFactor.make(state_variables[i], 45 | Xag, 46 | jaxfg.noises.Gaussian.make_from_covariance(Q_inv)) 47 | for i in range(T)] 48 | 49 | action_prior_factors: List[jaxfg.core.FactorBase] = \ 50 | [PriorFactor.make(action_variables[i], 51 | jnp.zeros(dim_u), 52 | jaxfg.noises.Gaussian.make_from_covariance(R_inv)) 53 | for i in range(T)] 54 | 55 | factors: List[jaxfg.core.FactorBase] = action_state_factors \ 56 | + state_action_state_factors \ 57 | + state_prior_factors \ 58 | + action_prior_factors 59 | 60 | state_action_variables = state_variables + action_variables 61 | 62 | # import ipdb; ipdb.set_trace() 63 | graph = jaxfg.core.StackedFactorGraph.make(factors) 64 | initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 65 | 66 | print("Initial assignments:") 67 | print(initial_assignments) 68 | 69 | # Solve. Note that the first call to solve() will be much slower than subsequent calls. 70 | with jaxfg.utils.stopwatch("First solve (slower because of JIT compilation)"): 71 | solution_assignments = graph.solve(initial_assignments, solver=LevenbergMarquardtSolver()) 72 | solution_assignments.storage.block_until_ready() # type: ignore 73 | 74 | with jaxfg.utils.stopwatch("Solve after initial compilation"): 75 | solution_assignments = graph.solve(initial_assignments, solver=LevenbergMarquardtSolver()) 76 | solution_assignments.storage.block_until_ready() # type: ignore 77 | 78 | # Print all solved variable values. 79 | print("Solutions (jaxfg.core.VariableAssignments):") 80 | print(solution_assignments) 81 | print() 82 | # import ipdb;ipdb.set_trace() 83 | 84 | 85 | # ======================== below is for visualization ============================== 86 | us = [solution_assignments.get_value(action_variables[i]) for i in range(T)] 87 | 88 | xs1 = [solution_assignments.get_value(state_variables[i])[0] for i in range(T)] 89 | xs2 = [solution_assignments.get_value(state_variables[i])[1] for i in range(T)] 90 | xs = [X0] + [solution_assignments.get_value(state_variables[i]) for i in range(T)] 91 | 92 | xs1 = [X0[0]] + xs1 # append the initial location at the beginning of the list 93 | xs2 = [X0[1]] + xs2 94 | us = us + [0] # append the last action at the end of the list 95 | 96 | # below is for get the trajectories under the true dynamics. 97 | true_xs = [X0] # append the first location 98 | for i in range(T): 99 | x_prev = true_xs[-1] 100 | next_x = pendulum_dynamics(x_prev, us[i]) 101 | true_xs.append(next_x) 102 | 103 | true_xs1 = [_[0] for _ in true_xs] 104 | true_xs2 = [_[1] for _ in true_xs] 105 | ts = list(range(T+1)) 106 | 107 | PendulumVis.init_plot() 108 | PendulumVis.plot_trajectory(xs1, xs2, true_xs1, true_xs2, us, T+1, save_path='assets/pendulum_watson20.png') 109 | 110 | -------------------------------------------------------------------------------- /control/_factors.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Tuple 2 | 3 | import jax_dataclasses as jdc 4 | from jax import numpy as jnp 5 | from overrides import overrides 6 | 7 | from jaxfg import noises 8 | from jaxfg.core._factor_base import FactorBase 9 | 10 | 11 | PriorValueTuple = Tuple[jnp.ndarray] 12 | 13 | 14 | @jdc.pytree_dataclass 15 | class PriorFactor(FactorBase[PriorValueTuple]): 16 | """Factor for defining a fixed prior on a frame. 17 | Residuals are computed as `(variable.inverse() @ mu).log()`. 18 | """ 19 | 20 | mu: jnp.ndarray 21 | 22 | @staticmethod 23 | def make( 24 | variable: jnp.ndarray, 25 | mu: jnp.ndarray, 26 | noise_model: noises.NoiseModelBase, 27 | ) -> "PriorFactor": 28 | return PriorFactor( 29 | variables=(variable,), 30 | mu=mu, 31 | noise_model=noise_model, 32 | ) 33 | 34 | @overrides 35 | def compute_residual_vector(self, variable_values: PriorValueTuple) -> jnp.ndarray: 36 | 37 | T: jnp.ndarray 38 | (T,) = variable_values 39 | 40 | # Equivalent to: return (variable_value.inverse() @ self.mu).log() 41 | # FIXME(CW): does the sign of the residual matter? 42 | return T - self.mu 43 | 44 | 45 | @jdc.pytree_dataclass 46 | class TransformedPriorFactor(FactorBase[PriorValueTuple]): 47 | """Factor for defining a fixed prior on a frame. 48 | Residuals are computed as `(variable.inverse() @ mu).log()`. 49 | """ 50 | 51 | mu: jnp.ndarray 52 | 53 | @staticmethod 54 | def make( 55 | variable: jnp.ndarray, 56 | mu: jnp.ndarray, 57 | noise_model: noises.NoiseModelBase, 58 | ) -> "TransformedPriorFactor": 59 | return TransformedPriorFactor( 60 | variables=(variable,), 61 | mu=mu, 62 | noise_model=noise_model, 63 | ) 64 | 65 | @overrides 66 | def compute_residual_vector(self, variable_values: PriorValueTuple) -> jnp.ndarray: 67 | 68 | T: jnp.ndarray 69 | (T,) = variable_values 70 | T_transformed = jnp.concatenate([jnp.sin(T[0:1]), jnp.cos(T[0:1]), T[1:2]]) 71 | 72 | # Equivalent to: return (variable_value.inverse() @ self.mu).log() 73 | return T_transformed - self.mu 74 | 75 | 76 | class LQRTripletTuple(NamedTuple): 77 | prev_state: jnp.ndarray 78 | action: jnp.ndarray 79 | next_state: jnp.ndarray 80 | 81 | 82 | class LQRInitialTuple(NamedTuple): 83 | action: jnp.ndarray 84 | next_state: jnp.ndarray 85 | 86 | 87 | @jdc.pytree_dataclass 88 | class LQRInitialFactor(FactorBase[LQRInitialTuple]): 89 | 90 | initial_state: jnp.ndarray 91 | A: jnp.ndarray 92 | B: jnp.ndarray 93 | c: jnp.ndarray 94 | 95 | @staticmethod 96 | def make( 97 | prev_state: jnp.ndarray, 98 | action: jnp.ndarray, 99 | next_state: jnp.ndarray, 100 | A: jnp.ndarray, 101 | B: jnp.ndarray, 102 | c: jnp.ndarray, 103 | noise_model: noises.NoiseModelBase, 104 | ) -> "LQRInitialFactor": 105 | return LQRInitialFactor( 106 | variables=( 107 | action, 108 | next_state 109 | ), 110 | A=A, 111 | B=B, 112 | c=c, 113 | initial_state=prev_state, 114 | noise_model=noise_model, 115 | ) 116 | 117 | @overrides 118 | def compute_residual_vector( 119 | self, variable_values: LQRInitialTuple 120 | ) -> jnp.ndarray: 121 | prev_state = self.initial_state 122 | next_state = variable_values.next_state 123 | action = variable_values.action 124 | return jnp.dot(self.A, prev_state) + jnp.dot(self.B, action) + self.c - next_state 125 | 126 | 127 | @jdc.pytree_dataclass 128 | class LQRTripletFactor(FactorBase[LQRTripletTuple]): 129 | 130 | A: jnp.ndarray 131 | B: jnp.ndarray 132 | c: jnp.ndarray 133 | 134 | @staticmethod 135 | def make( 136 | prev_state: jnp.ndarray, 137 | action: jnp.ndarray, 138 | next_state: jnp.ndarray, 139 | A: jnp.ndarray, 140 | B: jnp.ndarray, 141 | c: jnp.ndarray, 142 | noise_model: noises.NoiseModelBase, 143 | ) -> "LQRTripletFactor": 144 | assert type(prev_state) is type(next_state) 145 | return LQRTripletFactor( 146 | variables=( 147 | prev_state, 148 | action, 149 | next_state 150 | ), 151 | A=A, 152 | B=B, 153 | c=c, 154 | noise_model=noise_model, 155 | ) 156 | 157 | @overrides 158 | def compute_residual_vector( 159 | self, variable_values: LQRTripletTuple 160 | ) -> jnp.ndarray: 161 | prev_state = variable_values.prev_state 162 | next_state = variable_values.next_state 163 | action = variable_values.action 164 | return jnp.dot(self.A, prev_state) + jnp.dot(self.B, action) + self.c - next_state 165 | 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /examples/pendulum/nonlinsq_closed_loop_known.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))) 4 | 5 | from control import PriorFactor, TransformedPriorFactor, GeneralFactorSAS, GeneralFactorAS, BoundedRealVectorVariable 6 | import jax 7 | from jax import random 8 | from jax import numpy as jnp 9 | import jaxfg 10 | from jaxfg.solvers import LevenbergMarquardtSolver 11 | from jaxfg.core import RealVectorVariable 12 | from environments.pendulum import pendulum_dynamics 13 | import matplotlib.pyplot as plt 14 | from typing import List 15 | from utils.visualization_utils import PendulumMPCVis 16 | 17 | H = 15 # how many steps we will lookahead when planning; planning_horizon 18 | 19 | dim_x = 2 20 | dim_u = 1 21 | X0 = jnp.array([jnp.pi, 0.]) # [\theta, \dot{\theta}] 22 | Xag = jnp.array([0., 1., 0.]) # goal state [sin(\theta), cos(\theta), \dot{\theta}] 23 | Q_inv = jnp.diag(jnp.array([100, 1., 100])) # covariance of transformed state 24 | R_inv = jnp.diag(jnp.array([50.])) # covariance of action 25 | cov_dyn = jnp.array([[1e-6, 0.], [0., 1e-6]]) # covariance of state transition; small value means deterministic 26 | T = 100 # horizon 27 | max_u = 5 28 | min_u = -5 29 | 30 | key = random.PRNGKey(42) 31 | 32 | 33 | state_variables = [RealVectorVariable[dim_x]() for _ in range(H)] 34 | action_variables = [BoundedRealVectorVariable(min_u, max_u)[dim_u]() for _ in range(H)] 35 | 36 | action_state_factors: List[jaxfg.core.FactorBase] = \ 37 | [GeneralFactorAS.make(X0, 38 | action_variables[0], 39 | state_variables[0], 40 | pendulum_dynamics, 41 | jaxfg.noises.Gaussian.make_from_covariance(cov_dyn))] 42 | 43 | state_action_state_factors: List[jaxfg.core.FactorBase] = \ 44 | [GeneralFactorSAS.make(state_variables[i], 45 | action_variables[i+1], 46 | state_variables[i+1], 47 | pendulum_dynamics, 48 | jaxfg.noises.Gaussian.make_from_covariance(cov_dyn)) 49 | for i in range(H-1)] 50 | 51 | state_prior_factors: List[jaxfg.core.FactorBase] = \ 52 | [TransformedPriorFactor.make(state_variables[i], 53 | Xag, 54 | jaxfg.noises.Gaussian.make_from_covariance(Q_inv)) 55 | for i in range(H)] 56 | 57 | action_prior_factors: List[jaxfg.core.FactorBase] = \ 58 | [PriorFactor.make(action_variables[i], 59 | jnp.zeros(dim_u), 60 | jaxfg.noises.Gaussian.make_from_covariance(R_inv)) 61 | for i in range(H)] 62 | 63 | 64 | factors: List[jaxfg.core.FactorBase] = action_state_factors \ 65 | + state_action_state_factors \ 66 | + state_prior_factors \ 67 | + action_prior_factors 68 | 69 | state_action_variables = state_variables + action_variables 70 | 71 | # import ipdb; ipdb.set_trace() 72 | graph = jaxfg.core.StackedFactorGraph.make(factors) 73 | # import ipdb; ipdb.set_trace() 74 | # graph.factor_stacks[0].factor.initial_state[:] = X0 75 | # import ipdb; ipdb.set_trace() 76 | initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 77 | print("Initial assignments:") 78 | print(initial_assignments) 79 | 80 | # t = 0 81 | # initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 82 | # initial_assignments.set_value(state_variables[t], X0) 83 | 84 | # Solve. Note that the first call to solve() will be much slower than subsequent calls. 85 | with jaxfg.utils.stopwatch("First solve (slower because of JIT compilation)"): 86 | solution_assignments = graph.solve(initial_assignments, solver=LevenbergMarquardtSolver()) 87 | solution_assignments.storage.block_until_ready() # type: ignore 88 | 89 | states_observed = [X0] 90 | actions_taken = [] 91 | 92 | for t in range(T): 93 | # import ipdb; ipdb.set_trace() 94 | graph.factor_stacks[0].factor.initial_state[:] = states_observed[-1] 95 | with jaxfg.utils.stopwatch("Solve after initial compilation"): 96 | solution_assignments = graph.solve(initial_assignments, solver=LevenbergMarquardtSolver(max_iterations=1000)) 97 | solution_assignments.storage.block_until_ready() # type: ignore 98 | action = solution_assignments.get_value(action_variables[0]) # take the first action 99 | actions_taken.append(action) 100 | # import ipdb; ipdb.set_trace() 101 | new_key, subkey = random.split(key) 102 | del key 103 | next_state = pendulum_dynamics(states_observed[-1], action) \ 104 | + jax.random.multivariate_normal(subkey, jnp.zeros(dim_x), cov_dyn) # run in the true environment 105 | states_observed.append(next_state) 106 | del subkey 107 | key = new_key 108 | 109 | # ======================== below is for visualization ============================== 110 | true_xs1 = [_[0] for _ in states_observed] 111 | true_xs2 = [_[1] for _ in states_observed] 112 | us = actions_taken + [0.] 113 | 114 | PendulumMPCVis.init_plot() 115 | PendulumMPCVis.plot_trajectory(true_xs1, true_xs2, us, T+1) # , save_path='assets/pendulum_MPC_watson20.png') 116 | -------------------------------------------------------------------------------- /examples/pendulum/nonlinsq_closed_loop_funcaprox.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jaxfg 3 | import os.path as osp 4 | import sys 5 | sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))) 6 | 7 | from control import PriorFactor, TransformedPriorFactor, GeneralFactorSAS, GeneralFactorAS, BoundedRealVectorVariable 8 | from environments.pendulum import pendulum_dynamics 9 | from jaxfg.core import RealVectorVariable 10 | from jax import numpy as jnp 11 | from jax import random 12 | from jaxfg.solvers import LevenbergMarquardtSolver 13 | from learning.configs.default import get_config 14 | from models.networks import MLP 15 | from typing import List 16 | from utils.training_utils import restore_checkpoint, create_train_state 17 | from utils.visualization_utils import PendulumMPCVis 18 | 19 | H = 20 # how many steps we will lookahead when planning; planning_horizon 20 | 21 | dim_x = 2 22 | dim_u = 1 23 | X0 = jnp.array([jnp.pi, 0.]) # [\theta, \dot{\theta}] 24 | Xag = jnp.array([0., 1., 0.]) # goal state [sin(\theta), cos(\theta), \dot{\theta}] 25 | Q_inv = jnp.diag(jnp.array([100, 1., 100])) # covariance of transformed state 26 | R_inv = jnp.diag(jnp.array([50.])) # covariance of action 27 | cov_dyn = jnp.array([[1e-6, 0.], [0., 1e-6]]) # covariance of state transition; small value means deterministic 28 | T = 100 # horizon 29 | max_u = 5 30 | min_u = -5 31 | key = random.PRNGKey(42) 32 | state_variables = [RealVectorVariable[dim_x]() for _ in range(H)] 33 | action_variables = [BoundedRealVectorVariable(min_u, max_u)[dim_u]() for _ in range(H)] 34 | 35 | # params = jnp.load('data/pendulum_determinstic_NN_state.npy') 36 | config = get_config() 37 | workdir = 'data/checkpoint_1000' #'data/checkpoint_62500' 38 | rng = jax.random.PRNGKey(0) 39 | ckpt = create_train_state(rng, config) 40 | ckpt = restore_checkpoint(ckpt, workdir) 41 | params = ckpt.params 42 | 43 | def pendulum_dynamics_learned(state, action): 44 | if len(state.shape) == 1: 45 | assert len(state.shape) == len(action.shape) 46 | state = jnp.reshape(state, (1, -1)) 47 | action = jnp.reshape(action, (1, -1)) 48 | return MLP(config.fc_dims, config.activation).apply({'params': params}, jnp.concatenate([state, action], 1)).reshape(-1) 49 | 50 | action_state_factors: List[jaxfg.core.FactorBase] = \ 51 | [GeneralFactorAS.make(X0, 52 | action_variables[0], 53 | state_variables[0], 54 | pendulum_dynamics_learned, 55 | jaxfg.noises.Gaussian.make_from_covariance(cov_dyn))] 56 | 57 | state_action_state_factors: List[jaxfg.core.FactorBase] = \ 58 | [GeneralFactorSAS.make(state_variables[i], 59 | action_variables[i+1], 60 | state_variables[i+1], 61 | pendulum_dynamics_learned, 62 | jaxfg.noises.Gaussian.make_from_covariance(cov_dyn)) 63 | for i in range(H-1)] 64 | 65 | state_prior_factors: List[jaxfg.core.FactorBase] = \ 66 | [TransformedPriorFactor.make(state_variables[i], 67 | Xag, 68 | jaxfg.noises.Gaussian.make_from_covariance(Q_inv)) 69 | for i in range(H)] 70 | 71 | action_prior_factors: List[jaxfg.core.FactorBase] = \ 72 | [PriorFactor.make(action_variables[i], 73 | jnp.zeros(dim_u), 74 | jaxfg.noises.Gaussian.make_from_covariance(R_inv)) 75 | for i in range(H)] 76 | 77 | 78 | factors: List[jaxfg.core.FactorBase] = action_state_factors \ 79 | + state_action_state_factors \ 80 | + state_prior_factors \ 81 | + action_prior_factors 82 | 83 | state_action_variables = state_variables + action_variables 84 | 85 | # import ipdb; ipdb.set_trace() 86 | graph = jaxfg.core.StackedFactorGraph.make(factors) 87 | initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 88 | print("Initial assignments:") 89 | print(initial_assignments) 90 | 91 | # Solve. Note that the first call to solve() will be much slower than subsequent calls. 92 | with jaxfg.utils.stopwatch("First solve (slower because of JIT compilation)"): 93 | solution_assignments = graph.solve(initial_assignments, solver=LevenbergMarquardtSolver()) 94 | solution_assignments.storage.block_until_ready() # type: ignore 95 | 96 | states_observed = [X0] 97 | actions_taken = [] 98 | 99 | for t in range(T): 100 | # import ipdb; ipdb.set_trace() 101 | graph.factor_stacks[0].factor.initial_state[:] = states_observed[-1] 102 | with jaxfg.utils.stopwatch("Solve after initial compilation"): 103 | solution_assignments = graph.solve(initial_assignments, solver=LevenbergMarquardtSolver(max_iterations=1000)) 104 | solution_assignments.storage.block_until_ready() # type: ignore 105 | action = solution_assignments.get_value(action_variables[0]) # take the first action 106 | actions_taken.append(action) 107 | # import ipdb; ipdb.set_trace() 108 | new_key, subkey = random.split(key) 109 | del key 110 | next_state = pendulum_dynamics(states_observed[-1], action) \ 111 | + jax.random.multivariate_normal(subkey, jnp.zeros(dim_x), cov_dyn) # run in the true environment 112 | states_observed.append(next_state) 113 | del subkey 114 | key = new_key 115 | 116 | # ======================== below is for visualization ============================== 117 | true_xs1 = [_[0] for _ in states_observed] 118 | true_xs2 = [_[1] for _ in states_observed] 119 | us = actions_taken + [0.] 120 | 121 | PendulumMPCVis.init_plot() 122 | PendulumMPCVis.plot_trajectory(true_xs1, true_xs2, us, T+1, save_path='assets/pendulum_MPC_funcapprox_watson20.png') 123 | -------------------------------------------------------------------------------- /examples/collect_data/pendulum_determinstic.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))) 4 | 5 | from control import PriorFactor, TransformedPriorFactor, GeneralFactorSAS, GeneralFactorAS, BoundedRealVectorVariable 6 | import jax 7 | from jax import random 8 | from jax import numpy as jnp 9 | import jaxfg 10 | from jaxfg.solvers import LevenbergMarquardtSolver 11 | from jaxfg.core import RealVectorVariable 12 | from environments.pendulum import pendulum_dynamics 13 | import matplotlib.pyplot as plt 14 | from typing import List 15 | from utils.visualization_utils import PendulumMPCVis 16 | from utils.data_utils import transform_trajs_to_training_data, make_train_test_split 17 | 18 | H = 50 # how many steps we will lookahead when planning; planning_horizon 19 | 20 | dim_x = 2 21 | dim_u = 1 22 | X0 = jnp.array([jnp.pi, 0.]) # [\theta, \dot{\theta}] 23 | Xag = jnp.array([0., 1., 0.]) # goal state [sin(\theta), cos(\theta), \dot{\theta}] 24 | Q_inv = jnp.diag(jnp.array([100, 1., 100])) # covariance of transformed state 25 | R_inv = jnp.diag(jnp.array([50.])) # covariance of action 26 | cov_dyn = jnp.array([[0.05, 0.], [0., 1e-6]]) # covariance of state transition; small value means deterministic 27 | T = 100 # horizon 28 | max_u = 4 29 | min_u = -4 30 | 31 | key = random.PRNGKey(42) 32 | 33 | 34 | state_variables = [RealVectorVariable[dim_x]() for _ in range(H)] 35 | action_variables = [BoundedRealVectorVariable(min_u, max_u)[dim_u]() for _ in range(H)] 36 | 37 | action_state_factors: List[jaxfg.core.FactorBase] = \ 38 | [GeneralFactorAS.make(X0, 39 | action_variables[0], 40 | state_variables[0], 41 | pendulum_dynamics, 42 | jaxfg.noises.Gaussian.make_from_covariance(cov_dyn))] 43 | 44 | state_action_state_factors: List[jaxfg.core.FactorBase] = \ 45 | [GeneralFactorSAS.make(state_variables[i], 46 | action_variables[i+1], 47 | state_variables[i+1], 48 | pendulum_dynamics, 49 | jaxfg.noises.Gaussian.make_from_covariance(cov_dyn)) 50 | for i in range(H-1)] 51 | 52 | state_prior_factors: List[jaxfg.core.FactorBase] = \ 53 | [TransformedPriorFactor.make(state_variables[i], 54 | Xag, 55 | jaxfg.noises.Gaussian.make_from_covariance(Q_inv)) 56 | for i in range(H)] 57 | 58 | action_prior_factors: List[jaxfg.core.FactorBase] = \ 59 | [PriorFactor.make(action_variables[i], 60 | jnp.zeros(dim_u), 61 | jaxfg.noises.Gaussian.make_from_covariance(R_inv)) 62 | for i in range(H)] 63 | 64 | 65 | factors: List[jaxfg.core.FactorBase] = action_state_factors \ 66 | + state_action_state_factors \ 67 | + state_prior_factors \ 68 | + action_prior_factors 69 | 70 | state_action_variables = state_variables + action_variables 71 | 72 | # import ipdb; ipdb.set_trace() 73 | graph = jaxfg.core.StackedFactorGraph.make(factors) 74 | # import ipdb; ipdb.set_trace() 75 | # graph.factor_stacks[0].factor.initial_state[:] = X0 76 | # import ipdb; ipdb.set_trace() 77 | initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 78 | print("Initial assignments:") 79 | print(initial_assignments) 80 | 81 | # t = 0 82 | # initial_assignments = jaxfg.core.VariableAssignments.make_from_defaults(state_action_variables) 83 | # initial_assignments.set_value(state_variables[t], X0) 84 | 85 | # Solve. Note that the first call to solve() will be much slower than subsequent calls. 86 | with jaxfg.utils.stopwatch("First solve (slower because of JIT compilation)"): 87 | solution_assignments = graph.solve(initial_assignments, solver=LevenbergMarquardtSolver()) 88 | solution_assignments.storage.block_until_ready() # type: ignore 89 | 90 | n_trajs = 100 91 | trajs = [] 92 | 93 | for i in range(n_trajs): 94 | new_key, subkey = random.split(key) 95 | X0 = jax.random.uniform(subkey, shape=X0.shape, minval=-3*jnp.pi, maxval=3*jnp.pi) 96 | states_observed = [X0] 97 | actions_taken = [] 98 | 99 | for t in range(T): 100 | # import ipdb; ipdb.set_trace() 101 | graph.factor_stacks[0].factor.initial_state[:] = states_observed[-1] 102 | with jaxfg.utils.stopwatch("Solve after initial compilation"): 103 | solution_assignments = graph.solve(initial_assignments, solver=LevenbergMarquardtSolver(max_iterations=1000)) 104 | solution_assignments.storage.block_until_ready() # type: ignore 105 | action = solution_assignments.get_value(action_variables[0]) # take the first action 106 | actions_taken.append(action) 107 | # import ipdb; ipdb.set_trace() 108 | new_key, subkey = random.split(key) 109 | del key 110 | next_state = pendulum_dynamics(states_observed[-1], action) \ 111 | + jax.random.multivariate_normal(subkey, jnp.zeros(dim_x), cov_dyn) # run in the true environment 112 | states_observed.append(next_state) 113 | del subkey 114 | key = new_key 115 | # import ipdb; ipdb.set_trace() 116 | Xs = jnp.stack(states_observed, 0) 117 | Us = jnp.stack(actions_taken + [jnp.zeros(dim_u)], 0) 118 | traj = jnp.concatenate([Xs, Us], 1) 119 | trajs.append(traj) 120 | 121 | trajs = jnp.stack(trajs, 0) 122 | Xs, Ys = transform_trajs_to_training_data(trajs, dim_x, dim_u) 123 | 124 | train_X, train_Y, test_X, test_Y = make_train_test_split(Xs, Ys, 0.2) 125 | 126 | dataset = {'train_x': train_X, 'train_y': train_Y, 'test_x': test_X, 'test_y': test_Y} 127 | 128 | jnp.save('pendulum_determinstic_dataset', dataset) 129 | 130 | 131 | 132 | # ======================== below is for visualization ============================== 133 | 134 | 135 | # PendulumMPCVis.init_plot() 136 | # PendulumMPCVis.plot_trajectory(true_xs1, true_xs2, us, T+1) # , save_path='assets/pendulum_MPC_watson20.png') 137 | -------------------------------------------------------------------------------- /utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib 3 | import seaborn as sns 4 | import jax.numpy as jnp 5 | 6 | 7 | class PendulumVis: 8 | 9 | @staticmethod 10 | def init_plot(): 11 | PLOT_TIKZ = True 12 | matplotlib.rcParams["font.family"] = "Times New Roman" 13 | matplotlib.rcParams["figure.figsize"] = [25, 10] 14 | matplotlib.rcParams["legend.fontsize"] = 16 15 | matplotlib.rcParams["axes.titlesize"] = 22 16 | matplotlib.rcParams["figure.titlesize"] = 22 17 | matplotlib.rcParams["axes.labelsize"] = 22 18 | 19 | @staticmethod 20 | def plot_trajectory(x1, x2, true_x1, true_x2, u, T, save_path=None): 21 | f, a = plt.subplots(3, 2) 22 | t = range(T) 23 | 24 | # ymax = jnp.pi/8 25 | 26 | a[0, 0].set_title('Predicted/optimized trajectories') 27 | a[0, 0].set_ylabel("$\\theta$") 28 | a[1, 0].set_ylabel("$\dot{\\theta}$") 29 | a[2, 0].set_ylabel("$Nm$") 30 | a[2, 0].set_xlabel("$t$") 31 | a[0, 0].plot(t, x1, "b+-") 32 | a[1, 0].plot(t, x2, "b+-") 33 | a[2, 0].plot(t, u, "b+-") 34 | a[0, 0].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 35 | a[1, 0].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r',linestyle='dashed') 36 | a[2, 0].hlines(y=0., xmin=0, xmax=T, linewidth=2, color='r',linestyle='dashed') 37 | # a[0, 0].set_ylim(-ymax, ymax) 38 | # a[1, 0].set_ylim(-ymax, ymax) 39 | # a[2, 0].set_ylim(-3, 3) 40 | 41 | a[0, 1].set_title('Trajectory under true dynamics') 42 | a[0, 1].set_ylabel("$\\theta$") 43 | a[1, 1].set_ylabel("$\dot{\\theta}$") 44 | a[2, 1].set_ylabel("$Nm$") 45 | a[2, 1].set_xlabel("$t$") 46 | a[0, 1].plot(t, true_x1, "b+-") 47 | a[1, 1].plot(t, true_x2, "b+-") 48 | a[2, 1].plot(t, u, "b+-") 49 | a[0, 1].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 50 | a[1, 1].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 51 | a[2, 1].hlines(y=0., xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 52 | 53 | # a[0, 1].set_ylim(-ymax, ymax) 54 | # a[1, 1].set_ylim(-ymax, ymax) 55 | # a[2, 1].set_ylim(-3, 3) 56 | 57 | f.suptitle('Pendulum Horizon = %d' % (T-1)) 58 | 59 | if save_path is not None: 60 | plt.savefig(save_path) 61 | else: 62 | plt.show() 63 | 64 | 65 | class PendulumMPCVis: 66 | 67 | @staticmethod 68 | def init_plot(): 69 | PLOT_TIKZ = True 70 | matplotlib.rcParams["font.family"] = "Times New Roman" 71 | matplotlib.rcParams["figure.figsize"] = [15, 10] 72 | matplotlib.rcParams["legend.fontsize"] = 16 73 | matplotlib.rcParams["axes.titlesize"] = 22 74 | matplotlib.rcParams["figure.titlesize"] = 22 75 | matplotlib.rcParams["axes.labelsize"] = 22 76 | 77 | @staticmethod 78 | def plot_trajectory(true_x1, true_x2, u, T, save_path=None): 79 | f, a = plt.subplots(3, 1) 80 | t = range(T) 81 | 82 | ymax = jnp.pi 83 | 84 | # a[0, 0].set_title('Predicted/optimized trajectories') 85 | # a[0, 0].set_ylabel("$\\theta$") 86 | # a[1, 0].set_ylabel("$\dot{\\theta}$") 87 | # a[2, 0].set_ylabel("$Nm$") 88 | # a[2, 0].set_xlabel("$t$") 89 | # a[0, 0].plot(t, x1, "b+-") 90 | # a[1, 0].plot(t, x2, "b+-") 91 | # a[2, 0].plot(t, u, "b+-") 92 | # a[0, 0].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 93 | # a[1, 0].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r',linestyle='dashed') 94 | # a[2, 0].hlines(y=0., xmin=0, xmax=T, linewidth=2, color='r',linestyle='dashed') 95 | # a[0, 0].set_ylim(-ymax, ymax) 96 | # a[1, 0].set_ylim(-ymax, ymax) 97 | # a[2, 0].set_ylim(-3, 3) 98 | 99 | a[0].set_title('Trajectory under true dynamics') 100 | a[0].set_ylabel("$\\theta$") 101 | a[1].set_ylabel("$\dot{\\theta}$") 102 | a[2].set_ylabel("$Nm$") 103 | a[2].set_xlabel("$t$") 104 | a[0].plot(t, true_x1, "b+-") 105 | a[1].plot(t, true_x2, "b+-") 106 | a[2].plot(t, u, "b+-") 107 | a[0].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 108 | a[1].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 109 | a[2].hlines(y=0., xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 110 | 111 | # a[0].set_ylim(-ymax, ymax) 112 | # a[1].set_ylim(-ymax, ymax) 113 | # a[2].set_ylim(-5, 5) 114 | 115 | f.suptitle('Pendulum Horizon = %d' % (T-1)) 116 | 117 | if save_path is not None: 118 | plt.savefig(save_path) 119 | else: 120 | plt.show() 121 | 122 | 123 | class LQRVisWatson: 124 | 125 | @staticmethod 126 | def init_plot(): 127 | PLOT_TIKZ = True 128 | matplotlib.rcParams["font.family"] = "Times New Roman" 129 | matplotlib.rcParams["figure.figsize"] = [10, 10] 130 | matplotlib.rcParams["legend.fontsize"] = 16 131 | matplotlib.rcParams["axes.titlesize"] = 22 132 | matplotlib.rcParams["axes.labelsize"] = 22 133 | 134 | @staticmethod 135 | def plot_trajectory(x1, x2, u, T, save_path=None): 136 | f, a = plt.subplots(3, 1) 137 | 138 | t = range(T) 139 | a[0].set_title("State Trajectory") 140 | a[0].set_ylabel("$x_1$") 141 | a[1].set_ylabel("$x_2$") 142 | a[2].set_ylabel("$u$") 143 | a[2].set_xlabel("$t$") 144 | 145 | a[0].plot(t, x1, "k+-") 146 | a[1].plot(t, x2, "k+-") 147 | a[2].plot(t, u, "k+-") 148 | 149 | a[0].hlines(y=10, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 150 | a[1].hlines(y=10, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 151 | a[2].hlines(y=0., xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 152 | a[1].set_ylim(-3, 10.5) 153 | 154 | a[0].legend() 155 | if save_path is not None: 156 | # 'assets/LQR_watson20.png' 157 | plt.savefig(save_path) 158 | else: 159 | plt.show() 160 | 161 | 162 | class LQRVisGTSAM: 163 | 164 | @staticmethod 165 | def init_plot(): 166 | PLOT_TIKZ = True 167 | matplotlib.rcParams["font.family"] = "Times New Roman" 168 | matplotlib.rcParams["figure.figsize"] = [10, 10] 169 | matplotlib.rcParams["legend.fontsize"] = 16 170 | matplotlib.rcParams["axes.titlesize"] = 22 171 | matplotlib.rcParams["axes.labelsize"] = 22 172 | 173 | @staticmethod 174 | def plot_trajectory(x1, true_x, u, T, save_path=None): 175 | f, a = plt.subplots(2, 1) 176 | 177 | t = range(T) 178 | a[0].set_title("State Trajectory") 179 | a[0].set_ylabel("$x$") 180 | a[1].set_ylabel("$u$") 181 | a[1].set_xlabel("$t$") 182 | 183 | a[0].plot(t, x1, "k+-", label='Predicted trajectory') 184 | a[0].plot(t[::2], true_x[::2], "b-", label='True trajectory') 185 | a[1].plot(t, u, "k+-", label='Planned Action') 186 | 187 | a[0].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 188 | a[1].hlines(y=0, xmin=0, xmax=T, linewidth=2, color='r', linestyle='dashed') 189 | # a[1].set_ylim(-3, 10.5) 190 | plt.legend() 191 | 192 | a[0].legend() 193 | if save_path is not None: 194 | # 'assets/LQR_watson20.png' 195 | plt.savefig(save_path) 196 | else: 197 | plt.show() 198 | --------------------------------------------------------------------------------