├── img ├── crane.png ├── pusht.png ├── walker.png ├── humanoid.png ├── particle.png ├── pendulum.png ├── summary.png ├── cart_pole.png └── double_cart_pole.png ├── .gitignore ├── environment.yml ├── gpc ├── __init__.py ├── envs │ ├── __init__.py │ ├── pendulum.py │ ├── cart_pole.py │ ├── particle.py │ ├── walker.py │ ├── double_cart_pole.py │ ├── humanoid.py │ ├── pusht.py │ ├── crane.py │ └── base.py ├── testing.py ├── augmented.py ├── sampling.py ├── policy.py ├── architectures.py └── training.py ├── .pre-commit-config.yaml ├── LICENSE ├── tests ├── test_augmented_ctrl.py ├── test_env.py ├── test_architectures.py └── test_training.py ├── pyproject.toml ├── examples ├── cart_pole.py ├── pendulum.py ├── particle.py ├── double_cart_pole.py ├── walker.py ├── pusht.py ├── crane.py └── humanoid.py └── README.md /img/crane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/crane.png -------------------------------------------------------------------------------- /img/pusht.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/pusht.png -------------------------------------------------------------------------------- /img/walker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/walker.png -------------------------------------------------------------------------------- /img/humanoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/humanoid.png -------------------------------------------------------------------------------- /img/particle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/particle.png -------------------------------------------------------------------------------- /img/pendulum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/pendulum.png -------------------------------------------------------------------------------- /img/summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/summary.png -------------------------------------------------------------------------------- /img/cart_pole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/cart_pole.png -------------------------------------------------------------------------------- /img/double_cart_pole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/double_cart_pole.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | __pycache__ 3 | *.zip 4 | .vscode 5 | *.pt 6 | *.pkl 7 | *.egg-info 8 | MUJOCO_LOG.TXT 9 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gpc 2 | channels: 3 | - nvidia/label/cuda-12.3.0 4 | - conda-forge 5 | dependencies: 6 | - cuda 7 | - cudnn 8 | - python=3.12 9 | -------------------------------------------------------------------------------- /gpc/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Set XLA flags for better performance 4 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9" 5 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=true " 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - repo: https://github.com/charliermarsh/ruff-pre-commit 10 | rev: v0.7.1 11 | hooks: 12 | - id: ruff 13 | types_or: [ python, pyi, jupyter ] 14 | args: [ --fix ] 15 | - id: ruff-format 16 | types_or: [ python, pyi, jupyter ] 17 | -------------------------------------------------------------------------------- /gpc/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SimulatorState, TrainingEnv 2 | from .cart_pole import CartPoleEnv 3 | from .crane import CraneEnv 4 | from .double_cart_pole import DoubleCartPoleEnv 5 | from .humanoid import HumanoidEnv 6 | from .particle import ParticleEnv 7 | from .pendulum import PendulumEnv 8 | from .pusht import PushTEnv 9 | from .walker import WalkerEnv 10 | 11 | __all__ = [ 12 | "SimulatorState", 13 | "TrainingEnv", 14 | "CartPoleEnv", 15 | "CraneEnv", 16 | "DoubleCartPoleEnv", 17 | "ParticleEnv", 18 | "PendulumEnv", 19 | "PushTEnv", 20 | "WalkerEnv", 21 | "HumanoidEnv", 22 | ] 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Vince Kurtz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /gpc/envs/pendulum.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.tasks.pendulum import Pendulum 4 | from mujoco import mjx 5 | 6 | from gpc.envs import TrainingEnv 7 | 8 | 9 | class PendulumEnv(TrainingEnv): 10 | """Training environment for the pendulum swingup task.""" 11 | 12 | def __init__(self, episode_length: int) -> None: 13 | """Set up the pendulum training environment.""" 14 | super().__init__( 15 | task=Pendulum(planning_horizon=5), 16 | episode_length=episode_length, 17 | ) 18 | 19 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 20 | """Reset the simulator to start a new episode.""" 21 | rng, pos_rng, vel_rng = jax.random.split(rng, 3) 22 | qpos = jax.random.uniform(pos_rng, (1,), minval=-jnp.pi, maxval=jnp.pi) 23 | qvel = jax.random.uniform(vel_rng, (1,), minval=-8.0, maxval=8.0) 24 | return data.replace(qpos=qpos, qvel=qvel) 25 | 26 | def get_obs(self, data: mjx.Data) -> jax.Array: 27 | """Observe the velocity and sin/cos of the angle.""" 28 | theta = data.qpos[0] 29 | theta_dot = data.qvel[0] 30 | return jnp.array([jnp.cos(theta), jnp.sin(theta), theta_dot]) 31 | 32 | @property 33 | def observation_size(self) -> int: 34 | """The size of the observation space (sin, cos, theta_dot).""" 35 | return 3 36 | -------------------------------------------------------------------------------- /gpc/envs/cart_pole.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.tasks.cart_pole import CartPole 4 | from mujoco import mjx 5 | 6 | from gpc.envs import TrainingEnv 7 | 8 | 9 | class CartPoleEnv(TrainingEnv): 10 | """Training environment for the cartpole swingup task.""" 11 | 12 | def __init__(self, episode_length: int) -> None: 13 | """Set up the cartpole training environment.""" 14 | super().__init__(task=CartPole(), episode_length=episode_length) 15 | 16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 17 | """Reset the simulator to start a new episode.""" 18 | rng, theta_rng, pos_rng, vel_rng = jax.random.split(rng, 4) 19 | 20 | theta = jax.random.uniform(theta_rng, (), minval=-3.14, maxval=3.14) 21 | pos = jax.random.uniform(pos_rng, (), minval=-1.8, maxval=1.8) 22 | qvel = jax.random.uniform(vel_rng, (2,), minval=-2.0, maxval=2.0) 23 | qpos = jnp.array([pos, theta]) 24 | 25 | return data.replace(qpos=qpos, qvel=qvel) 26 | 27 | def get_obs(self, data: mjx.Data) -> jax.Array: 28 | """Observe the velocity and sin/cos of the angle.""" 29 | p = data.qpos[0] 30 | theta = data.qpos[1] 31 | v = data.qvel[0] 32 | theta_dot = data.qvel[1] 33 | return jnp.array([p, jnp.cos(theta), jnp.sin(theta), v, theta_dot]) 34 | 35 | @property 36 | def observation_size(self) -> int: 37 | """The size of the observation space (includes sin and cos).""" 38 | return 5 39 | -------------------------------------------------------------------------------- /gpc/envs/particle.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.tasks.particle import Particle 4 | from mujoco import mjx 5 | 6 | from gpc.envs import TrainingEnv 7 | 8 | 9 | class ParticleEnv(TrainingEnv): 10 | """Training environment for the particle task.""" 11 | 12 | def __init__(self, episode_length: int = 100) -> None: 13 | """Set up the particle training environment.""" 14 | super().__init__(task=Particle(), episode_length=episode_length) 15 | 16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 17 | """Reset the simulator to start a new episode.""" 18 | rng, pos_rng, vel_rng, mocap_rng = jax.random.split(rng, 4) 19 | qpos = jax.random.uniform(pos_rng, (2,), minval=-0.29, maxval=0.29) 20 | qvel = jax.random.uniform(vel_rng, (2,), minval=-0.5, maxval=0.5) 21 | target = jax.random.uniform(mocap_rng, (2,), minval=-0.29, maxval=0.29) 22 | mocap_pos = data.mocap_pos.at[0, 0:2].set(target) 23 | return data.replace(qpos=qpos, qvel=qvel, mocap_pos=mocap_pos) 24 | 25 | def get_obs(self, data: mjx.Data) -> jax.Array: 26 | """Observe the position relative to the target and the velocity.""" 27 | pos = ( 28 | data.site_xpos[self.task.pointmass_id, 0:2] - data.mocap_pos[0, 0:2] 29 | ) 30 | vel = data.qvel[:] 31 | return jnp.concatenate([pos, vel]) 32 | 33 | @property 34 | def observation_size(self) -> int: 35 | """The size of the observation space.""" 36 | return 4 37 | -------------------------------------------------------------------------------- /tests/test_augmented_ctrl.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.algs import PredictiveSampling 4 | from hydrax.tasks.particle import Particle 5 | from mujoco import mjx 6 | 7 | from gpc.augmented import PolicyAugmentedController 8 | 9 | 10 | def test_augmented() -> None: 11 | """Test the prediction-augmented controller.""" 12 | # Task and optimizer setup 13 | task = Particle() 14 | ps = PredictiveSampling(task, num_samples=32, noise_level=0.1) 15 | opt = PolicyAugmentedController(ps, num_policy_samples=32) 16 | jit_opt = jax.jit(opt.optimize) 17 | 18 | # Initialize the system state and policy parameters 19 | state = mjx.make_data(task.model) 20 | state = state.replace( 21 | mocap_pos=state.mocap_pos.at[0, 0:2].set(jnp.array([0.01, 0.01])) 22 | ) 23 | params = opt.init_params() 24 | params = params.replace( 25 | policy_samples=jnp.ones((32, task.planning_horizon, task.model.nu)) 26 | ) 27 | 28 | for _ in range(10): 29 | # Do an optimization step 30 | params, rollouts = jit_opt(state, params) 31 | 32 | # Pick the best rollout 33 | total_costs = jnp.sum(rollouts.costs, axis=1) 34 | best_idx = jnp.argmin(total_costs) 35 | best_ctrl = rollouts.controls[best_idx] 36 | 37 | assert jnp.all(best_ctrl != 0.0) 38 | assert jnp.all(params.policy_samples == 1.0) 39 | 40 | U = opt.get_action_sequence(params) 41 | assert jnp.allclose(U, params.base_params.mean) 42 | 43 | 44 | if __name__ == "__main__": 45 | test_augmented() 46 | -------------------------------------------------------------------------------- /gpc/envs/walker.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.tasks.walker import Walker 4 | from mujoco import mjx 5 | 6 | from gpc.envs import TrainingEnv 7 | 8 | 9 | class WalkerEnv(TrainingEnv): 10 | """Training environment for the walker task.""" 11 | 12 | def __init__(self, episode_length: int) -> None: 13 | """Set up the walker training environment.""" 14 | super().__init__(task=Walker(), episode_length=episode_length) 15 | 16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 17 | """Reset the simulator to start a new episode.""" 18 | rng, pos_rng, vel_rng = jax.random.split(rng, 3) 19 | 20 | # Joint limits are zero for the floating base 21 | q_min = self.task.model.jnt_range[:, 0] 22 | q_max = self.task.model.jnt_range[:, 1] 23 | q_min = q_min.at[2].set(-1.5) # orientation 24 | q_max = q_max.at[2].set(1.5) 25 | qpos = jax.random.uniform(pos_rng, (9,), minval=q_min, maxval=q_max) 26 | qvel = jax.random.uniform(vel_rng, (9,), minval=-0.1, maxval=0.1) 27 | 28 | return data.replace(qpos=qpos, qvel=qvel) 29 | 30 | def get_obs(self, data: mjx.Data) -> jax.Array: 31 | """Observe everything in the state except the horizontal position.""" 32 | pz = data.qpos[0] # base coordinates are (z, x, theta) 33 | theta = data.qpos[2] 34 | base_pos_data = jnp.array([jnp.cos(theta), jnp.sin(theta), pz]) 35 | return jnp.concatenate([base_pos_data, data.qpos[3:], data.qvel]) 36 | 37 | @property 38 | def observation_size(self) -> int: 39 | """The size of the observation space.""" 40 | return 18 41 | -------------------------------------------------------------------------------- /tests/test_env.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from gpc.envs import ParticleEnv, SimulatorState 7 | 8 | 9 | def test_particle_env() -> None: 10 | """Test the particle environment.""" 11 | rng = jax.random.key(0) 12 | env = ParticleEnv() 13 | 14 | state = env.init_state(rng) 15 | assert state.t == 0 16 | 17 | state2 = env._reset_state(state) 18 | assert jnp.all(state.data.qpos != state2.data.qpos) 19 | assert jnp.all(state.data.qvel != state2.data.qvel) 20 | assert jnp.all(state.rng != state2.rng) 21 | 22 | obs = env._get_observation(state) 23 | assert obs.shape == (4,) 24 | 25 | jit_step = jax.jit(env.step) 26 | 27 | state = jit_step(state, jnp.zeros(2)) 28 | assert state.t == 1 29 | 30 | state = state.replace(t=100) 31 | assert env.episode_over(state) 32 | state = jit_step(state, jnp.zeros(2)) 33 | assert state.t == 0 34 | 35 | 36 | def test_render() -> None: 37 | """Test rendering the particle environment.""" 38 | rng = jax.random.key(0) 39 | env = ParticleEnv() 40 | 41 | rng, init_rng = jax.random.split(rng) 42 | state = env.init_state(init_rng) 43 | 44 | def _step(state: SimulatorState, action: jax.Array) -> Tuple: 45 | state = env.step(state, action) 46 | return state, state 47 | 48 | num_steps = 100 49 | rng, act_rng = jax.random.split(rng) 50 | actions = jax.random.normal(act_rng, (num_steps, 2)) 51 | _, states = jax.lax.scan(_step, state, actions) 52 | 53 | frames = env.render(states, fps=10) 54 | assert frames.shape == (10, 3, 240, 320) 55 | 56 | 57 | if __name__ == "__main__": 58 | test_particle_env() 59 | test_render() 60 | -------------------------------------------------------------------------------- /gpc/envs/double_cart_pole.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.tasks.double_cart_pole import DoubleCartPole 4 | from mujoco import mjx 5 | 6 | from gpc.envs import TrainingEnv 7 | 8 | 9 | class DoubleCartPoleEnv(TrainingEnv): 10 | """Training environment for the double cart-pole swingup task.""" 11 | 12 | def __init__(self, episode_length: int) -> None: 13 | """Set up the training environment.""" 14 | super().__init__(task=DoubleCartPole(), episode_length=episode_length) 15 | 16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 17 | """Reset the simulator to start a new episode.""" 18 | rng, theta_rng, pos_rng, vel_rng = jax.random.split(rng, 4) 19 | 20 | thetas = jax.random.uniform(theta_rng, (2), minval=-3.14, maxval=3.14) 21 | pos = jax.random.uniform(pos_rng, (), minval=-2.8, maxval=2.8) 22 | qvel = jax.random.uniform(vel_rng, (3,), minval=-10.0, maxval=10.0) 23 | qpos = jnp.array([pos, thetas[0], thetas[1]]) 24 | 25 | return data.replace(qpos=qpos, qvel=qvel) 26 | 27 | def get_obs(self, data: mjx.Data) -> jax.Array: 28 | """Observe the velocity and sin/cos of the angles.""" 29 | p = data.qpos[0] 30 | theta1 = data.qpos[1] 31 | theta2 = data.qpos[2] 32 | q_obs = jnp.array( 33 | [ 34 | p, 35 | jnp.cos(theta1), 36 | jnp.sin(theta1), 37 | jnp.cos(theta2), 38 | jnp.sin(theta2), 39 | ] 40 | ) 41 | return jnp.concatenate([q_obs, data.qvel]) 42 | 43 | @property 44 | def observation_size(self) -> int: 45 | """The size of the observation space (includes sin and cos).""" 46 | return 8 47 | -------------------------------------------------------------------------------- /gpc/envs/humanoid.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.tasks.humanoid_standup import HumanoidStandup 4 | from mujoco import mjx 5 | 6 | from gpc.envs import TrainingEnv 7 | 8 | 9 | class HumanoidEnv(TrainingEnv): 10 | """Training environment for humanoid (Unitree G1) standup.""" 11 | 12 | def __init__(self, episode_length: int) -> None: 13 | """Set up the walker training environment.""" 14 | super().__init__(task=HumanoidStandup(), episode_length=episode_length) 15 | 16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 17 | """Reset the simulator to start a new episode.""" 18 | rng, pos_rng, vel_rng, ori_rng = jax.random.split(rng, 4) 19 | 20 | # Random positions and velocities 21 | qpos = self.task.qstand + 0.1 * jax.random.normal( 22 | pos_rng, (self.task.model.nq,) 23 | ) 24 | qvel = 0.1 * jax.random.normal(vel_rng, (self.task.model.nv,)) 25 | 26 | # Random base orientation 27 | u, v, w = jax.random.uniform(ori_rng, (3,)) 28 | quat = jnp.array( 29 | [ 30 | jnp.sqrt(1 - u) * jnp.sin(2 * jnp.pi * v), 31 | jnp.sqrt(1 - u) * jnp.cos(2 * jnp.pi * v), 32 | jnp.sqrt(u) * jnp.sin(2 * jnp.pi * w), 33 | jnp.sqrt(u) * jnp.cos(2 * jnp.pi * w), 34 | ] 35 | ) 36 | qpos = qpos.at[3:7].set(quat) 37 | 38 | return data.replace(qpos=qpos, qvel=qvel) 39 | 40 | def get_obs(self, data: mjx.Data) -> jax.Array: 41 | """Observe the full state, regularized to be agnostic to orientation.""" 42 | height = self.task._get_torso_height(data)[None] 43 | orientation = self.task._get_torso_orientation(data) # upright rotation 44 | return jnp.concatenate([height, orientation, data.qpos[7:], data.qvel]) 45 | 46 | @property 47 | def observation_size(self) -> int: 48 | """The size of the observations.""" 49 | return 56 50 | -------------------------------------------------------------------------------- /gpc/envs/pusht.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.tasks.pusht import PushT 4 | from mujoco import mjx 5 | 6 | from gpc.envs import TrainingEnv 7 | 8 | 9 | class PushTEnv(TrainingEnv): 10 | """Training environment for the pusher-T task.""" 11 | 12 | def __init__(self, episode_length: int) -> None: 13 | """Set up the walker training environment.""" 14 | super().__init__(task=PushT(), episode_length=episode_length) 15 | 16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 17 | """Reset the simulator to start a new episode.""" 18 | rng, pos_rng, vel_rng, goal_pos_rng, goal_ori_rng = jax.random.split( 19 | rng, 5 20 | ) 21 | 22 | # Random configuration for the pusher and the T 23 | q_min = jnp.array([-0.2, -0.2, -jnp.pi, -0.2, -0.2]) 24 | q_max = jnp.array([0.2, 0.2, jnp.pi, 0.2, 0.2]) 25 | qpos = jax.random.uniform(pos_rng, (5,), minval=q_min, maxval=q_max) 26 | 27 | # Velocities fixed at zero 28 | qvel = jax.random.uniform(vel_rng, (5,), minval=-0.0, maxval=0.0) 29 | 30 | # Goal position and orientation fixed at zero 31 | goal = jax.random.uniform(goal_pos_rng, (2,), minval=-0.0, maxval=0.0) 32 | mocap_pos = data.mocap_pos.at[0, 0:2].set(goal) 33 | theta = jax.random.uniform(goal_ori_rng, (), minval=0.0, maxval=0.0) 34 | mocap_quat = jnp.array([[jnp.cos(theta / 2), 0, 0, jnp.sin(theta / 2)]]) 35 | 36 | return data.replace( 37 | qpos=qpos, qvel=qvel, mocap_pos=mocap_pos, mocap_quat=mocap_quat 38 | ) 39 | 40 | def get_obs(self, data: mjx.Data) -> jax.Array: 41 | """Observe positions relative to the target.""" 42 | pusher_pos = data.qpos[-2:] 43 | block_pos = data.qpos[0:2] 44 | block_ori = self.task._get_orientation_err(data)[0:1] 45 | return jnp.concatenate([pusher_pos, block_pos, block_ori]) 46 | 47 | @property 48 | def observation_size(self) -> int: 49 | """The size of the observation space.""" 50 | return 5 51 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "gpc" 7 | version = "0.0.1" 8 | description = "Generative Predictive Control" 9 | readme = "README.md" 10 | license = {text="MIT"} 11 | requires-python = ">=3.12.0" 12 | classifiers = [ 13 | "Development Status :: 3 - Alpha", 14 | "Programming Language :: Python", 15 | ] 16 | dependencies = [ 17 | "hydrax@git+https://github.com/vincekurtz/hydrax@v0.0.1", 18 | "jax[cuda12]==0.4.35", 19 | "flax==0.10.0", 20 | "pytest==8.3.3", 21 | "ruff==0.7.1", 22 | "pre-commit==4.0.1", 23 | "matplotlib==3.9.2", 24 | "mujoco==3.2.4", 25 | "mujoco-mjx==3.2.4", 26 | "evosax==0.1.6", 27 | "optax==0.2.3", 28 | "tensorboard==2.18.0", 29 | "tensorboardX==2.6.2", 30 | "cloudpickle==3.1.0", 31 | "moviepy==1.0.3", 32 | "imageio==2.27", 33 | ] 34 | 35 | [tool.ruff] 36 | line-length = 80 37 | 38 | [tool.ruff.lint] 39 | pydocstyle.convention = "google" 40 | select = [ 41 | "ANN", # annotations 42 | "N", # naming conventions 43 | "D", # docstrings 44 | "B", # flake8 bugbear 45 | "E", # pycodestyle errors 46 | "F", # Pyflakes rules 47 | "I", # isort formatting 48 | "PLC", # Pylint convention warnings 49 | "PLE", # Pylint errors 50 | "PLR", # Pylint refactor recommendations 51 | "PLW", # Pylint warnings 52 | ] 53 | ignore = [ 54 | "ANN003", # missing type annotation for **kwargs 55 | "ANN101", # missing type annotation for `self` in method 56 | "ANN202", # missing return type annotation for private function 57 | "ANN204", # missing return type annotation for `__init__` 58 | "ANN401", # dynamically typed expressions (typing.Any) are disallowed 59 | "D100", # missing docstring in public module 60 | "D104", # missing docstring in public package 61 | "D203", # blank line before class docstring 62 | "D211", # no blank line before class 63 | "D212", # multi-line docstring summary at first line 64 | "D213", # multi-line docstring summary at second line 65 | "E731", # assigning to a `lambda` expression 66 | "N806", # only lowercase variables in functions 67 | "PLR0913", # too many arguments 68 | "PLR2004", # magic value used in comparison 69 | ] 70 | 71 | [tool.ruff.lint.isort] 72 | combine-as-imports = true 73 | known-first-party = ["hydra"] 74 | split-on-trailing-comma = false 75 | 76 | [tool.ruff.format] 77 | quote-style = "double" 78 | indent-style = "space" 79 | skip-magic-trailing-comma = false 80 | line-ending = "auto" 81 | docstring-code-format = true 82 | docstring-code-line-length = "dynamic" 83 | 84 | [tool.pytest.ini_options] 85 | testpaths = [ 86 | "tests", 87 | ] 88 | 89 | [tool.setuptools] 90 | packages = ["gpc"] 91 | -------------------------------------------------------------------------------- /examples/cart_pole.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mujoco 4 | from flax import nnx 5 | from hydrax.algs import PredictiveSampling 6 | from hydrax.simulation.deterministic import run_interactive as run_sampling 7 | 8 | from gpc.architectures import DenoisingMLP 9 | from gpc.envs import CartPoleEnv 10 | from gpc.policy import Policy 11 | from gpc.sampling import BootstrappedPredictiveSampling 12 | from gpc.testing import test_interactive 13 | from gpc.training import train 14 | 15 | if __name__ == "__main__": 16 | # Parse command-line arguments 17 | parser = argparse.ArgumentParser( 18 | description="Balance an inverted pendulum on a cart" 19 | ) 20 | subparsers = parser.add_subparsers( 21 | dest="task", help="What to do (choose one)" 22 | ) 23 | subparsers.add_parser("train", help="Train (and save) a generative policy") 24 | subparsers.add_parser("test", help="Test a generative policy") 25 | subparsers.add_parser( 26 | "sample", help="Bootstrap sampling-based MPC with a generative policy" 27 | ) 28 | args = parser.parse_args() 29 | 30 | # Set up the environment and save file 31 | env = CartPoleEnv(episode_length=200) 32 | save_file = "/tmp/cart_pole_policy.pkl" 33 | 34 | if args.task == "train": 35 | # Train the policy and save it to a file 36 | ctrl = PredictiveSampling(env.task, num_samples=8, noise_level=0.1) 37 | net = DenoisingMLP( 38 | action_size=env.task.model.nu, 39 | observation_size=env.observation_size, 40 | horizon=env.task.planning_horizon, 41 | hidden_layers=[64, 64], 42 | rngs=nnx.Rngs(0), 43 | ) 44 | policy = train( 45 | env, 46 | ctrl, 47 | net, 48 | num_policy_samples=2, 49 | log_dir="/tmp/gpc_cart_pole", 50 | num_iters=10, 51 | num_envs=128, 52 | num_epochs=100, 53 | ) 54 | policy.save(save_file) 55 | print(f"Saved policy to {save_file}") 56 | 57 | elif args.task == "test": 58 | # Load the policy from a file and test it interactively 59 | print(f"Loading policy from {save_file}") 60 | policy = Policy.load(save_file) 61 | test_interactive(env, policy) 62 | 63 | elif args.task == "sample": 64 | # Use the policy to bootstrap sampling-based MPC 65 | policy = Policy.load(save_file) 66 | ctrl = BootstrappedPredictiveSampling( 67 | policy, 68 | observation_fn=env.get_obs, 69 | num_policy_samples=2, 70 | task=env.task, 71 | num_samples=1, 72 | noise_level=0.1, 73 | ) 74 | mj_model = env.task.mj_model 75 | mj_data = mujoco.MjData(mj_model) 76 | run_sampling(ctrl, mj_model, mj_data, frequency=50) 77 | 78 | else: 79 | parser.print_help() 80 | -------------------------------------------------------------------------------- /examples/pendulum.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mujoco 4 | from flax import nnx 5 | from hydrax.algs import PredictiveSampling 6 | from hydrax.simulation.deterministic import run_interactive as run_sampling 7 | 8 | from gpc.architectures import DenoisingMLP 9 | from gpc.envs import PendulumEnv 10 | from gpc.policy import Policy 11 | from gpc.sampling import BootstrappedPredictiveSampling 12 | from gpc.testing import test_interactive 13 | from gpc.training import train 14 | 15 | if __name__ == "__main__": 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser( 18 | description="Swing up an inverted pendulum" 19 | ) 20 | subparsers = parser.add_subparsers( 21 | dest="task", help="What to do (choose one)" 22 | ) 23 | subparsers.add_parser("train", help="Train (and save) a generative policy") 24 | subparsers.add_parser("test", help="Test a generative policy") 25 | subparsers.add_parser( 26 | "sample", help="Bootstrap sampling-based MPC with a generative policy" 27 | ) 28 | args = parser.parse_args() 29 | 30 | # Set up the environment and save file 31 | env = PendulumEnv(episode_length=200) 32 | save_file = "/tmp/pendulum_policy.pkl" 33 | 34 | if args.task == "train": 35 | # Train the policy and save it to a file 36 | ctrl = PredictiveSampling(env.task, num_samples=8, noise_level=0.1) 37 | net = DenoisingMLP( 38 | action_size=env.task.model.nu, 39 | observation_size=env.observation_size, 40 | horizon=env.task.planning_horizon, 41 | hidden_layers=[32, 32], 42 | rngs=nnx.Rngs(0), 43 | ) 44 | policy = train( 45 | env, 46 | ctrl, 47 | net, 48 | num_policy_samples=2, 49 | log_dir="/tmp/gpc_pendulum", 50 | num_epochs=10, 51 | num_iters=10, 52 | num_envs=128, 53 | num_videos=2, 54 | strategy="policy", 55 | ) 56 | policy.save(save_file) 57 | print(f"Saved policy to {save_file}") 58 | 59 | elif args.task == "test": 60 | # Load the policy from a file and test it interactively 61 | print(f"Loading policy from {save_file}") 62 | policy = Policy.load(save_file) 63 | test_interactive(env, policy) 64 | 65 | elif args.task == "sample": 66 | # Use the policy to bootstrap sampling-based MPC 67 | policy = Policy.load(save_file) 68 | ctrl = BootstrappedPredictiveSampling( 69 | policy, 70 | env.get_obs, 71 | num_policy_samples=4, 72 | task=env.task, 73 | num_samples=4, 74 | noise_level=0.1, 75 | ) 76 | 77 | mj_model = env.task.mj_model 78 | mj_data = mujoco.MjData(mj_model) 79 | run_sampling(ctrl, mj_model, mj_data, frequency=50) 80 | 81 | else: 82 | parser.print_help() 83 | -------------------------------------------------------------------------------- /examples/particle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mujoco 4 | from flax import nnx 5 | from hydrax.algs import PredictiveSampling 6 | from hydrax.simulation.deterministic import run_interactive as run_sampling 7 | 8 | from gpc.architectures import DenoisingMLP 9 | from gpc.envs import ParticleEnv 10 | from gpc.policy import Policy 11 | from gpc.sampling import BootstrappedPredictiveSampling 12 | from gpc.testing import test_interactive 13 | from gpc.training import train 14 | 15 | if __name__ == "__main__": 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser( 18 | description="Drive a point mass to a target position" 19 | ) 20 | subparsers = parser.add_subparsers( 21 | dest="task", help="What to do (choose one)" 22 | ) 23 | subparsers.add_parser("train", help="Train (and save) a generative policy") 24 | subparsers.add_parser("test", help="Test a generative policy") 25 | subparsers.add_parser( 26 | "sample", help="Bootstrap sampling-based MPC with a generative policy" 27 | ) 28 | args = parser.parse_args() 29 | 30 | # Set up the environment and save file 31 | env = ParticleEnv(episode_length=100) 32 | save_file = "/tmp/particle_policy.pkl" 33 | 34 | if args.task == "train": 35 | # Train the policy and save it to a file 36 | ctrl = PredictiveSampling(env.task, num_samples=8, noise_level=0.1) 37 | net = DenoisingMLP( 38 | action_size=env.task.model.nu, 39 | observation_size=env.observation_size, 40 | horizon=env.task.planning_horizon, 41 | hidden_layers=[32, 32], 42 | rngs=nnx.Rngs(0), 43 | ) 44 | policy = train( 45 | env, 46 | ctrl, 47 | net, 48 | num_policy_samples=8, 49 | log_dir="/tmp/gpc_particle", 50 | num_iters=10, 51 | num_envs=128, 52 | batch_size=128, 53 | num_epochs=100, 54 | ) 55 | policy.save(save_file) 56 | print(f"Saved policy to {save_file}") 57 | 58 | elif args.task == "test": 59 | # Load the policy from a file and test it interactively 60 | print(f"Loading policy from {save_file}") 61 | policy = Policy.load(save_file) 62 | test_interactive(env, policy) 63 | 64 | elif args.task == "sample": 65 | # Use the policy to bootstrap sampling-based MPC 66 | policy = Policy.load(save_file) 67 | ctrl = BootstrappedPredictiveSampling( 68 | policy, 69 | env.get_obs, 70 | inference_timestep=0.1, 71 | num_policy_samples=4, 72 | task=env.task, 73 | num_samples=1, 74 | noise_level=0.1, 75 | ) 76 | mj_model = env.task.mj_model 77 | mj_data = mujoco.MjData(mj_model) 78 | run_sampling(ctrl, mj_model, mj_data, frequency=50) 79 | 80 | else: 81 | parser.print_help() 82 | -------------------------------------------------------------------------------- /examples/double_cart_pole.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mujoco 4 | from flax import nnx 5 | from hydrax.algs import PredictiveSampling 6 | from hydrax.simulation.deterministic import run_interactive as run_sampling 7 | 8 | from gpc.architectures import DenoisingMLP 9 | from gpc.envs import DoubleCartPoleEnv 10 | from gpc.policy import Policy 11 | from gpc.sampling import BootstrappedPredictiveSampling 12 | from gpc.testing import test_interactive 13 | from gpc.training import train 14 | 15 | if __name__ == "__main__": 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser( 18 | description="Balance a double inverted pendulum on a cart" 19 | ) 20 | subparsers = parser.add_subparsers( 21 | dest="task", help="What to do (choose one)" 22 | ) 23 | subparsers.add_parser("train", help="Train (and save) a generative policy") 24 | subparsers.add_parser("test", help="Test a generative policy") 25 | subparsers.add_parser( 26 | "sample", help="Bootstrap sampling-based MPC with a generative policy" 27 | ) 28 | args = parser.parse_args() 29 | 30 | # Set up the environment and save file 31 | env = DoubleCartPoleEnv(episode_length=400) 32 | save_file = "/tmp/double_cart_pole_policy.pkl" 33 | 34 | if args.task == "train": 35 | # Train the policy and save it to a file 36 | ctrl = PredictiveSampling(env.task, num_samples=16, noise_level=0.3) 37 | net = DenoisingMLP( 38 | action_size=env.task.model.nu, 39 | observation_size=env.observation_size, 40 | horizon=env.task.planning_horizon, 41 | hidden_layers=[128, 128], 42 | rngs=nnx.Rngs(0), 43 | ) 44 | policy = train( 45 | env, 46 | ctrl, 47 | net, 48 | num_policy_samples=16, 49 | log_dir="/tmp/gpc_double_cart_pole", 50 | num_iters=50, 51 | num_envs=256, 52 | num_epochs=100, 53 | checkpoint_every=5, 54 | num_videos=4, 55 | ) 56 | policy.save(save_file) 57 | print(f"Saved policy to {save_file}") 58 | 59 | elif args.task == "test": 60 | # Load the policy from a file and test it interactively 61 | print(f"Loading policy from {save_file}") 62 | policy = Policy.load(save_file) 63 | test_interactive(env, policy, inference_timestep=0.1) 64 | 65 | elif args.task == "sample": 66 | # Use the policy to bootstrap sampling-based MPC 67 | policy = Policy.load(save_file) 68 | ctrl = BootstrappedPredictiveSampling( 69 | policy, 70 | env.get_obs, 71 | inference_timestep=0.01, 72 | num_policy_samples=4, 73 | task=env.task, 74 | num_samples=1, 75 | noise_level=0.3, 76 | ) 77 | mj_model = env.task.mj_model 78 | mj_data = mujoco.MjData(mj_model) 79 | run_sampling(ctrl, mj_model, mj_data, frequency=50) 80 | 81 | else: 82 | parser.print_help() 83 | -------------------------------------------------------------------------------- /examples/walker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mujoco 4 | from flax import nnx 5 | from hydrax.algs import PredictiveSampling 6 | from hydrax.simulation.deterministic import run_interactive as run_sampling 7 | 8 | from gpc.architectures import DenoisingCNN 9 | from gpc.envs import WalkerEnv 10 | from gpc.policy import Policy 11 | from gpc.sampling import BootstrappedPredictiveSampling 12 | from gpc.testing import test_interactive 13 | from gpc.training import train 14 | 15 | if __name__ == "__main__": 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser(description="Planar biped locomotion") 18 | subparsers = parser.add_subparsers( 19 | dest="task", help="What to do (choose one)" 20 | ) 21 | subparsers.add_parser("train", help="Train (and save) a generative policy") 22 | subparsers.add_parser("test", help="Test a generative policy") 23 | subparsers.add_parser( 24 | "sample", help="Bootstrap sampling-based MPC with a generative policy" 25 | ) 26 | args = parser.parse_args() 27 | 28 | # Set up the environment and save file 29 | env = WalkerEnv(episode_length=500) 30 | save_file = "/tmp/walker_policy.pkl" 31 | 32 | if args.task == "train": 33 | # Train the policy and save it to a file 34 | ctrl = PredictiveSampling(env.task, num_samples=16, noise_level=0.3) 35 | net = DenoisingCNN( 36 | action_size=env.task.model.nu, 37 | observation_size=env.observation_size, 38 | horizon=env.task.planning_horizon, 39 | feature_dims=[64, 64], 40 | timestep_embedding_dim=16, 41 | rngs=nnx.Rngs(0), 42 | ) 43 | policy = train( 44 | env, 45 | ctrl, 46 | net, 47 | log_dir="/tmp/gpc_walker", 48 | num_policy_samples=16, 49 | num_iters=20, 50 | num_envs=128, 51 | num_epochs=10, 52 | ) 53 | policy.save(save_file) 54 | print(f"Saved policy to {save_file}") 55 | 56 | elif args.task == "test": 57 | # Load the policy from a file and test it interactively 58 | print(f"Loading policy from {save_file}") 59 | policy = Policy.load(save_file) 60 | test_interactive( 61 | env, policy, inference_timestep=0.01, warm_start_level=1.0 62 | ) 63 | 64 | elif args.task == "sample": 65 | # Use the policy to bootstrap sampling-based MPC 66 | policy = Policy.load(save_file) 67 | ctrl = BootstrappedPredictiveSampling( 68 | policy, 69 | env.get_obs, 70 | warm_start_level=0.5, 71 | inference_timestep=0.1, 72 | num_policy_samples=128, 73 | task=env.task, 74 | num_samples=1, 75 | noise_level=0.3, 76 | ) 77 | 78 | mj_model = env.task.mj_model 79 | mj_data = mujoco.MjData(mj_model) 80 | run_sampling(ctrl, mj_model, mj_data, frequency=50, fixed_camera_id=0) 81 | 82 | else: 83 | parser.print_help() 84 | -------------------------------------------------------------------------------- /examples/pusht.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mujoco 4 | from flax import nnx 5 | from hydrax.algs import PredictiveSampling 6 | from hydrax.simulation.deterministic import run_interactive as run_sampling 7 | 8 | from gpc.architectures import DenoisingCNN 9 | from gpc.envs import PushTEnv 10 | from gpc.policy import Policy 11 | from gpc.sampling import BootstrappedPredictiveSampling 12 | from gpc.testing import test_interactive 13 | from gpc.training import train 14 | 15 | if __name__ == "__main__": 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser( 18 | description="Push a T-shaped block on a table" 19 | ) 20 | subparsers = parser.add_subparsers( 21 | dest="task", help="What to do (choose one)" 22 | ) 23 | subparsers.add_parser("train", help="Train (and save) a generative policy") 24 | subparsers.add_parser("test", help="Test a generative policy") 25 | subparsers.add_parser( 26 | "sample", help="Bootstrap sampling-based MPC with a generative policy" 27 | ) 28 | args = parser.parse_args() 29 | 30 | # Set up the environment and save file 31 | env = PushTEnv(episode_length=400) 32 | save_file = "/tmp/pusht_policy.pkl" 33 | 34 | if args.task == "train": 35 | # Train the policy and save it to a file 36 | ctrl = PredictiveSampling(env.task, num_samples=128, noise_level=0.4) 37 | net = DenoisingCNN( 38 | action_size=env.task.model.nu, 39 | observation_size=env.observation_size, 40 | horizon=env.task.planning_horizon, 41 | feature_dims=[32, 32, 32], 42 | timestep_embedding_dim=8, 43 | rngs=nnx.Rngs(0), 44 | ) 45 | policy = train( 46 | env, 47 | ctrl, 48 | net, 49 | num_policy_samples=32, 50 | log_dir="/tmp/gpc_pusht", 51 | num_iters=20, 52 | num_envs=128, 53 | num_epochs=10, 54 | checkpoint_every=5, 55 | ) 56 | policy.save(save_file) 57 | print(f"Saved policy to {save_file}") 58 | 59 | elif args.task == "test": 60 | # Load the policy from a file and test it interactively 61 | print(f"Loading policy from {save_file}") 62 | policy = Policy.load(save_file) 63 | mj_data = mujoco.MjData(env.task.mj_model) 64 | mj_data.qpos[:] = [0.1, 0.1, 2.0, 0.0, 0.0] # set the initial state 65 | test_interactive(env, policy, mj_data) 66 | 67 | elif args.task == "sample": 68 | # Use the policy to bootstrap sampling-based MPC 69 | policy = Policy.load(save_file) 70 | ctrl = BootstrappedPredictiveSampling( 71 | policy, 72 | env.get_obs, 73 | warm_start_level=0.5, 74 | num_policy_samples=32, 75 | task=env.task, 76 | num_samples=1, 77 | noise_level=0.1, 78 | ) 79 | mj_model = env.task.mj_model 80 | mj_data = mujoco.MjData(mj_model) 81 | run_sampling(ctrl, mj_model, mj_data, frequency=50) 82 | 83 | else: 84 | parser.print_help() 85 | -------------------------------------------------------------------------------- /gpc/envs/crane.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from hydrax.tasks.crane import Crane 4 | from mujoco import mjx 5 | 6 | from gpc.envs import TrainingEnv 7 | 8 | 9 | class CraneEnv(TrainingEnv): 10 | """Training environment for the luffing crane end-effetor tracking task.""" 11 | 12 | def __init__(self, episode_length: int = 100) -> None: 13 | """Set up the particle training environment.""" 14 | super().__init__(task=Crane(), episode_length=episode_length) 15 | 16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 17 | """Reset the simulator to start a new episode.""" 18 | rng, pos_rng, vel_rng, target_rng = jax.random.split(rng, 4) 19 | 20 | # Crane state 21 | # TODO: figure out a more principled way to initialize these. 22 | # Right now the crane swings wildly at initialization, which prevents 23 | # gathering a lot of data near the target (which is where the system is 24 | # mostly at run time). 25 | q_lim = jnp.array( 26 | [ 27 | [-1.0, 1.0], # slew 28 | [0.0, 1.0], # luff 29 | [-1.0, 1.0], # payload x-pos 30 | [1.0, 2.2], # payload y-pos 31 | [0.3, 1.0], # payload z-pos 32 | [1.0, 1.0], # payload orientation (fixed upright) 33 | [0.0, 0.0], 34 | [0.0, 0.0], 35 | [0.0, 0.0], 36 | ] 37 | ) 38 | qpos = self.task.model.qpos0 + jax.random.uniform( 39 | pos_rng, 40 | (self.task.model.nq,), 41 | minval=q_lim[:, 0], 42 | maxval=q_lim[:, 1], 43 | ) 44 | qvel = jax.random.uniform( 45 | vel_rng, (self.task.model.nv,), minval=-0.1, maxval=0.1 46 | ) 47 | 48 | # Target position 49 | # TODO: figure out a better set of potential target positions 50 | pos_min = jnp.array([-1.5, 1.0, 0.0]) 51 | pos_max = jnp.array([1.5, 3.0, 1.5]) 52 | target_pos = jax.random.uniform( 53 | target_rng, (3,), minval=pos_min, maxval=pos_max 54 | ) 55 | mocap_pos = data.mocap_pos.at[0].set(target_pos) 56 | 57 | # Target orientation - this is unused but must be set so vectorization 58 | # (which is determined by the size of rng) works properly. 59 | target_quat = jnp.array([1.0, 0.0, 0.0, 0.0]) + jax.random.uniform( 60 | target_rng, (4,), minval=-0.0, maxval=0.0 61 | ) 62 | mocap_quat = data.mocap_quat.at[0].set(target_quat) 63 | 64 | return data.replace( 65 | qpos=qpos, qvel=qvel, mocap_pos=mocap_pos, mocap_quat=mocap_quat 66 | ) 67 | 68 | def get_obs(self, data: mjx.Data) -> jax.Array: 69 | """Observe the full crane state, plus end-effector pos/vel.""" 70 | ee_pos = self.task._get_payload_position(data) 71 | ee_vel = self.task._get_payload_velocity(data) 72 | return jnp.concatenate([ee_pos, ee_vel, data.qpos, data.qvel]) 73 | 74 | @property 75 | def observation_size(self) -> int: 76 | """The size of the observation space.""" 77 | return 23 78 | -------------------------------------------------------------------------------- /examples/crane.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mujoco 4 | from flax import nnx 5 | from hydrax.algs import PredictiveSampling 6 | from hydrax.simulation.deterministic import run_interactive as run_sampling 7 | 8 | from gpc.architectures import DenoisingCNN 9 | from gpc.envs import CraneEnv 10 | from gpc.policy import Policy 11 | from gpc.sampling import BootstrappedPredictiveSampling 12 | from gpc.testing import test_interactive 13 | from gpc.training import train 14 | 15 | if __name__ == "__main__": 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser( 18 | description="Swing the payload of a luffing crane to a target position." 19 | ) 20 | subparsers = parser.add_subparsers( 21 | dest="task", help="What to do (choose one)" 22 | ) 23 | subparsers.add_parser("train", help="Train (and save) a generative policy") 24 | subparsers.add_parser("test", help="Test a generative policy") 25 | subparsers.add_parser( 26 | "sample", help="Bootstrap sampling-based MPC with a generative policy" 27 | ) 28 | args = parser.parse_args() 29 | 30 | # Set up the environment and save file 31 | env = CraneEnv(episode_length=500) 32 | save_file = "/tmp/crane_policy.pkl" 33 | 34 | if args.task == "train": 35 | # Train the policy and save it to a file 36 | ctrl = PredictiveSampling(env.task, num_samples=4, noise_level=0.05) 37 | net = DenoisingCNN( 38 | action_size=env.task.model.nu, 39 | observation_size=env.observation_size, 40 | horizon=env.task.planning_horizon, 41 | feature_dims=[64, 64], 42 | timestep_embedding_dim=16, 43 | rngs=nnx.Rngs(0), 44 | ) 45 | policy = train( 46 | env, 47 | ctrl, 48 | net, 49 | num_policy_samples=2, 50 | log_dir="/tmp/gpc_crane", 51 | num_iters=10, 52 | num_envs=512, 53 | num_videos=4, 54 | batch_size=1024, 55 | num_epochs=20, 56 | checkpoint_every=5, 57 | strategy="policy", 58 | ) 59 | policy.save(save_file) 60 | print(f"Saved policy to {save_file}") 61 | 62 | elif args.task == "test": 63 | # Load the policy from a file and test it interactively 64 | print(f"Loading policy from {save_file}") 65 | policy = Policy.load(save_file) 66 | test_interactive(env, policy) 67 | 68 | elif args.task == "sample": 69 | # Use the policy to bootstrap sampling-based MPC 70 | policy = Policy.load(save_file) 71 | ctrl = BootstrappedPredictiveSampling( 72 | policy, 73 | env.get_obs, 74 | inference_timestep=0.1, 75 | num_policy_samples=4, # samples from the flow-matching policy 76 | task=env.task, 77 | num_samples=4, # samples from a gaussian around the previous mean 78 | noise_level=0.05, 79 | ) 80 | mj_model = env.task.mj_model 81 | mj_data = mujoco.MjData(mj_model) 82 | run_sampling(ctrl, mj_model, mj_data, frequency=30) 83 | 84 | else: 85 | parser.print_help() 86 | -------------------------------------------------------------------------------- /gpc/testing.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import mujoco 7 | import mujoco.viewer 8 | from mujoco import mjx 9 | 10 | from gpc.envs import TrainingEnv 11 | from gpc.policy import Policy 12 | 13 | 14 | def test_interactive( 15 | env: TrainingEnv, 16 | policy: Policy, 17 | mj_data: mujoco.MjData = None, 18 | inference_timestep: float = 0.1, 19 | warm_start_level: float = 1.0, 20 | ) -> None: 21 | """Test a GPC policy with an interactive simulation. 22 | 23 | Args: 24 | env: The environment, which defines the system to simulate. 25 | policy: The GPC policy to test. 26 | mj_data: The initial state for the simulation. 27 | inference_timestep: The timestep dt to use for flow matching inference. 28 | warm_start_level: The warm start level to use for the policy. 29 | """ 30 | rng = jax.random.key(0) 31 | task = env.task 32 | 33 | # Set up the policy 34 | policy = policy.replace(dt=inference_timestep) 35 | policy.model.eval() 36 | jit_policy = jax.jit( 37 | partial(policy.apply, warm_start_level=warm_start_level) 38 | ) 39 | 40 | # Set up the mujoco simultion 41 | mj_model = task.mj_model 42 | if mj_data is None: 43 | mj_data = mujoco.MjData(mj_model) 44 | 45 | # Initialize the action sequence 46 | actions = jnp.zeros((task.planning_horizon, task.model.nu)) 47 | 48 | # Set up an observation function 49 | mjx_data = mjx.make_data(task.model) 50 | 51 | @jax.jit 52 | def get_obs(mjx_data: mjx.Data) -> jax.Array: 53 | """Get an observation from the mujoco data.""" 54 | mjx_data = mjx.forward(task.model, mjx_data) # update sites & sensors 55 | return env.get_obs(mjx_data) 56 | 57 | # Run the simulation 58 | with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer: 59 | while viewer.is_running(): 60 | st = time.time() 61 | 62 | # Get an observation 63 | mjx_data = mjx_data.replace( 64 | qpos=jnp.array(mj_data.qpos), 65 | qvel=jnp.array(mj_data.qvel), 66 | mocap_pos=jnp.array(mj_data.mocap_pos), 67 | mocap_quat=jnp.array(mj_data.mocap_quat), 68 | ) 69 | obs = get_obs(mjx_data) 70 | 71 | # Update the action sequence 72 | inference_start = time.time() 73 | rng, policy_rng = jax.random.split(rng) 74 | actions = jit_policy(actions, obs, policy_rng) 75 | mj_data.ctrl[:] = actions[0] 76 | 77 | inference_time = time.time() - inference_start 78 | obs_time = inference_start - st 79 | print( 80 | f" Observation time: {obs_time:.5f}s " 81 | f" Inference time: {inference_time:.5f}s", 82 | end="\r", 83 | ) 84 | 85 | mujoco.mj_step(mj_model, mj_data) 86 | viewer.sync() 87 | 88 | elapsed = time.time() - st 89 | if elapsed < mj_model.opt.timestep: 90 | time.sleep(mj_model.opt.timestep - elapsed) 91 | 92 | # Save what was last in the print buffer 93 | print("") 94 | -------------------------------------------------------------------------------- /examples/humanoid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mujoco 4 | from flax import nnx 5 | from hydrax.algs import MPPI 6 | from hydrax.simulation.deterministic import run_interactive as run_sampling 7 | 8 | from gpc.architectures import DenoisingCNN 9 | from gpc.envs import HumanoidEnv 10 | from gpc.policy import Policy 11 | from gpc.sampling import BootstrappedPredictiveSampling 12 | from gpc.testing import test_interactive 13 | from gpc.training import train 14 | 15 | if __name__ == "__main__": 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser( 18 | description="Humanoid standup from arbitrary initial positions" 19 | ) 20 | subparsers = parser.add_subparsers( 21 | dest="task", help="What to do (choose one)" 22 | ) 23 | subparsers.add_parser("train", help="Train (and save) a generative policy") 24 | subparsers.add_parser("test", help="Test a generative policy") 25 | subparsers.add_parser( 26 | "sample", help="Bootstrap sampling-based MPC with a generative policy" 27 | ) 28 | args = parser.parse_args() 29 | 30 | # Set up the environment and save file 31 | env = HumanoidEnv(episode_length=400) 32 | save_file = "/tmp/humanoid_policy.pkl" 33 | 34 | if args.task == "train": 35 | # Train the policy and save it to a file 36 | ctrl = MPPI( 37 | env.task, 38 | num_samples=32, 39 | noise_level=1.0, 40 | temperature=0.1, 41 | num_randomizations=2, 42 | ) 43 | net = DenoisingCNN( 44 | action_size=env.task.model.nu, 45 | observation_size=env.observation_size, 46 | horizon=env.task.planning_horizon, 47 | feature_dims=(128,) * 3, 48 | timestep_embedding_dim=64, 49 | rngs=nnx.Rngs(0), 50 | ) 51 | policy = train( 52 | env, 53 | ctrl, 54 | net, 55 | num_policy_samples=32, 56 | log_dir="/tmp/gpc_humanoid", 57 | num_epochs=10, 58 | num_iters=50, 59 | num_envs=128, 60 | num_videos=2, 61 | checkpoint_every=1, 62 | strategy="best", 63 | ) 64 | policy.save(save_file) 65 | print(f"Saved policy to {save_file}") 66 | 67 | elif args.task == "test": 68 | # Load the policy from a file and test it interactively 69 | print(f"Loading policy from {save_file}") 70 | policy = Policy.load(save_file) 71 | test_interactive(env, policy) 72 | 73 | elif args.task == "sample": 74 | # Use the policy to bootstrap sampling-based MPC 75 | policy = Policy.load(save_file) 76 | ctrl = BootstrappedPredictiveSampling( 77 | policy, 78 | env.get_obs, 79 | num_policy_samples=128, 80 | warm_start_level=0.9, 81 | task=env.task, 82 | num_samples=128, 83 | noise_level=0.5, 84 | num_randomizations=2, 85 | ) 86 | 87 | mj_model = env.task.mj_model 88 | mj_model.opt.timestep = 0.01 89 | 90 | mj_data = mujoco.MjData(mj_model) 91 | mj_data.qpos[3:7] = [-0.7, 0.0, 0.7, 0.0] 92 | 93 | run_sampling(ctrl, mj_model, mj_data, frequency=50, show_traces=False) 94 | 95 | else: 96 | parser.print_help() 97 | -------------------------------------------------------------------------------- /gpc/augmented.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from flax.struct import dataclass 6 | from hydrax.alg_base import SamplingBasedController, Trajectory 7 | 8 | 9 | @dataclass 10 | class PACParams: 11 | """Parameters for the policy-augmented controller. 12 | 13 | Attributes: 14 | base_params: The parameters for the base controller. 15 | policy_samples: Control sequences sampled from the policy. 16 | rng: Random number generator key for domain randomization. 17 | """ 18 | 19 | base_params: Any 20 | policy_samples: jax.Array 21 | rng: jax.Array 22 | 23 | 24 | class PolicyAugmentedController(SamplingBasedController): 25 | """An SPC generalization where samples are augmented by a learned policy.""" 26 | 27 | def __init__( 28 | self, 29 | base_ctrl: SamplingBasedController, 30 | num_policy_samples: int, 31 | ) -> None: 32 | """Initialize the policy-augmented controller. 33 | 34 | Args: 35 | base_ctrl: The base controller to augment. 36 | num_policy_samples: The number of samples to draw from the policy. 37 | """ 38 | self.base_ctrl = base_ctrl 39 | self.num_policy_samples = num_policy_samples 40 | super().__init__( 41 | base_ctrl.task, 42 | base_ctrl.num_randomizations, 43 | base_ctrl.risk_strategy, 44 | seed=0, 45 | ) 46 | 47 | def init_params(self) -> PACParams: 48 | """Initialize the controller parameters.""" 49 | base_params = self.base_ctrl.init_params() 50 | base_rng, our_rng = jax.random.split(base_params.rng) 51 | base_params = base_params.replace(rng=base_rng) 52 | policy_samples = jnp.zeros( 53 | ( 54 | self.num_policy_samples, 55 | self.task.planning_horizon, 56 | self.task.model.nu, 57 | ) 58 | ) 59 | return PACParams( 60 | base_params=base_params, 61 | policy_samples=policy_samples, 62 | rng=our_rng, 63 | ) 64 | 65 | def sample_controls(self, params: PACParams) -> Tuple[jax.Array, PACParams]: 66 | """Sample control sequences from the base controller and the policy.""" 67 | # Samples from the base controller 68 | base_samples, base_params = self.base_ctrl.sample_controls( 69 | params.base_params 70 | ) 71 | 72 | # Include samples from the policy. Assumes that thes have already been 73 | # generated and stored in params.policy_samples. 74 | samples = jnp.append(base_samples, params.policy_samples, axis=0) 75 | 76 | return samples, params.replace(base_params=base_params) 77 | 78 | def update_params( 79 | self, params: PACParams, rollouts: Trajectory 80 | ) -> PACParams: 81 | """Update the policy parameters according to the base controller.""" 82 | base_params = self.base_ctrl.update_params(params.base_params, rollouts) 83 | return params.replace(base_params=base_params) 84 | 85 | def get_action(self, params: PACParams, t: float) -> jax.Array: 86 | """Get the action from the base controller at a given time.""" 87 | return self.base_ctrl.get_action(params.base_params, t) 88 | 89 | def get_action_sequence(self, params: PACParams) -> jax.Array: 90 | """Get the action sequence from the controller.""" 91 | timesteps = jnp.arange(self.task.planning_horizon) * self.task.dt 92 | return jax.vmap(self.get_action, in_axes=(None, 0))(params, timesteps) 93 | -------------------------------------------------------------------------------- /gpc/sampling.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from hydrax.alg_base import Trajectory 6 | from hydrax.algs.predictive_sampling import PredictiveSampling 7 | from mujoco import mjx 8 | 9 | from gpc.policy import Policy 10 | 11 | 12 | class BootstrappedPredictiveSampling(PredictiveSampling): 13 | """Perform predictive sampling, but add samples from a generative policy.""" 14 | 15 | def __init__( 16 | self, 17 | policy: Policy, 18 | observation_fn: Callable[[mjx.Data], jax.Array], 19 | num_policy_samples: int, 20 | warm_start_level: float = 0.0, 21 | inference_timestep: float = 0.1, 22 | **kwargs, 23 | ): 24 | """Initialize the controller. 25 | 26 | Args: 27 | policy: The generative policy to sample from. 28 | observation_fn: A function that produces an observation vector. 29 | num_policy_samples: The number of samples to take from the policy. 30 | warm_start_level: The warm start level in [0, 1] to use for the 31 | policy samples. 0.0 generates samples from scratch, while 1.0 32 | seed all samples from the previous action sequence. 33 | inference_timestep: The timestep dt for flow matching inference. 34 | **kwargs: Constructor arguments for PredictiveSampling. 35 | """ 36 | self.observation_fn = observation_fn 37 | self.policy = policy.replace(dt=inference_timestep) 38 | self.policy.model.eval() # Don't update batch statistics 39 | self.warm_start_level = jnp.clip(warm_start_level, 0.0, 1.0) 40 | self.num_policy_samples = num_policy_samples 41 | 42 | super().__init__(**kwargs) 43 | 44 | def optimize(self, state: mjx.Data, params: Any) -> Tuple[Any, Trajectory]: 45 | """Perform an optimization step to update the policy parameters. 46 | 47 | In addition to sampling random control sequences, also sample control 48 | sequences from the generative policy. 49 | 50 | Args: 51 | state: The initial state x₀. 52 | params: The current policy parameters, U ~ π(params). 53 | 54 | Returns: 55 | Updated policy parameters 56 | Rollouts used to update the parameters 57 | """ 58 | rng, policy_rng, dr_rng = jax.random.split(params.rng, 3) 59 | 60 | # Sample random control sequences 61 | controls, params = self.sample_controls(params) 62 | controls = jnp.clip(controls, self.task.u_min, self.task.u_max) 63 | 64 | # Update sensor readings and get an observation 65 | state = mjx.forward(self.task.model, state) 66 | y = self.observation_fn(state) 67 | 68 | # Sample from the generative policy, which is conditioned on the latest 69 | # observation. 70 | policy_rngs = jax.random.split(policy_rng, self.num_policy_samples) 71 | policy_controls = jax.vmap( 72 | self.policy.apply, in_axes=(None, None, 0, None) 73 | )( 74 | params.mean, 75 | y, 76 | policy_rngs, 77 | self.warm_start_level, 78 | ) 79 | 80 | # Combine the random and policy samples 81 | controls = jnp.concatenate([controls, policy_controls], axis=0) 82 | 83 | # Roll out the control sequences, applying domain randomizations and 84 | # combining costs using self.risk_strategy. 85 | rollouts = self.rollout_with_randomizations(state, controls, dr_rng) 86 | 87 | # Update the policy parameters based on the combined costs 88 | params = params.replace(rng=rng) 89 | params = self.update_params(params, rollouts) 90 | return params, rollouts 91 | -------------------------------------------------------------------------------- /gpc/policy.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple, Union 3 | 4 | import cloudpickle 5 | import jax 6 | import jax.numpy as jnp 7 | from flax import nnx 8 | from flax.struct import dataclass 9 | 10 | 11 | @dataclass 12 | class Policy: 13 | """A pickle-able Generative Predictive Control policy. 14 | 15 | Generates action sequences using flow matching, conditioned on the latest 16 | observation, e.g., samples U = [u_0, u_1, ...] ~ p(U | y). 17 | 18 | Attributes: 19 | model: The flow matching network that generates the action sequence. 20 | normalizer: Observation normalization module. 21 | u_min: The minimum action values. 22 | u_max: The maximum action values. 23 | dt: The integration step size for flow matching. 24 | """ 25 | 26 | model: nnx.Module 27 | normalizer: nnx.BatchNorm 28 | u_min: jax.Array 29 | u_max: jax.Array 30 | dt: float = 0.1 31 | 32 | def save(self, path: Union[Path, str]) -> None: 33 | """Save the policy to a file. 34 | 35 | Args: 36 | path: The path to save the policy to. 37 | """ 38 | with open(path, "wb") as f: 39 | cloudpickle.dump(self, f) 40 | 41 | @staticmethod 42 | def load(path: Union[Path, str]) -> "Policy": 43 | """Load a policy from a file. 44 | 45 | Args: 46 | path: The path to load the policy from. 47 | 48 | Returns: 49 | The loaded policy instance 50 | """ 51 | with open(path, "rb") as f: 52 | policy = cloudpickle.load(f) 53 | return policy 54 | 55 | def apply( 56 | self, 57 | prev: jax.Array, 58 | y: jax.Array, 59 | rng: jax.Array, 60 | warm_start_level: float = 0.0, 61 | ) -> jax.Array: 62 | """Generate an action sequence conditioned on the observation. 63 | 64 | Args: 65 | prev: The previous action sequence. 66 | y: The current observation. 67 | rng: The random number generator key. 68 | warm_start_level: The degree of warm-starting to use, in [0, 1]. 69 | 70 | A warm-start level of 0.0 means the action sequence is generated from 71 | scratch, with the seed for flow matching drawn from a random normal 72 | distribution. A warm-start level of 1.0 means the seed is the previous 73 | action sequence. Values in between interpolate between these two, with 74 | larger values giving smoother but less exploratory action sequences. 75 | 76 | Returns: 77 | The updated action sequence 78 | """ 79 | # Normalize the observation, but don't update the stored statistics 80 | y = self.normalizer(y, use_running_average=True) 81 | 82 | # Set the initial sample 83 | warm_start_level = jnp.clip(warm_start_level, 0.0, 1.0) 84 | noise = jax.random.normal(rng, prev.shape) 85 | U = warm_start_level * prev + (1 - warm_start_level) * noise 86 | 87 | def _step(args: Tuple[jax.Array, float]) -> Tuple[jax.Array, float]: 88 | """Flow the sample U along the learned vector field.""" 89 | U, t = args 90 | U += self.dt * self.model(U, y, t) 91 | U = jax.numpy.clip(U, -1, 1) 92 | return U, t + self.dt 93 | 94 | # While t < 1, U += dt * model(U, y, t) 95 | U, t = jax.lax.while_loop( 96 | lambda args: jnp.all(args[1] < 1.0), 97 | _step, 98 | (U, jnp.zeros(1)), 99 | ) 100 | 101 | # Rescale actions from [-1, 1] to [u_min, u_max] 102 | mean = (self.u_max + self.u_min) / 2 103 | scale = (self.u_max - self.u_min) / 2 104 | U = U * scale + mean 105 | 106 | return U 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Predictive Control 2 | 3 | This repository contains code for the paper ["Generative Predictive Control: Flow 4 | Matching Policies for Dynamic and Difficult-to-Demonstrate Tasks"](https://arxiv.org/abs/2502.13406) 5 | by Vince Kurtz and Joel Burdick. [Video summary](https://youtu.be/mjL7CF877Ow). 6 | 7 | This includes code for training and testing flow-matching policies on each of 8 | the robot systems shown below: 9 | 10 | [](examples/cart_pole.py) 11 | [](examples/double_cart_pole.py) 12 | [](examples/pusht.py) 13 | [](examples/walker.py) 14 | [](examples/crane.py) 15 | [](examples/humanoid.py) 16 | 17 | Generative Predictive Control (GPC) is a supervised learning framework for 18 | training flow-matching policies on tasks that are difficult to demonstrate but 19 | easy to simulate. GPC alternates between generating training data with 20 | [sampling-based predictive control](https://github.com/vincekurtz/hydrax), 21 | fitting a generative model to the data, and using the generative model to 22 | improve the sampling distribution. 23 | 24 |
25 | 26 |
27 | 28 | ## Install (Conda) 29 | 30 | Clone and create the conda env (first time only): 31 | ```bash 32 | git clone https://github.com/vincekurtz/gpc.git 33 | cd gpc 34 | conda env create -f environment.yml 35 | ``` 36 | 37 | Enter the conda env: 38 | 39 | ```bash 40 | conda activate gpc 41 | ``` 42 | 43 | Install the package and dependencies: 44 | 45 | ```bash 46 | pip install -e . 47 | ``` 48 | 49 | ## Examples 50 | 51 | Various examples can be found in the [`examples`](examples) directory. For 52 | example, to train a cart-pole swingup policy using GPC, run: 53 | 54 | ```bash 55 | python examples/cart_pole.py train 56 | ``` 57 | 58 | This will train a flow-matching policy and save it to 59 | `/tmp/cart_pole_policy.pkl`. To run an interactive simulation with the trained 60 | policy, run 61 | 62 | ```bash 63 | python examples/cart_pole.py test 64 | ``` 65 | 66 | To see other command-line options, run 67 | 68 | ```bash 69 | python examples/cart_pole.py --help 70 | ``` 71 | 72 | ## Using a Different Robot Model 73 | 74 | To try GPC on your own robot or task, you will need to: 75 | 76 | 1. Define a [Hydrax 77 | task](https://github.com/vincekurtz/hydrax?tab=readme-ov-file#design-your-own-task) 78 | that encodes the cost function and system dynamics. 79 | 2. Define a training environment that inherits from 80 | [`gpc.envs.base.TrainingEnv`](gpc/envs/base.py). This must implement the 81 | `reset`, `get_obs`, and `observation_size` methods. For example: 82 | 83 | ```python 84 | class MyCustomEnv(TrainingEnv): 85 | def __init__(self): 86 | super().__init__(task=MyCustomHydraxTask(), episode_length=100) 87 | 88 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 89 | """Reset the simulator to start a new episode.""" 90 | ... 91 | return new_data 92 | 93 | def get_obs(self, data: mjx.Data) -> jax.Array: 94 | """Get the observation from the simulator.""" 95 | ... 96 | return jax.array([obs1, obs2, ...]) 97 | 98 | @property 99 | def observation_size(self) -> int: 100 | """Return the size of the observation vector.""" 101 | ... 102 | ``` 103 | 104 | Then you should be able to run `gpc.training.train` to train a flow-matching 105 | policy, and `gpc.testing.test_interactive` to run an interactive simulation with 106 | the trained policy. See the environments in [`gpc.envs`](gpc/envs) for examples 107 | and additional details. 108 | 109 | ## Citation 110 | 111 | ```bibtex 112 | @article{kurtz2025generative, 113 | title={Generative Predictive Control: Flow Matching Policies for Dynamic and Difficult-to-Demonstrate Task}, 114 | author={Kurtz, Vince and Burdick, Joel}, 115 | journal={arXiv preprint arXiv:2502.13406}, 116 | year={2025}, 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /tests/test_architectures.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cloudpickle 4 | import jax.numpy as jnp 5 | import matplotlib.pyplot as plt 6 | from flax import nnx 7 | 8 | from gpc.architectures import ( 9 | MLP, 10 | DenoisingCNN, 11 | DenoisingMLP, 12 | PositionalEmbedding, 13 | ) 14 | 15 | 16 | def test_mlp_construction() -> None: 17 | """Create a simple MLP and verify sizes.""" 18 | batch_size = 10 19 | input_size = 2 20 | output_size = 3 21 | 22 | # Create the model 23 | model = MLP([input_size, 128, 32, output_size], nnx.Rngs(0)) 24 | 25 | # Make sure the model is constructed correctly 26 | input = jnp.zeros((batch_size, input_size)) 27 | output = model(input) 28 | assert output.shape == (batch_size, output_size) 29 | 30 | # Print a summary of the model 31 | nnx.display(model) 32 | 33 | 34 | def test_mlp_save_load() -> None: 35 | """Verify that we can pickle an MLP.""" 36 | layer_sizes = [2, 3, 4] 37 | mlp = MLP(layer_sizes, rngs=nnx.Rngs(1)) 38 | dummy_input = jnp.ones((2,)) 39 | original_output = mlp(dummy_input) 40 | 41 | # Create a temporary path for saving stuff 42 | local_dir = Path("_test_mlp") 43 | local_dir.mkdir(parents=True, exist_ok=True) 44 | 45 | model_path = local_dir / "mlp.pkl" 46 | with Path(model_path).open("wb") as f: 47 | cloudpickle.dump(mlp, f) 48 | 49 | with Path(model_path).open("rb") as f: 50 | model_restored = cloudpickle.load(f) 51 | 52 | # Check that the model is still functional 53 | restored_output = model_restored(dummy_input) 54 | assert jnp.allclose(original_output, restored_output) 55 | 56 | # Remove the temporary directory 57 | for p in local_dir.iterdir(): 58 | p.unlink() 59 | local_dir.rmdir() 60 | 61 | 62 | def test_denoising_mlp() -> None: 63 | """Test the denoising MLP.""" 64 | num_steps = 5 65 | action_dim = 3 66 | obs_dim = 4 67 | 68 | # Define the network architecture 69 | net = DenoisingMLP(action_dim, obs_dim, num_steps, (32, 32), nnx.Rngs(0)) 70 | 71 | # Test on some data 72 | U = jnp.ones((num_steps, action_dim)) 73 | y = jnp.ones(obs_dim) 74 | t = jnp.ones(1) 75 | U_out = net(U, y, t) 76 | assert U_out.shape == (num_steps, action_dim) 77 | 78 | # Test on some batched data 79 | U = jnp.ones((14, 24, num_steps, action_dim)) 80 | y = jnp.ones((14, 24, obs_dim)) 81 | t = jnp.ones((14, 24, 1)) 82 | U_out = net(U, y, t) 83 | assert U_out.shape == (14, 24, num_steps, action_dim) 84 | 85 | 86 | def test_positional_embedding() -> None: 87 | """Test our sinusoidal positional embedding.""" 88 | dim = 8 89 | emb = PositionalEmbedding(dim) 90 | 91 | e = emb(jnp.zeros(1)[0]) 92 | assert e.shape == (dim,) 93 | 94 | t = jnp.zeros((24, 14, 1)) 95 | e = emb(t) 96 | assert e.shape == (24, 14, dim) 97 | 98 | t = jnp.linspace(0, 1, 100) 99 | e = emb(t) 100 | assert e.shape == (100, dim) 101 | 102 | if __name__ == "__main__": 103 | # Visualize the positional embedding 104 | plt.plot(t, e) 105 | plt.xlabel("Time") 106 | plt.ylabel("Positional Embedding") 107 | plt.title("Sinusoidal Positional Embedding") 108 | plt.show() 109 | 110 | 111 | def test_denoising_cnn() -> None: 112 | """Test the denoising CNN.""" 113 | num_steps = 5 114 | action_dim = 3 115 | obs_dim = 4 116 | 117 | # Define the network architecture 118 | net = DenoisingCNN(action_dim, obs_dim, num_steps, [32, 32], nnx.Rngs(0)) 119 | 120 | # Test on some data 121 | U = jnp.ones((num_steps, action_dim)) 122 | y = jnp.ones(obs_dim) 123 | t = jnp.ones(1) 124 | U_out = net(U, y, t) 125 | assert U_out.shape == (num_steps, action_dim) 126 | 127 | # Test on some batched data 128 | U = jnp.ones((14, 24, num_steps, action_dim)) 129 | y = jnp.ones((14, 24, obs_dim)) 130 | t = jnp.ones((14, 24, 1)) 131 | U_out = net(U, y, t) 132 | assert U_out.shape == (14, 24, num_steps, action_dim) 133 | 134 | 135 | if __name__ == "__main__": 136 | test_mlp_construction() 137 | test_mlp_save_load() 138 | test_denoising_mlp() 139 | test_positional_embedding() 140 | test_denoising_cnn() 141 | -------------------------------------------------------------------------------- /gpc/envs/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import jax 4 | import mujoco 5 | import numpy as np 6 | from flax.struct import dataclass 7 | from hydrax.task_base import Task 8 | from mujoco import mjx 9 | 10 | 11 | @dataclass 12 | class SimulatorState: 13 | """A dataclass for storing the simulator state. 14 | 15 | Attributes: 16 | data: The mjx simulator data. 17 | t: The current time step. 18 | rng: The random number generator key. 19 | """ 20 | 21 | data: mjx.Data 22 | t: int 23 | rng: jax.Array 24 | 25 | 26 | class TrainingEnv(ABC): 27 | """Abstract class defining a training environment.""" 28 | 29 | def __init__(self, task: Task, episode_length: int) -> None: 30 | """Initialize the training environment.""" 31 | self.task = task 32 | self.episode_length = episode_length 33 | self.renderer = mujoco.Renderer(self.task.mj_model) 34 | 35 | # Disable shadows and reflections for faster rendering 36 | self.renderer.scene.flags[mujoco.mjtRndFlag.mjRND_SHADOW] = False 37 | self.renderer.scene.flags[mujoco.mjtRndFlag.mjRND_REFLECTION] = False 38 | self.renderer.scene.flags[mujoco.mjtRndFlag.mjRND_FOG] = False 39 | self.renderer.scene.flags[mujoco.mjtRndFlag.mjRND_HAZE] = False 40 | 41 | def init_state(self, rng: jax.Array) -> SimulatorState: 42 | """Initialize the simulator state.""" 43 | state = SimulatorState( 44 | data=mjx.make_data(self.task.model), t=0, rng=rng 45 | ) 46 | return self._reset_state(state) 47 | 48 | def render(self, states: SimulatorState, fps: int = 10) -> np.ndarray: 49 | """Render video frames from a state trajectory. 50 | 51 | Note that this is not a pure jax function, and should only be used for 52 | visualization. 53 | 54 | Args: 55 | states: Sequence of states (vmapped over time). 56 | fps: The frames per second for the video. 57 | 58 | Returns: 59 | A sequence of video frames, with shape (T, C, H, W). 60 | """ 61 | sim_dt = self.task.model.opt.timestep 62 | render_dt = 1.0 / fps 63 | render_every = int(round(render_dt / sim_dt)) 64 | steps = np.arange(0, len(states.t), render_every) 65 | 66 | frames = [] 67 | for i in steps: 68 | mjx_data = jax.tree.map(lambda x: x[i], states.data) # noqa: B023 69 | mj_data = mjx.get_data(self.task.mj_model, mjx_data) 70 | self.renderer.update_scene(mj_data) 71 | pixels = self.renderer.render() # H, W, C 72 | frames.append(pixels.transpose(2, 0, 1)) # C, H, W 73 | 74 | return np.stack(frames) 75 | 76 | @abstractmethod 77 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 78 | """Reset the simulator to start a new episode.""" 79 | 80 | @abstractmethod 81 | def get_obs(self, data: mjx.Data) -> jax.Array: 82 | """Get the observation from the simulator state.""" 83 | 84 | @property 85 | @abstractmethod 86 | def observation_size(self) -> int: 87 | """The size of the observation space.""" 88 | 89 | def _reset_state(self, state: SimulatorState) -> SimulatorState: 90 | """Reset the simulator state to start a new episode.""" 91 | rng, reset_rng = jax.random.split(state.rng) 92 | data = self.reset(state.data, reset_rng) 93 | data = mjx.forward(self.task.model, data) # update sensor data 94 | return SimulatorState(data=data, t=0, rng=rng) 95 | 96 | def _update_goal(self, state: SimulatorState) -> SimulatorState: 97 | """Update the goal state during the middle of an episode.""" 98 | rng, goal_rng = jax.random.split(state.rng) 99 | data = self.update_goal(state.data, goal_rng) 100 | return state.replace(data=data, rng=rng) 101 | 102 | def _get_observation(self, state: SimulatorState) -> jax.Array: 103 | """Get the observation from the simulator state.""" 104 | return self.get_obs(state.data) 105 | 106 | def episode_over(self, state: SimulatorState) -> bool: 107 | """Check if the episode is over. 108 | 109 | Override this method if the episode should terminate early. 110 | """ 111 | return state.t >= self.episode_length 112 | 113 | def goal_reached(self, state: SimulatorState) -> bool: 114 | """Check if we've achieved a sub-goal. 115 | 116 | This gives us the opportunity to update the goal before the episode 117 | ends. For example, we might want to choose a new target configuration 118 | once the old one has been reached. 119 | """ 120 | return False 121 | 122 | def update_goal(self, data: mjx.Data, rng: jax.Array) -> mjx.Data: 123 | """Update the goal state during the middle of an episode. 124 | 125 | Typically this is done via mocap_pos and mocap_quat, and by default we 126 | do nothing. 127 | """ 128 | return data 129 | 130 | def step(self, state: SimulatorState, action: jax.Array) -> SimulatorState: 131 | """Take a simulation step. 132 | 133 | Args: 134 | state: The simulator state. 135 | action: The action to take. 136 | 137 | Returns: 138 | The new simulator state and the new time step. 139 | """ 140 | # Check if the episode is over 141 | next_state = jax.lax.cond( 142 | self.episode_over(state), 143 | lambda _: self._reset_state(state), 144 | lambda _: state.replace( 145 | data=mjx.step(self.task.model, state.data.replace(ctrl=action)), 146 | t=state.t + 1, 147 | ), 148 | operand=None, 149 | ) 150 | 151 | # Check if we've reached a sub-goal that needs updating 152 | next_state = jax.lax.cond( 153 | self.goal_reached(next_state), 154 | lambda _: self._update_goal(next_state), 155 | lambda _: next_state, 156 | operand=None, 157 | ) 158 | 159 | return next_state 160 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | from pathlib import Path 4 | 5 | import evosax 6 | import jax 7 | import jax.numpy as jnp 8 | import matplotlib.pyplot as plt 9 | import optax 10 | import pytest 11 | from flax import nnx 12 | from hydrax.algs import Evosax, PredictiveSampling 13 | 14 | from gpc.architectures import DenoisingMLP 15 | from gpc.augmented import PolicyAugmentedController 16 | from gpc.envs import ParticleEnv, SimulatorState 17 | from gpc.policy import Policy 18 | from gpc.training import fit_policy, simulate_episode, train 19 | 20 | 21 | def test_simulate() -> None: 22 | """Test simulating an episode.""" 23 | rng = jax.random.key(0) 24 | env = ParticleEnv(episode_length=13) 25 | ctrl = PolicyAugmentedController( 26 | PredictiveSampling(env.task, num_samples=8, noise_level=0.1), 27 | num_policy_samples=8, 28 | ) 29 | net = DenoisingMLP( 30 | action_size=env.task.model.nu, 31 | observation_size=env.observation_size, 32 | horizon=env.task.planning_horizon, 33 | hidden_layers=[32, 32], 34 | rngs=nnx.Rngs(0), 35 | ) 36 | normalizer = nnx.BatchNorm( 37 | env.observation_size, 38 | momentum=0.1, 39 | epsilon=1e-5, 40 | use_bias=False, 41 | use_scale=False, 42 | rngs=nnx.Rngs(0), 43 | ) 44 | 45 | policy = Policy(net, normalizer, env.task.u_min, env.task.u_max) 46 | 47 | rng, episode_rng = jax.random.split(rng) 48 | y, U, U_guess, J_spc, J_policy, states = simulate_episode( 49 | env, ctrl, policy, 0.0, episode_rng 50 | ) 51 | 52 | assert y.shape == (13, 4) 53 | assert U.shape == (13, 5, 2) 54 | assert U_guess.shape == (13, 5, 2) 55 | assert J_spc.shape == (13,) 56 | assert J_policy.shape == (13,) 57 | assert isinstance(states, SimulatorState) 58 | assert states.t.shape == (13,) 59 | assert states.data.qpos.shape == (13, 2) 60 | 61 | 62 | def test_fit() -> None: 63 | """Test fitting the policy network.""" 64 | rng = jax.random.key(0) 65 | 66 | # Make some fake data 67 | rng, obs_rng, act_rng = jax.random.split(rng, 3) 68 | y1 = jax.random.uniform(obs_rng, (64, 1)) 69 | y2 = jax.random.uniform(obs_rng, (128, 1)) 70 | y = jnp.concatenate([y1, y2], axis=0) 71 | U1 = -0.5 - y1[..., None] + 0.1 * jax.random.normal(act_rng, (64, 1, 1)) 72 | U2 = 0.5 * y2[..., None] + 0.1 * jax.random.normal(act_rng, (128, 1, 1)) 73 | U = jnp.concatenate([U1, U2], axis=0) 74 | 75 | # Plot the training data 76 | if __name__ == "__main__": 77 | plt.scatter(y, U[:, 0, 0]) 78 | plt.xlabel("Observation") 79 | plt.ylabel("Action") 80 | plt.show(block=False) 81 | 82 | # Set up the policy network 83 | net = DenoisingMLP( 84 | action_size=1, 85 | observation_size=1, 86 | horizon=1, 87 | hidden_layers=[32, 32], 88 | rngs=nnx.Rngs(0), 89 | ) 90 | 91 | # Set up the optimizer 92 | optimizer = nnx.Optimizer(net, optax.adam(1e-2)) 93 | batch_size = 512 # can be larger than the dataset b/c added noise 94 | num_epochs = 1000 95 | 96 | # Fit the policy network 97 | st = time.time() 98 | rng, fit_rng = jax.random.split(rng) 99 | loss = fit_policy(y, U, U, net, optimizer, batch_size, num_epochs, fit_rng) 100 | print("Final loss:", loss) 101 | assert loss < 1.0 102 | print("Fit time:", time.time() - st) 103 | 104 | # Try generating some actions 105 | rng, test_rng = jax.random.split(rng) 106 | y_test = jnp.linspace(0.0, 1.0, 100)[:, None] 107 | U_test = jax.random.normal(test_rng, (100, 1, 1)) 108 | dt = 0.1 109 | for t in jnp.arange(0.0, 1.0, dt): 110 | v = net(U_test, y_test, jnp.tile(t, (100, 1))) 111 | U_test += v * dt 112 | 113 | if __name__ == "__main__": 114 | plt.scatter(y_test, U_test[:, 0, 0]) 115 | plt.xlabel("Observation") 116 | plt.ylabel("Action") 117 | plt.show() 118 | 119 | 120 | def test_train() -> None: 121 | """Test the training loop.""" 122 | log_dir = Path("_test_train") 123 | log_dir.mkdir(parents=True, exist_ok=True) 124 | 125 | env = ParticleEnv() 126 | net = DenoisingMLP( 127 | action_size=env.task.model.nu, 128 | observation_size=env.observation_size, 129 | horizon=env.task.planning_horizon, 130 | hidden_layers=[32, 32], 131 | rngs=nnx.Rngs(0), 132 | ) 133 | 134 | # Try training with an incompatible controller 135 | with pytest.raises(AssertionError): 136 | invalid_ctrl = Evosax(env.task, evosax.Sep_CMA_ES, num_samples=8) 137 | policy = train( 138 | env, 139 | invalid_ctrl, 140 | net, 141 | num_policy_samples=2, 142 | log_dir=log_dir, 143 | num_iters=1, 144 | num_envs=4, 145 | ) 146 | 147 | # Train with predictive sampling 148 | ctrl = PredictiveSampling(env.task, num_samples=8, noise_level=0.1) 149 | policy = train( 150 | env, 151 | ctrl, 152 | net, 153 | num_policy_samples=8, 154 | log_dir=log_dir, 155 | num_iters=3, 156 | num_envs=128, 157 | checkpoint_every=1, 158 | ) 159 | 160 | assert isinstance(policy, Policy) 161 | 162 | # Test the policy 163 | rng = jax.random.key(0) 164 | y = jnp.array([-0.1, 0.1, 0.0, 0.0]) 165 | U = jnp.zeros((env.task.planning_horizon, env.task.model.nu)) 166 | U = policy.apply(U, y, rng) 167 | 168 | # Check that the policy output points in the right direction 169 | assert U.shape == (env.task.planning_horizon, env.task.model.nu) 170 | assert U[0, 0] > 0.0 171 | assert U[0, 1] < 0.0 172 | 173 | # Cleanup recursively 174 | shutil.rmtree(log_dir) 175 | 176 | 177 | def test_policy() -> None: 178 | """Test the policy helper class.""" 179 | rng = jax.random.key(0) 180 | num_steps = 5 181 | num_actions = 2 182 | num_obs = 3 183 | 184 | # Create a toy network 185 | mlp = DenoisingMLP( 186 | action_size=num_actions, 187 | observation_size=num_obs, 188 | horizon=num_steps, 189 | hidden_layers=[32, 32], 190 | rngs=nnx.Rngs(0), 191 | ) 192 | 193 | # Create an observation normalizer 194 | normalizer = nnx.BatchNorm( 195 | num_obs, 196 | momentum=0.1, 197 | epsilon=1e-5, 198 | use_bias=False, 199 | use_scale=False, 200 | rngs=nnx.Rngs(0), 201 | ) 202 | 203 | # Create the policy 204 | u_min = -2 * jnp.ones(num_actions) 205 | u_max = jnp.ones(num_actions) 206 | policy = Policy(mlp, normalizer, u_min, u_max) 207 | 208 | # Test running the policy 209 | rng, apply_rng = jax.random.split(rng) 210 | U = jnp.zeros((num_steps, num_actions)) 211 | y = jnp.ones((num_obs,)) 212 | U1 = policy.apply(U, y, apply_rng) 213 | assert U1.shape == (num_steps, num_actions) 214 | 215 | assert jnp.all(U1 != 0.0) 216 | assert jnp.all(U1 >= u_min) 217 | assert jnp.all(U1 <= u_max) 218 | 219 | # Save and load the policy 220 | local_dir = Path("_test_policy") 221 | local_dir.mkdir(parents=True, exist_ok=True) 222 | 223 | policy.save(local_dir / "policy.pkl") 224 | del policy 225 | 226 | policy2 = Policy.load(local_dir / "policy.pkl") 227 | 228 | U2 = jax.jit(policy2.apply)(U, y, apply_rng) 229 | assert jnp.allclose(U2, U1) 230 | 231 | # Cleanup 232 | for p in local_dir.iterdir(): 233 | p.unlink() 234 | local_dir.rmdir() 235 | 236 | 237 | if __name__ == "__main__": 238 | test_simulate() 239 | test_fit() 240 | test_train() 241 | test_policy() 242 | -------------------------------------------------------------------------------- /gpc/architectures.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from flax import nnx 6 | 7 | 8 | class MLP(nnx.Module): 9 | """A simple multi-layer perceptron.""" 10 | 11 | def __init__(self, layer_sizes: Sequence[int], rngs: nnx.Rngs): 12 | """Initialize the network. 13 | 14 | Args: 15 | layer_sizes: Sizes of all layers, including input and output. 16 | rngs: Random number generators for initialization. 17 | """ 18 | self.num_hidden = len(layer_sizes) - 2 19 | 20 | # TODO: use nnx.scan to scan over layers, reducing compile times 21 | for i, (input_size, output_size) in enumerate( 22 | zip(layer_sizes[:-1], layer_sizes[1:], strict=False) 23 | ): 24 | setattr( 25 | self, f"l{i}", nnx.Linear(input_size, output_size, rngs=rngs) 26 | ) 27 | 28 | def __call__(self, x: jax.Array) -> jax.Array: 29 | """Forward pass through the network.""" 30 | for i in range(self.num_hidden): 31 | x = getattr(self, f"l{i}")(x) 32 | x = nnx.swish(x) 33 | x = getattr(self, f"l{self.num_hidden}")(x) 34 | return x 35 | 36 | 37 | class DenoisingMLP(nnx.Module): 38 | """A simple multi-layer perceptron for action sequence denoising. 39 | 40 | Computes U* = NNet(U, y, t), where U is the noisy action sequence, y is the 41 | initial observation, and t is the time step in the denoising process. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | action_size: int, 47 | observation_size: int, 48 | horizon: int, 49 | hidden_layers: Sequence[int], 50 | rngs: nnx.Rngs, 51 | ): 52 | """Initialize the network. 53 | 54 | Args: 55 | action_size: Dimension of the actions (u). 56 | observation_size: Dimension of the observations (y). 57 | horizon: Number of steps in the action sequence (U = [u0, u1, ...]). 58 | hidden_layers: Sizes of all hidden layers. 59 | rngs: Random number generators for initialization. 60 | """ 61 | self.action_size = action_size 62 | self.observation_size = observation_size 63 | self.horizon = horizon 64 | self.hidden_layers = hidden_layers 65 | 66 | input_size = horizon * action_size + observation_size + 1 67 | output_size = horizon * action_size 68 | self.mlp = MLP( 69 | [input_size] + list(hidden_layers) + [output_size], rngs=rngs 70 | ) 71 | 72 | def __call__(self, u: jax.Array, y: jax.Array, t: jax.Array) -> jax.Array: 73 | """Forward pass through the network.""" 74 | batches = u.shape[:-2] 75 | u_flat = u.reshape(batches + (self.horizon * self.action_size,)) 76 | x = jnp.concatenate([u_flat, y, t], axis=-1) 77 | x = self.mlp(x) 78 | return x.reshape(batches + (self.horizon, self.action_size)) 79 | 80 | 81 | class PositionalEmbedding(nnx.Module): 82 | """A simple sinusoidal positional embedding layer.""" 83 | 84 | def __init__(self, dim: int): 85 | """Initialize the positional embedding. 86 | 87 | Args: 88 | dim: Dimension to lift the input to. 89 | """ 90 | self.half_dim = dim // 2 91 | 92 | def __call__(self, t: jax.Array) -> jax.Array: 93 | """Compute the positional embedding.""" 94 | freqs = jnp.arange(1, self.half_dim + 1) * jnp.pi 95 | emb = freqs * jnp.squeeze(t)[..., None] 96 | emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) 97 | return emb 98 | 99 | 100 | class Conv1DBlock(nnx.Module): 101 | """A simple temporal convolutional block. 102 | 103 | ---------- ------------- -------------------- 104 | ---> | Conv1d | --> | BatchNorm | --> | Swish Activation | ---> 105 | ---------- ------------- -------------------- 106 | 107 | """ 108 | 109 | def __init__( 110 | self, 111 | in_features: int, 112 | out_features: int, 113 | kernel_size: int, 114 | rngs: nnx.Rngs, 115 | ): 116 | """Initialize the block. 117 | 118 | Args: 119 | in_features: Number of input features. 120 | out_features: Number of output features. 121 | kernel_size: Size of the convolutional kernel. 122 | rngs: Random number generators for initialization. 123 | """ 124 | self.c = nnx.Conv( 125 | in_features=in_features, 126 | out_features=out_features, 127 | kernel_size=kernel_size, 128 | padding="SAME", 129 | rngs=rngs, 130 | ) 131 | self.bn = nnx.BatchNorm(num_features=out_features, rngs=rngs) 132 | 133 | def __call__(self, x: jax.Array) -> jax.Array: 134 | """Forward pass through the block.""" 135 | x = self.c(x) 136 | x = self.bn(x) 137 | x = nnx.swish(x) 138 | return x 139 | 140 | 141 | class ConditionalResidualBlock(nnx.Module): 142 | """A temporal convolutional block with FiLM conditional information. 143 | 144 | ------------------------------------------------------------- 145 | | | 146 | | ----------- ----------- ----------- | 147 | x ---> | Encoder | --> (+) --> | Dropout | --> | Decoder | --> (+) --> 148 | ----------- | ----------- ----------- 149 | | 150 | ---------- 151 | y -----------------| Linear | 152 | ---------- 153 | 154 | """ 155 | 156 | def __init__( 157 | self, 158 | in_features: int, 159 | out_features: int, 160 | cond_features: int, 161 | kernel_size: int, 162 | rngs: nnx.Rngs, 163 | ): 164 | """Initialize the block. 165 | 166 | Args: 167 | in_features: Number of input features. 168 | out_features: Number of output features. 169 | cond_features: Number of conditioning features. 170 | kernel_size: Size of the convolutional kernel. 171 | rngs: Random number generators for initialization. 172 | """ 173 | self.encoder = Conv1DBlock(in_features, out_features, kernel_size, rngs) 174 | self.decoder = Conv1DBlock( 175 | out_features, out_features, kernel_size, rngs 176 | ) 177 | self.linear = nnx.LinearGeneral( 178 | cond_features, (1, out_features), rngs=rngs 179 | ) 180 | self.dropout = nnx.Dropout(rate=0.1, rngs=rngs) 181 | self.residual = nnx.Conv( 182 | in_features=in_features, 183 | out_features=out_features, 184 | kernel_size=1, 185 | padding="SAME", 186 | rngs=rngs, 187 | ) 188 | 189 | def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array: 190 | """Forward pass through the block.""" 191 | z = self.encoder(x) 192 | z += self.linear(y) 193 | z = self.dropout(z) 194 | z = self.decoder(z) 195 | return z + self.residual(x) 196 | 197 | 198 | class DenoisingCNN(nnx.Module): 199 | """A denoising convolutional network with FiLM conditioning. 200 | 201 | Based on Diffusion Policy, https://arxiv.org/abs/2303.04137v5. 202 | """ 203 | 204 | def __init__( 205 | self, 206 | action_size: int, 207 | observation_size: int, 208 | horizon: int, 209 | feature_dims: Sequence[int], 210 | rngs: nnx.Rngs, 211 | kernel_size: int = 3, 212 | timestep_embedding_dim: int = 32, 213 | ): 214 | """Initialize the network. 215 | 216 | Args: 217 | action_size: Dimension of the actions (u). 218 | observation_size: Dimension of the observations (y). 219 | horizon: Number of steps in the action sequence (U = [u0, u1, ...]). 220 | feature_dims: List of feature dimensions. 221 | rngs: Random number generators for initialization. 222 | kernel_size: Size of the convolutional kernel. 223 | timestep_embedding_dim: Dimension of the positional embedding. 224 | """ 225 | self.action_size = action_size 226 | self.observation_size = observation_size 227 | self.horizon = horizon 228 | self.num_layers = len(feature_dims) + 1 229 | self.positional_embedding = PositionalEmbedding(timestep_embedding_dim) 230 | 231 | feature_sizes = [action_size] + list(feature_dims) + [action_size] 232 | for i, (input_size, output_size) in enumerate( 233 | zip(feature_sizes[:-1], feature_sizes[1:], strict=False) 234 | ): 235 | setattr( 236 | self, 237 | f"l{i}", 238 | ConditionalResidualBlock( 239 | input_size, 240 | output_size, 241 | observation_size + timestep_embedding_dim, 242 | kernel_size, 243 | rngs, 244 | ), 245 | ) 246 | 247 | def __call__(self, u: jax.Array, y: jax.Array, t: jax.Array) -> jax.Array: 248 | """Forward pass through the network.""" 249 | emb = self.positional_embedding(t) 250 | y = jnp.concatenate([y, emb], axis=-1) 251 | 252 | x = self.l0(u, y) 253 | for i in range(1, self.num_layers): 254 | x = getattr(self, f"l{i}")(x, y) 255 | 256 | return x + u 257 | -------------------------------------------------------------------------------- /gpc/training.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | from pathlib import Path 4 | from typing import Any, Tuple, Union 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | from flax import nnx 11 | from hydrax.alg_base import SamplingBasedController 12 | from tensorboardX import SummaryWriter 13 | 14 | from gpc.augmented import PACParams, PolicyAugmentedController 15 | from gpc.envs import SimulatorState, TrainingEnv 16 | from gpc.policy import Policy 17 | 18 | Params = Any 19 | 20 | 21 | def simulate_episode( 22 | env: TrainingEnv, 23 | ctrl: PolicyAugmentedController, 24 | policy: Policy, 25 | exploration_noise_level: float, 26 | rng: jax.Array, 27 | strategy: str = "policy", 28 | ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, SimulatorState]: 29 | """Starting from a random initial state, run SPC and record training data. 30 | 31 | Args: 32 | env: The training environment. 33 | ctrl: The sampling-based controller (augmented with a learned policy). 34 | policy: The generative policy network. 35 | exploration_noise_level: Standard deviation of the gaussian noise added 36 | to each action. 37 | rng: The random number generator key. 38 | strategy: The strategy for advancing the simulation. "policy" uses the 39 | first policy sample, while "best" agregates all samples. 40 | 41 | Returns: 42 | y: The observations at each time step. 43 | U: The optimal actions at each time step. 44 | U_guess: The initial guess for the optimal actions at each time step. 45 | J_spc: cost of the best action sequence found by SPC at each time step. 46 | J_policy: cost of the best action sequence found by the policy. 47 | states: Vmapped simulator states at each time step. 48 | """ 49 | rng, ctrl_rng, env_rng = jax.random.split(rng, 3) 50 | 51 | # Set the initial state of the environment 52 | x = env.init_state(env_rng) 53 | 54 | # Set the initial sampling-based controller parameters 55 | psi = ctrl.init_params() 56 | psi = psi.replace(base_params=psi.base_params.replace(rng=ctrl_rng)) 57 | 58 | def _scan_fn( 59 | carry: Tuple[SimulatorState, jax.Array, PACParams], t: int 60 | ) -> Tuple: 61 | """Take simulation step, and record all data.""" 62 | x, U, psi = carry 63 | 64 | # Sample action sequences from the learned policy 65 | # TODO: consider warm-starting the policy 66 | y = env._get_observation(x) 67 | rng, policy_rng, explore_rng = jax.random.split(psi.base_params.rng, 3) 68 | policy_rngs = jax.random.split(policy_rng, ctrl.num_policy_samples) 69 | warm_start_level = 0.0 70 | Us = jax.vmap(policy.apply, in_axes=(0, None, 0, None))( 71 | U, y, policy_rngs, warm_start_level 72 | ) 73 | 74 | # Place the samples into the predictive control parameters so they 75 | # can be used in the predictive control update 76 | psi = psi.replace( 77 | policy_samples=Us, base_params=psi.base_params.replace(rng=rng) 78 | ) 79 | 80 | # Update the action sequence with sampling-based predictive control 81 | psi, rollouts = ctrl.optimize(x.data, psi) 82 | U_star = ctrl.get_action_sequence(psi) 83 | 84 | # Record the lowest costs achieved by SPC and the policy 85 | # TODO: consider logging something more informative 86 | costs = jnp.sum(rollouts.costs, axis=1) 87 | spc_best_idx = jnp.argmin(costs[: -ctrl.num_policy_samples]) 88 | policy_best_idx = ( 89 | jnp.argmin(costs[ctrl.num_policy_samples :]) 90 | + ctrl.num_policy_samples 91 | ) 92 | spc_best = costs[spc_best_idx] 93 | policy_best = costs[policy_best_idx] 94 | 95 | # Step the simulation 96 | if strategy == "policy": 97 | u = Us[0, 0] 98 | elif strategy == "best": 99 | u = U_star[0] 100 | else: 101 | raise ValueError(f"Unknown strategy: {strategy}") 102 | exploration_noise = exploration_noise_level * jax.random.normal( 103 | explore_rng, u.shape 104 | ) 105 | x = env.step(x, u + exploration_noise) 106 | 107 | # Record the initial guess for the optimal action sequence. This is used 108 | # to weigh the flow matching loss in the policy training. 109 | U_guess = psi.base_params.mean 110 | 111 | return (x, Us, psi), (y, U_star, U_guess, spc_best, policy_best, x) 112 | 113 | rng, u_rng = jax.random.split(rng) 114 | U = jax.random.normal( 115 | u_rng, 116 | (ctrl.num_policy_samples, env.task.planning_horizon, env.task.model.nu), 117 | ) 118 | _, (y, U, U_guess, J_spc, J_policy, states) = jax.lax.scan( 119 | _scan_fn, (x, U, psi), jnp.arange(env.episode_length) 120 | ) 121 | 122 | return y, U, U_guess, J_spc, J_policy, states 123 | 124 | 125 | def fit_policy( 126 | observations: jax.Array, 127 | action_sequences: jax.Array, 128 | old_action_sequences: jax.Array, 129 | model: nnx.Module, 130 | optimizer: nnx.Optimizer, 131 | batch_size: int, 132 | num_epochs: int, 133 | rng: jax.Array, 134 | sigma_min: float = 1e-2, 135 | ) -> jax.Array: 136 | """Fit a flow matching model to the data. 137 | 138 | This model generates samples U ~ π(U|y) from the policy by flowing from 139 | U ~ N(0, I) to the target action sequence U*. 140 | 141 | Args: 142 | observations: The (normalized) observations y. 143 | action_sequences: The corresponding target action sequences U. 144 | old_action_sequences: The previous action sequences U_guess. 145 | model: The policy network, outputs the flow matching vector field. 146 | optimizer: The optimizer (e.g. Adam). 147 | batch_size: The batch size. 148 | num_epochs: The number of epochs. 149 | rng: The random number generator key. 150 | sigma_min: Target distribution width for flow matching, see 151 | https://arxiv.org/pdf/2210.02747, eq (20-23). 152 | 153 | Returns: 154 | The loss from the last epoch. 155 | 156 | Note that model and optimizer are updated in-place by flax.nnx. 157 | """ 158 | num_data_points = observations.shape[0] 159 | num_batches = max(1, num_data_points // batch_size) 160 | 161 | def _loss_fn( 162 | model: nnx.Module, 163 | obs: jax.Array, 164 | act: jax.Array, 165 | old_act: jax.Array, 166 | noise: jax.Array, 167 | t: jax.Array, 168 | ) -> jax.Array: 169 | """Compute the flow-matching loss.""" 170 | alpha = 1.0 - sigma_min 171 | noised_action = t[..., None] * act + (1 - alpha * t[..., None]) * noise 172 | target = act - alpha * noise 173 | pred = model(noised_action, obs, t) 174 | 175 | # Weigh the loss by how close the noise is to the old action sequence. 176 | # If they are similar (in terms of angle to the target action) then the 177 | # weight is high. Otherwise the noised sample might be approaching the 178 | # target action sequence from a different direction, so this sample 179 | # isn't so informative and we reduce the weight. 180 | v1 = (old_act - act).flatten() 181 | v2 = (noise - act).flatten() 182 | cosine_similarity = jnp.dot(v1, v2) / ( 183 | jnp.linalg.norm(v1) * jnp.linalg.norm(v2) + 1e-8 184 | ) 185 | weight = jax.lax.stop_gradient(jnp.exp(2 * (cosine_similarity - 1))) 186 | 187 | return weight * jnp.mean(jnp.square(pred - target)) 188 | 189 | def _train_step( 190 | model: nnx.Module, 191 | optimizer: nnx.Optimizer, 192 | rng: jax.Array, 193 | ) -> Tuple[jax.Array, jax.Array]: 194 | """Perform a gradient descent step on a batch of data.""" 195 | # Get a random batch of data 196 | rng, batch_rng = jax.random.split(rng) 197 | batch_idx = jax.random.randint( 198 | batch_rng, (batch_size,), 0, num_data_points 199 | ) 200 | batch_obs = observations[batch_idx] 201 | batch_act = action_sequences[batch_idx] 202 | batch_old_act = old_action_sequences[batch_idx] 203 | 204 | # Sample noise and time steps for the flow matching targets 205 | rng, noise_rng, t_rng = jax.random.split(rng, 3) 206 | noise = jax.random.normal(noise_rng, batch_act.shape) 207 | t = jax.random.uniform(t_rng, (batch_size, 1)) 208 | 209 | # Compute the loss and its gradient 210 | loss, grad = nnx.value_and_grad(_loss_fn)( 211 | model, batch_obs, batch_act, batch_old_act, noise, t 212 | ) 213 | 214 | # Update the optimizer and model parameters in-place via flax.nnx 215 | optimizer.update(grad) 216 | 217 | return rng, loss 218 | 219 | # for i in range(num_batches * num_epochs): take a training step 220 | @nnx.scan 221 | def _scan_fn(carry: Tuple, i: int) -> Tuple: 222 | model, optimizer, rng = carry 223 | rng, loss = _train_step(model, optimizer, rng) 224 | return (model, optimizer, rng), loss 225 | 226 | _, losses = _scan_fn( 227 | (model, optimizer, rng), jnp.arange(num_batches * num_epochs) 228 | ) 229 | 230 | return losses[-1] 231 | 232 | 233 | def train( # noqa: PLR0915 this is a long function, don't limit to 50 lines 234 | env: TrainingEnv, 235 | ctrl: SamplingBasedController, 236 | net: nnx.Module, 237 | num_policy_samples: int, 238 | log_dir: Union[Path, str], 239 | num_iters: int, 240 | num_envs: int, 241 | learning_rate: float = 1e-3, 242 | batch_size: int = 128, 243 | num_epochs: int = 10, 244 | checkpoint_every: int = 10, 245 | exploration_noise_level: float = 0.0, 246 | normalize_observations: bool = True, 247 | num_videos: int = 2, 248 | video_fps: int = 10, 249 | strategy: str = "policy", 250 | ) -> None: 251 | """Train a generative predictive controller. 252 | 253 | Args: 254 | env: The training environment. 255 | ctrl: The sampling-based predictive control method to use. 256 | net: The flow matching network architecture. 257 | num_policy_samples: The number of samples to draw from the policy. 258 | log_dir: The directory to log TensorBoard data to. 259 | num_iters: The number of training iterations. 260 | num_envs: The number of parallel environments to simulate. 261 | learning_rate: The learning rate for the policy network. 262 | batch_size: The batch size for training the policy network. 263 | num_epochs: The number of epochs to train the policy network. 264 | checkpoint_every: Number of iterations between policy checkpoint saves. 265 | exploration_noise_level: Standard deviation of the gaussian noise added 266 | to each action during episode simulation. 267 | normalize_observations: Flag for observation normalization. 268 | num_videos: Number of videos to render for visualization. 269 | video_fps: Frames per second for rendered videos. 270 | strategy: The strategy for choosing a control action to advance the 271 | simulation during the data collection phase. "policy" uses the 272 | first policy sample, while "best" agregates all samples. 273 | 274 | """ 275 | rng = jax.random.key(0) 276 | 277 | # Check that the task has finite input bounds 278 | assert jnp.all(jnp.isfinite(env.task.u_min)) 279 | assert jnp.all(jnp.isfinite(env.task.u_max)) 280 | 281 | # Check that the sampling-based predictive controller is compatible. In 282 | # particular, we need access to the mean of the sampling distribution. 283 | _spc_params = ctrl.init_params() 284 | assert hasattr( 285 | _spc_params, "mean" 286 | ), f"Controller '{type(ctrl).__name__}' is not compatible with GPC." 287 | 288 | # Print some information about the training setup 289 | episode_seconds = env.episode_length * env.task.model.opt.timestep 290 | horizon_seconds = env.task.planning_horizon * env.task.dt 291 | num_samples = num_policy_samples + ctrl.num_samples 292 | print("Training with:") 293 | print( 294 | f" episode length: {episode_seconds} seconds" 295 | f" ({env.episode_length} simulation steps)" 296 | ) 297 | ( 298 | print( 299 | f" planning horizon: {horizon_seconds} seconds" 300 | f" ({env.task.planning_horizon} knots)" 301 | ), 302 | ) 303 | print( 304 | " Parallel rollouts per simulation step:" 305 | f" {num_samples * ctrl.num_randomizations * num_envs}" 306 | f" (= {num_samples} x {ctrl.num_randomizations} x {num_envs})" 307 | ) 308 | print("") 309 | 310 | # Print some info about the policy architecture 311 | params = nnx.state(net, nnx.Param) 312 | total_params = sum([np.prod(x.shape) for x in jax.tree.leaves(params)], 0) 313 | print(f"Policy: {type(net).__name__} with {total_params} parameters") 314 | print("") 315 | 316 | # Set up the sampling-based controller and policy network 317 | ctrl = PolicyAugmentedController(ctrl, num_policy_samples) 318 | assert env.task == ctrl.task 319 | 320 | # Set up the policy 321 | normalizer = nnx.BatchNorm( 322 | num_features=env.observation_size, 323 | momentum=0.1, 324 | use_bias=False, 325 | use_scale=False, 326 | use_fast_variance=False, 327 | rngs=nnx.Rngs(0), 328 | ) 329 | policy = Policy(net, normalizer, env.task.u_min, env.task.u_max) 330 | 331 | # Set up the optimizer 332 | optimizer = nnx.Optimizer(net, optax.adamw(learning_rate)) 333 | 334 | # Set up the TensorBoard logger 335 | log_dir = Path(log_dir) / time.strftime("%Y%m%d_%H%M%S") 336 | print("Logging to", log_dir) 337 | tb_writer = SummaryWriter(log_dir) 338 | 339 | # Set up some helper functions 340 | @nnx.jit 341 | def jit_simulate( 342 | policy: Policy, rng: jax.Array 343 | ) -> Tuple[ 344 | jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, SimulatorState 345 | ]: 346 | """Simulate episodes in parallel. 347 | 348 | Args: 349 | policy: The policy network. 350 | rng: The random number generator key. 351 | 352 | Returns: 353 | The observations at each time step. 354 | The best action sequence at each time step. 355 | Average cost of SPC's best action sequence. 356 | Average cost of the policy's best action sequence. 357 | Fraction of times the policy generated the best action sequence. 358 | First four simulation trajectories for visualization. 359 | """ 360 | rngs = jax.random.split(rng, num_envs) 361 | 362 | y, U, U_guess, J_spc, J_policy, states = jax.vmap( 363 | simulate_episode, in_axes=(None, None, None, None, 0, None) 364 | )(env, ctrl, policy, exploration_noise_level, rngs, strategy) 365 | 366 | # Get the first few simulated trajectories 367 | selected_states = jax.tree.map(lambda x: x[:num_videos], states) 368 | 369 | frac = jnp.mean(J_policy < J_spc) 370 | return ( 371 | y, 372 | U, 373 | U_guess, 374 | jnp.mean(J_spc), 375 | jnp.mean(J_policy), 376 | frac, 377 | selected_states, 378 | ) 379 | 380 | @nnx.jit 381 | def jit_fit( 382 | policy: Policy, 383 | optimizer: nnx.Optimizer, 384 | observations: jax.Array, 385 | actions: jax.Array, 386 | previous_actions: jax.Array, 387 | rng: jax.Array, 388 | ) -> jax.Array: 389 | """Fit the policy network to the data. 390 | 391 | Args: 392 | policy: The policy network (updated in place). 393 | optimizer: The optimizer (updated in place). 394 | observations: The observations. 395 | actions: The best action sequences. 396 | previous_actions: The initial/guessed action sequences. 397 | rng: The random number generator key. 398 | 399 | Returns: 400 | The loss from the last epoch. 401 | """ 402 | # Flatten across timesteps and initial conditions 403 | y = observations.reshape(-1, observations.shape[-1]) 404 | U = actions.reshape(-1, env.task.planning_horizon, env.task.model.nu) 405 | U_guess = previous_actions.reshape( 406 | -1, env.task.planning_horizon, env.task.model.nu 407 | ) 408 | 409 | # Rescale the actions from [u_min, u_max] to [-1, 1] 410 | mean = (env.task.u_max + env.task.u_min) / 2 411 | scale = (env.task.u_max - env.task.u_min) / 2 412 | U = (U - mean) / scale 413 | U_guess = (U_guess - mean) / scale 414 | 415 | # Normalize the observations, updating the running statistics stored 416 | # in the policy 417 | y = policy.normalizer(y, use_running_average=not normalize_observations) 418 | 419 | # Do the regression 420 | return fit_policy( 421 | y, 422 | U, 423 | U_guess, 424 | policy.model, 425 | optimizer, 426 | batch_size, 427 | num_epochs, 428 | rng, 429 | ) 430 | 431 | train_start = datetime.now() 432 | for i in range(num_iters): 433 | # Simulate and record the best action sequences. Some of the action 434 | # samples are generated via SPC and others are generated by the policy. 435 | policy.model.eval() 436 | sim_start = time.time() 437 | rng, episode_rng = jax.random.split(rng) 438 | y, U, U_guess, J_spc, J_policy, frac, traj = jit_simulate( 439 | policy, episode_rng 440 | ) 441 | y.block_until_ready() 442 | sim_time = time.time() - sim_start 443 | 444 | # Render the first few trajectories for visualization 445 | # N.B. this uses CPU mujoco's rendering utils, so we need to do it 446 | # sequentially and outside a jit-compiled function 447 | if num_videos > 0: 448 | render_start = time.time() 449 | video_frames = [] 450 | for j in range(num_videos): 451 | states = jax.tree.map(lambda x: x[j], traj) # noqa: B023 452 | video_frames.append(env.render(states, video_fps)) 453 | video_frames = np.stack(video_frames) 454 | render_time = time.time() - render_start 455 | 456 | # Fit the policy network U = NNet(y) to the data 457 | policy.model.train() 458 | fit_start = time.time() 459 | rng, fit_rng = jax.random.split(rng) 460 | loss = jit_fit(policy, optimizer, y, U, U_guess, fit_rng) 461 | loss.block_until_ready() 462 | fit_time = time.time() - fit_start 463 | 464 | # TODO: run some evaluation tests 465 | 466 | # Save a policy checkpoint 467 | if i % checkpoint_every == 0 and i > 0: 468 | ckpt_path = log_dir / f"policy_ckpt_{i}.pkl" 469 | policy.save(ckpt_path) 470 | print(f"Saved policy checkpoint to {ckpt_path}") 471 | 472 | # Print a performance summary 473 | time_elapsed = datetime.now() - train_start 474 | print( 475 | f" {i+1}/{num_iters} |" 476 | f" policy cost {J_policy:.4f} |" 477 | f" spc cost {J_spc:.4f} |" 478 | f" {100 * frac:.2f}% policy is best |" 479 | f" loss {loss:.4f} |" 480 | f" {time_elapsed} elapsed" 481 | ) 482 | 483 | # Tensorboard logging 484 | tb_writer.add_scalar("sim/policy_cost", J_policy, i) 485 | tb_writer.add_scalar("sim/spc_cost", J_spc, i) 486 | tb_writer.add_scalar("sim/time", sim_time, i) 487 | tb_writer.add_scalar("sim/policy_best_frac", frac, i) 488 | tb_writer.add_scalar("fit/loss", loss, i) 489 | tb_writer.add_scalar("fit/time", fit_time, i) 490 | if num_videos > 0: 491 | tb_writer.add_scalar("render/time", render_time, i) 492 | tb_writer.add_video( 493 | "render/trajectories", video_frames, i, fps=video_fps 494 | ) 495 | tb_writer.flush() 496 | 497 | return policy 498 | --------------------------------------------------------------------------------