├── img
├── crane.png
├── pusht.png
├── walker.png
├── humanoid.png
├── particle.png
├── pendulum.png
├── summary.png
├── cart_pole.png
└── double_cart_pole.png
├── .gitignore
├── environment.yml
├── gpc
├── __init__.py
├── envs
│ ├── __init__.py
│ ├── pendulum.py
│ ├── cart_pole.py
│ ├── particle.py
│ ├── walker.py
│ ├── double_cart_pole.py
│ ├── humanoid.py
│ ├── pusht.py
│ ├── crane.py
│ └── base.py
├── testing.py
├── augmented.py
├── sampling.py
├── policy.py
├── architectures.py
└── training.py
├── .pre-commit-config.yaml
├── LICENSE
├── tests
├── test_augmented_ctrl.py
├── test_env.py
├── test_architectures.py
└── test_training.py
├── pyproject.toml
├── examples
├── cart_pole.py
├── pendulum.py
├── particle.py
├── double_cart_pole.py
├── walker.py
├── pusht.py
├── crane.py
└── humanoid.py
└── README.md
/img/crane.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/crane.png
--------------------------------------------------------------------------------
/img/pusht.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/pusht.png
--------------------------------------------------------------------------------
/img/walker.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/walker.png
--------------------------------------------------------------------------------
/img/humanoid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/humanoid.png
--------------------------------------------------------------------------------
/img/particle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/particle.png
--------------------------------------------------------------------------------
/img/pendulum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/pendulum.png
--------------------------------------------------------------------------------
/img/summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/summary.png
--------------------------------------------------------------------------------
/img/cart_pole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/cart_pole.png
--------------------------------------------------------------------------------
/img/double_cart_pole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vincekurtz/gpc/HEAD/img/double_cart_pole.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | __pycache__
3 | *.zip
4 | .vscode
5 | *.pt
6 | *.pkl
7 | *.egg-info
8 | MUJOCO_LOG.TXT
9 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gpc
2 | channels:
3 | - nvidia/label/cuda-12.3.0
4 | - conda-forge
5 | dependencies:
6 | - cuda
7 | - cudnn
8 | - python=3.12
9 |
--------------------------------------------------------------------------------
/gpc/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # Set XLA flags for better performance
4 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
5 | os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=true "
6 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - repo: https://github.com/charliermarsh/ruff-pre-commit
10 | rev: v0.7.1
11 | hooks:
12 | - id: ruff
13 | types_or: [ python, pyi, jupyter ]
14 | args: [ --fix ]
15 | - id: ruff-format
16 | types_or: [ python, pyi, jupyter ]
17 |
--------------------------------------------------------------------------------
/gpc/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import SimulatorState, TrainingEnv
2 | from .cart_pole import CartPoleEnv
3 | from .crane import CraneEnv
4 | from .double_cart_pole import DoubleCartPoleEnv
5 | from .humanoid import HumanoidEnv
6 | from .particle import ParticleEnv
7 | from .pendulum import PendulumEnv
8 | from .pusht import PushTEnv
9 | from .walker import WalkerEnv
10 |
11 | __all__ = [
12 | "SimulatorState",
13 | "TrainingEnv",
14 | "CartPoleEnv",
15 | "CraneEnv",
16 | "DoubleCartPoleEnv",
17 | "ParticleEnv",
18 | "PendulumEnv",
19 | "PushTEnv",
20 | "WalkerEnv",
21 | "HumanoidEnv",
22 | ]
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Vince Kurtz
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/gpc/envs/pendulum.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.tasks.pendulum import Pendulum
4 | from mujoco import mjx
5 |
6 | from gpc.envs import TrainingEnv
7 |
8 |
9 | class PendulumEnv(TrainingEnv):
10 | """Training environment for the pendulum swingup task."""
11 |
12 | def __init__(self, episode_length: int) -> None:
13 | """Set up the pendulum training environment."""
14 | super().__init__(
15 | task=Pendulum(planning_horizon=5),
16 | episode_length=episode_length,
17 | )
18 |
19 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
20 | """Reset the simulator to start a new episode."""
21 | rng, pos_rng, vel_rng = jax.random.split(rng, 3)
22 | qpos = jax.random.uniform(pos_rng, (1,), minval=-jnp.pi, maxval=jnp.pi)
23 | qvel = jax.random.uniform(vel_rng, (1,), minval=-8.0, maxval=8.0)
24 | return data.replace(qpos=qpos, qvel=qvel)
25 |
26 | def get_obs(self, data: mjx.Data) -> jax.Array:
27 | """Observe the velocity and sin/cos of the angle."""
28 | theta = data.qpos[0]
29 | theta_dot = data.qvel[0]
30 | return jnp.array([jnp.cos(theta), jnp.sin(theta), theta_dot])
31 |
32 | @property
33 | def observation_size(self) -> int:
34 | """The size of the observation space (sin, cos, theta_dot)."""
35 | return 3
36 |
--------------------------------------------------------------------------------
/gpc/envs/cart_pole.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.tasks.cart_pole import CartPole
4 | from mujoco import mjx
5 |
6 | from gpc.envs import TrainingEnv
7 |
8 |
9 | class CartPoleEnv(TrainingEnv):
10 | """Training environment for the cartpole swingup task."""
11 |
12 | def __init__(self, episode_length: int) -> None:
13 | """Set up the cartpole training environment."""
14 | super().__init__(task=CartPole(), episode_length=episode_length)
15 |
16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
17 | """Reset the simulator to start a new episode."""
18 | rng, theta_rng, pos_rng, vel_rng = jax.random.split(rng, 4)
19 |
20 | theta = jax.random.uniform(theta_rng, (), minval=-3.14, maxval=3.14)
21 | pos = jax.random.uniform(pos_rng, (), minval=-1.8, maxval=1.8)
22 | qvel = jax.random.uniform(vel_rng, (2,), minval=-2.0, maxval=2.0)
23 | qpos = jnp.array([pos, theta])
24 |
25 | return data.replace(qpos=qpos, qvel=qvel)
26 |
27 | def get_obs(self, data: mjx.Data) -> jax.Array:
28 | """Observe the velocity and sin/cos of the angle."""
29 | p = data.qpos[0]
30 | theta = data.qpos[1]
31 | v = data.qvel[0]
32 | theta_dot = data.qvel[1]
33 | return jnp.array([p, jnp.cos(theta), jnp.sin(theta), v, theta_dot])
34 |
35 | @property
36 | def observation_size(self) -> int:
37 | """The size of the observation space (includes sin and cos)."""
38 | return 5
39 |
--------------------------------------------------------------------------------
/gpc/envs/particle.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.tasks.particle import Particle
4 | from mujoco import mjx
5 |
6 | from gpc.envs import TrainingEnv
7 |
8 |
9 | class ParticleEnv(TrainingEnv):
10 | """Training environment for the particle task."""
11 |
12 | def __init__(self, episode_length: int = 100) -> None:
13 | """Set up the particle training environment."""
14 | super().__init__(task=Particle(), episode_length=episode_length)
15 |
16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
17 | """Reset the simulator to start a new episode."""
18 | rng, pos_rng, vel_rng, mocap_rng = jax.random.split(rng, 4)
19 | qpos = jax.random.uniform(pos_rng, (2,), minval=-0.29, maxval=0.29)
20 | qvel = jax.random.uniform(vel_rng, (2,), minval=-0.5, maxval=0.5)
21 | target = jax.random.uniform(mocap_rng, (2,), minval=-0.29, maxval=0.29)
22 | mocap_pos = data.mocap_pos.at[0, 0:2].set(target)
23 | return data.replace(qpos=qpos, qvel=qvel, mocap_pos=mocap_pos)
24 |
25 | def get_obs(self, data: mjx.Data) -> jax.Array:
26 | """Observe the position relative to the target and the velocity."""
27 | pos = (
28 | data.site_xpos[self.task.pointmass_id, 0:2] - data.mocap_pos[0, 0:2]
29 | )
30 | vel = data.qvel[:]
31 | return jnp.concatenate([pos, vel])
32 |
33 | @property
34 | def observation_size(self) -> int:
35 | """The size of the observation space."""
36 | return 4
37 |
--------------------------------------------------------------------------------
/tests/test_augmented_ctrl.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.algs import PredictiveSampling
4 | from hydrax.tasks.particle import Particle
5 | from mujoco import mjx
6 |
7 | from gpc.augmented import PolicyAugmentedController
8 |
9 |
10 | def test_augmented() -> None:
11 | """Test the prediction-augmented controller."""
12 | # Task and optimizer setup
13 | task = Particle()
14 | ps = PredictiveSampling(task, num_samples=32, noise_level=0.1)
15 | opt = PolicyAugmentedController(ps, num_policy_samples=32)
16 | jit_opt = jax.jit(opt.optimize)
17 |
18 | # Initialize the system state and policy parameters
19 | state = mjx.make_data(task.model)
20 | state = state.replace(
21 | mocap_pos=state.mocap_pos.at[0, 0:2].set(jnp.array([0.01, 0.01]))
22 | )
23 | params = opt.init_params()
24 | params = params.replace(
25 | policy_samples=jnp.ones((32, task.planning_horizon, task.model.nu))
26 | )
27 |
28 | for _ in range(10):
29 | # Do an optimization step
30 | params, rollouts = jit_opt(state, params)
31 |
32 | # Pick the best rollout
33 | total_costs = jnp.sum(rollouts.costs, axis=1)
34 | best_idx = jnp.argmin(total_costs)
35 | best_ctrl = rollouts.controls[best_idx]
36 |
37 | assert jnp.all(best_ctrl != 0.0)
38 | assert jnp.all(params.policy_samples == 1.0)
39 |
40 | U = opt.get_action_sequence(params)
41 | assert jnp.allclose(U, params.base_params.mean)
42 |
43 |
44 | if __name__ == "__main__":
45 | test_augmented()
46 |
--------------------------------------------------------------------------------
/gpc/envs/walker.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.tasks.walker import Walker
4 | from mujoco import mjx
5 |
6 | from gpc.envs import TrainingEnv
7 |
8 |
9 | class WalkerEnv(TrainingEnv):
10 | """Training environment for the walker task."""
11 |
12 | def __init__(self, episode_length: int) -> None:
13 | """Set up the walker training environment."""
14 | super().__init__(task=Walker(), episode_length=episode_length)
15 |
16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
17 | """Reset the simulator to start a new episode."""
18 | rng, pos_rng, vel_rng = jax.random.split(rng, 3)
19 |
20 | # Joint limits are zero for the floating base
21 | q_min = self.task.model.jnt_range[:, 0]
22 | q_max = self.task.model.jnt_range[:, 1]
23 | q_min = q_min.at[2].set(-1.5) # orientation
24 | q_max = q_max.at[2].set(1.5)
25 | qpos = jax.random.uniform(pos_rng, (9,), minval=q_min, maxval=q_max)
26 | qvel = jax.random.uniform(vel_rng, (9,), minval=-0.1, maxval=0.1)
27 |
28 | return data.replace(qpos=qpos, qvel=qvel)
29 |
30 | def get_obs(self, data: mjx.Data) -> jax.Array:
31 | """Observe everything in the state except the horizontal position."""
32 | pz = data.qpos[0] # base coordinates are (z, x, theta)
33 | theta = data.qpos[2]
34 | base_pos_data = jnp.array([jnp.cos(theta), jnp.sin(theta), pz])
35 | return jnp.concatenate([base_pos_data, data.qpos[3:], data.qvel])
36 |
37 | @property
38 | def observation_size(self) -> int:
39 | """The size of the observation space."""
40 | return 18
41 |
--------------------------------------------------------------------------------
/tests/test_env.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from gpc.envs import ParticleEnv, SimulatorState
7 |
8 |
9 | def test_particle_env() -> None:
10 | """Test the particle environment."""
11 | rng = jax.random.key(0)
12 | env = ParticleEnv()
13 |
14 | state = env.init_state(rng)
15 | assert state.t == 0
16 |
17 | state2 = env._reset_state(state)
18 | assert jnp.all(state.data.qpos != state2.data.qpos)
19 | assert jnp.all(state.data.qvel != state2.data.qvel)
20 | assert jnp.all(state.rng != state2.rng)
21 |
22 | obs = env._get_observation(state)
23 | assert obs.shape == (4,)
24 |
25 | jit_step = jax.jit(env.step)
26 |
27 | state = jit_step(state, jnp.zeros(2))
28 | assert state.t == 1
29 |
30 | state = state.replace(t=100)
31 | assert env.episode_over(state)
32 | state = jit_step(state, jnp.zeros(2))
33 | assert state.t == 0
34 |
35 |
36 | def test_render() -> None:
37 | """Test rendering the particle environment."""
38 | rng = jax.random.key(0)
39 | env = ParticleEnv()
40 |
41 | rng, init_rng = jax.random.split(rng)
42 | state = env.init_state(init_rng)
43 |
44 | def _step(state: SimulatorState, action: jax.Array) -> Tuple:
45 | state = env.step(state, action)
46 | return state, state
47 |
48 | num_steps = 100
49 | rng, act_rng = jax.random.split(rng)
50 | actions = jax.random.normal(act_rng, (num_steps, 2))
51 | _, states = jax.lax.scan(_step, state, actions)
52 |
53 | frames = env.render(states, fps=10)
54 | assert frames.shape == (10, 3, 240, 320)
55 |
56 |
57 | if __name__ == "__main__":
58 | test_particle_env()
59 | test_render()
60 |
--------------------------------------------------------------------------------
/gpc/envs/double_cart_pole.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.tasks.double_cart_pole import DoubleCartPole
4 | from mujoco import mjx
5 |
6 | from gpc.envs import TrainingEnv
7 |
8 |
9 | class DoubleCartPoleEnv(TrainingEnv):
10 | """Training environment for the double cart-pole swingup task."""
11 |
12 | def __init__(self, episode_length: int) -> None:
13 | """Set up the training environment."""
14 | super().__init__(task=DoubleCartPole(), episode_length=episode_length)
15 |
16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
17 | """Reset the simulator to start a new episode."""
18 | rng, theta_rng, pos_rng, vel_rng = jax.random.split(rng, 4)
19 |
20 | thetas = jax.random.uniform(theta_rng, (2), minval=-3.14, maxval=3.14)
21 | pos = jax.random.uniform(pos_rng, (), minval=-2.8, maxval=2.8)
22 | qvel = jax.random.uniform(vel_rng, (3,), minval=-10.0, maxval=10.0)
23 | qpos = jnp.array([pos, thetas[0], thetas[1]])
24 |
25 | return data.replace(qpos=qpos, qvel=qvel)
26 |
27 | def get_obs(self, data: mjx.Data) -> jax.Array:
28 | """Observe the velocity and sin/cos of the angles."""
29 | p = data.qpos[0]
30 | theta1 = data.qpos[1]
31 | theta2 = data.qpos[2]
32 | q_obs = jnp.array(
33 | [
34 | p,
35 | jnp.cos(theta1),
36 | jnp.sin(theta1),
37 | jnp.cos(theta2),
38 | jnp.sin(theta2),
39 | ]
40 | )
41 | return jnp.concatenate([q_obs, data.qvel])
42 |
43 | @property
44 | def observation_size(self) -> int:
45 | """The size of the observation space (includes sin and cos)."""
46 | return 8
47 |
--------------------------------------------------------------------------------
/gpc/envs/humanoid.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.tasks.humanoid_standup import HumanoidStandup
4 | from mujoco import mjx
5 |
6 | from gpc.envs import TrainingEnv
7 |
8 |
9 | class HumanoidEnv(TrainingEnv):
10 | """Training environment for humanoid (Unitree G1) standup."""
11 |
12 | def __init__(self, episode_length: int) -> None:
13 | """Set up the walker training environment."""
14 | super().__init__(task=HumanoidStandup(), episode_length=episode_length)
15 |
16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
17 | """Reset the simulator to start a new episode."""
18 | rng, pos_rng, vel_rng, ori_rng = jax.random.split(rng, 4)
19 |
20 | # Random positions and velocities
21 | qpos = self.task.qstand + 0.1 * jax.random.normal(
22 | pos_rng, (self.task.model.nq,)
23 | )
24 | qvel = 0.1 * jax.random.normal(vel_rng, (self.task.model.nv,))
25 |
26 | # Random base orientation
27 | u, v, w = jax.random.uniform(ori_rng, (3,))
28 | quat = jnp.array(
29 | [
30 | jnp.sqrt(1 - u) * jnp.sin(2 * jnp.pi * v),
31 | jnp.sqrt(1 - u) * jnp.cos(2 * jnp.pi * v),
32 | jnp.sqrt(u) * jnp.sin(2 * jnp.pi * w),
33 | jnp.sqrt(u) * jnp.cos(2 * jnp.pi * w),
34 | ]
35 | )
36 | qpos = qpos.at[3:7].set(quat)
37 |
38 | return data.replace(qpos=qpos, qvel=qvel)
39 |
40 | def get_obs(self, data: mjx.Data) -> jax.Array:
41 | """Observe the full state, regularized to be agnostic to orientation."""
42 | height = self.task._get_torso_height(data)[None]
43 | orientation = self.task._get_torso_orientation(data) # upright rotation
44 | return jnp.concatenate([height, orientation, data.qpos[7:], data.qvel])
45 |
46 | @property
47 | def observation_size(self) -> int:
48 | """The size of the observations."""
49 | return 56
50 |
--------------------------------------------------------------------------------
/gpc/envs/pusht.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.tasks.pusht import PushT
4 | from mujoco import mjx
5 |
6 | from gpc.envs import TrainingEnv
7 |
8 |
9 | class PushTEnv(TrainingEnv):
10 | """Training environment for the pusher-T task."""
11 |
12 | def __init__(self, episode_length: int) -> None:
13 | """Set up the walker training environment."""
14 | super().__init__(task=PushT(), episode_length=episode_length)
15 |
16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
17 | """Reset the simulator to start a new episode."""
18 | rng, pos_rng, vel_rng, goal_pos_rng, goal_ori_rng = jax.random.split(
19 | rng, 5
20 | )
21 |
22 | # Random configuration for the pusher and the T
23 | q_min = jnp.array([-0.2, -0.2, -jnp.pi, -0.2, -0.2])
24 | q_max = jnp.array([0.2, 0.2, jnp.pi, 0.2, 0.2])
25 | qpos = jax.random.uniform(pos_rng, (5,), minval=q_min, maxval=q_max)
26 |
27 | # Velocities fixed at zero
28 | qvel = jax.random.uniform(vel_rng, (5,), minval=-0.0, maxval=0.0)
29 |
30 | # Goal position and orientation fixed at zero
31 | goal = jax.random.uniform(goal_pos_rng, (2,), minval=-0.0, maxval=0.0)
32 | mocap_pos = data.mocap_pos.at[0, 0:2].set(goal)
33 | theta = jax.random.uniform(goal_ori_rng, (), minval=0.0, maxval=0.0)
34 | mocap_quat = jnp.array([[jnp.cos(theta / 2), 0, 0, jnp.sin(theta / 2)]])
35 |
36 | return data.replace(
37 | qpos=qpos, qvel=qvel, mocap_pos=mocap_pos, mocap_quat=mocap_quat
38 | )
39 |
40 | def get_obs(self, data: mjx.Data) -> jax.Array:
41 | """Observe positions relative to the target."""
42 | pusher_pos = data.qpos[-2:]
43 | block_pos = data.qpos[0:2]
44 | block_ori = self.task._get_orientation_err(data)[0:1]
45 | return jnp.concatenate([pusher_pos, block_pos, block_ori])
46 |
47 | @property
48 | def observation_size(self) -> int:
49 | """The size of the observation space."""
50 | return 5
51 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "gpc"
7 | version = "0.0.1"
8 | description = "Generative Predictive Control"
9 | readme = "README.md"
10 | license = {text="MIT"}
11 | requires-python = ">=3.12.0"
12 | classifiers = [
13 | "Development Status :: 3 - Alpha",
14 | "Programming Language :: Python",
15 | ]
16 | dependencies = [
17 | "hydrax@git+https://github.com/vincekurtz/hydrax@v0.0.1",
18 | "jax[cuda12]==0.4.35",
19 | "flax==0.10.0",
20 | "pytest==8.3.3",
21 | "ruff==0.7.1",
22 | "pre-commit==4.0.1",
23 | "matplotlib==3.9.2",
24 | "mujoco==3.2.4",
25 | "mujoco-mjx==3.2.4",
26 | "evosax==0.1.6",
27 | "optax==0.2.3",
28 | "tensorboard==2.18.0",
29 | "tensorboardX==2.6.2",
30 | "cloudpickle==3.1.0",
31 | "moviepy==1.0.3",
32 | "imageio==2.27",
33 | ]
34 |
35 | [tool.ruff]
36 | line-length = 80
37 |
38 | [tool.ruff.lint]
39 | pydocstyle.convention = "google"
40 | select = [
41 | "ANN", # annotations
42 | "N", # naming conventions
43 | "D", # docstrings
44 | "B", # flake8 bugbear
45 | "E", # pycodestyle errors
46 | "F", # Pyflakes rules
47 | "I", # isort formatting
48 | "PLC", # Pylint convention warnings
49 | "PLE", # Pylint errors
50 | "PLR", # Pylint refactor recommendations
51 | "PLW", # Pylint warnings
52 | ]
53 | ignore = [
54 | "ANN003", # missing type annotation for **kwargs
55 | "ANN101", # missing type annotation for `self` in method
56 | "ANN202", # missing return type annotation for private function
57 | "ANN204", # missing return type annotation for `__init__`
58 | "ANN401", # dynamically typed expressions (typing.Any) are disallowed
59 | "D100", # missing docstring in public module
60 | "D104", # missing docstring in public package
61 | "D203", # blank line before class docstring
62 | "D211", # no blank line before class
63 | "D212", # multi-line docstring summary at first line
64 | "D213", # multi-line docstring summary at second line
65 | "E731", # assigning to a `lambda` expression
66 | "N806", # only lowercase variables in functions
67 | "PLR0913", # too many arguments
68 | "PLR2004", # magic value used in comparison
69 | ]
70 |
71 | [tool.ruff.lint.isort]
72 | combine-as-imports = true
73 | known-first-party = ["hydra"]
74 | split-on-trailing-comma = false
75 |
76 | [tool.ruff.format]
77 | quote-style = "double"
78 | indent-style = "space"
79 | skip-magic-trailing-comma = false
80 | line-ending = "auto"
81 | docstring-code-format = true
82 | docstring-code-line-length = "dynamic"
83 |
84 | [tool.pytest.ini_options]
85 | testpaths = [
86 | "tests",
87 | ]
88 |
89 | [tool.setuptools]
90 | packages = ["gpc"]
91 |
--------------------------------------------------------------------------------
/examples/cart_pole.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mujoco
4 | from flax import nnx
5 | from hydrax.algs import PredictiveSampling
6 | from hydrax.simulation.deterministic import run_interactive as run_sampling
7 |
8 | from gpc.architectures import DenoisingMLP
9 | from gpc.envs import CartPoleEnv
10 | from gpc.policy import Policy
11 | from gpc.sampling import BootstrappedPredictiveSampling
12 | from gpc.testing import test_interactive
13 | from gpc.training import train
14 |
15 | if __name__ == "__main__":
16 | # Parse command-line arguments
17 | parser = argparse.ArgumentParser(
18 | description="Balance an inverted pendulum on a cart"
19 | )
20 | subparsers = parser.add_subparsers(
21 | dest="task", help="What to do (choose one)"
22 | )
23 | subparsers.add_parser("train", help="Train (and save) a generative policy")
24 | subparsers.add_parser("test", help="Test a generative policy")
25 | subparsers.add_parser(
26 | "sample", help="Bootstrap sampling-based MPC with a generative policy"
27 | )
28 | args = parser.parse_args()
29 |
30 | # Set up the environment and save file
31 | env = CartPoleEnv(episode_length=200)
32 | save_file = "/tmp/cart_pole_policy.pkl"
33 |
34 | if args.task == "train":
35 | # Train the policy and save it to a file
36 | ctrl = PredictiveSampling(env.task, num_samples=8, noise_level=0.1)
37 | net = DenoisingMLP(
38 | action_size=env.task.model.nu,
39 | observation_size=env.observation_size,
40 | horizon=env.task.planning_horizon,
41 | hidden_layers=[64, 64],
42 | rngs=nnx.Rngs(0),
43 | )
44 | policy = train(
45 | env,
46 | ctrl,
47 | net,
48 | num_policy_samples=2,
49 | log_dir="/tmp/gpc_cart_pole",
50 | num_iters=10,
51 | num_envs=128,
52 | num_epochs=100,
53 | )
54 | policy.save(save_file)
55 | print(f"Saved policy to {save_file}")
56 |
57 | elif args.task == "test":
58 | # Load the policy from a file and test it interactively
59 | print(f"Loading policy from {save_file}")
60 | policy = Policy.load(save_file)
61 | test_interactive(env, policy)
62 |
63 | elif args.task == "sample":
64 | # Use the policy to bootstrap sampling-based MPC
65 | policy = Policy.load(save_file)
66 | ctrl = BootstrappedPredictiveSampling(
67 | policy,
68 | observation_fn=env.get_obs,
69 | num_policy_samples=2,
70 | task=env.task,
71 | num_samples=1,
72 | noise_level=0.1,
73 | )
74 | mj_model = env.task.mj_model
75 | mj_data = mujoco.MjData(mj_model)
76 | run_sampling(ctrl, mj_model, mj_data, frequency=50)
77 |
78 | else:
79 | parser.print_help()
80 |
--------------------------------------------------------------------------------
/examples/pendulum.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mujoco
4 | from flax import nnx
5 | from hydrax.algs import PredictiveSampling
6 | from hydrax.simulation.deterministic import run_interactive as run_sampling
7 |
8 | from gpc.architectures import DenoisingMLP
9 | from gpc.envs import PendulumEnv
10 | from gpc.policy import Policy
11 | from gpc.sampling import BootstrappedPredictiveSampling
12 | from gpc.testing import test_interactive
13 | from gpc.training import train
14 |
15 | if __name__ == "__main__":
16 | # Parse command line arguments
17 | parser = argparse.ArgumentParser(
18 | description="Swing up an inverted pendulum"
19 | )
20 | subparsers = parser.add_subparsers(
21 | dest="task", help="What to do (choose one)"
22 | )
23 | subparsers.add_parser("train", help="Train (and save) a generative policy")
24 | subparsers.add_parser("test", help="Test a generative policy")
25 | subparsers.add_parser(
26 | "sample", help="Bootstrap sampling-based MPC with a generative policy"
27 | )
28 | args = parser.parse_args()
29 |
30 | # Set up the environment and save file
31 | env = PendulumEnv(episode_length=200)
32 | save_file = "/tmp/pendulum_policy.pkl"
33 |
34 | if args.task == "train":
35 | # Train the policy and save it to a file
36 | ctrl = PredictiveSampling(env.task, num_samples=8, noise_level=0.1)
37 | net = DenoisingMLP(
38 | action_size=env.task.model.nu,
39 | observation_size=env.observation_size,
40 | horizon=env.task.planning_horizon,
41 | hidden_layers=[32, 32],
42 | rngs=nnx.Rngs(0),
43 | )
44 | policy = train(
45 | env,
46 | ctrl,
47 | net,
48 | num_policy_samples=2,
49 | log_dir="/tmp/gpc_pendulum",
50 | num_epochs=10,
51 | num_iters=10,
52 | num_envs=128,
53 | num_videos=2,
54 | strategy="policy",
55 | )
56 | policy.save(save_file)
57 | print(f"Saved policy to {save_file}")
58 |
59 | elif args.task == "test":
60 | # Load the policy from a file and test it interactively
61 | print(f"Loading policy from {save_file}")
62 | policy = Policy.load(save_file)
63 | test_interactive(env, policy)
64 |
65 | elif args.task == "sample":
66 | # Use the policy to bootstrap sampling-based MPC
67 | policy = Policy.load(save_file)
68 | ctrl = BootstrappedPredictiveSampling(
69 | policy,
70 | env.get_obs,
71 | num_policy_samples=4,
72 | task=env.task,
73 | num_samples=4,
74 | noise_level=0.1,
75 | )
76 |
77 | mj_model = env.task.mj_model
78 | mj_data = mujoco.MjData(mj_model)
79 | run_sampling(ctrl, mj_model, mj_data, frequency=50)
80 |
81 | else:
82 | parser.print_help()
83 |
--------------------------------------------------------------------------------
/examples/particle.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mujoco
4 | from flax import nnx
5 | from hydrax.algs import PredictiveSampling
6 | from hydrax.simulation.deterministic import run_interactive as run_sampling
7 |
8 | from gpc.architectures import DenoisingMLP
9 | from gpc.envs import ParticleEnv
10 | from gpc.policy import Policy
11 | from gpc.sampling import BootstrappedPredictiveSampling
12 | from gpc.testing import test_interactive
13 | from gpc.training import train
14 |
15 | if __name__ == "__main__":
16 | # Parse command line arguments
17 | parser = argparse.ArgumentParser(
18 | description="Drive a point mass to a target position"
19 | )
20 | subparsers = parser.add_subparsers(
21 | dest="task", help="What to do (choose one)"
22 | )
23 | subparsers.add_parser("train", help="Train (and save) a generative policy")
24 | subparsers.add_parser("test", help="Test a generative policy")
25 | subparsers.add_parser(
26 | "sample", help="Bootstrap sampling-based MPC with a generative policy"
27 | )
28 | args = parser.parse_args()
29 |
30 | # Set up the environment and save file
31 | env = ParticleEnv(episode_length=100)
32 | save_file = "/tmp/particle_policy.pkl"
33 |
34 | if args.task == "train":
35 | # Train the policy and save it to a file
36 | ctrl = PredictiveSampling(env.task, num_samples=8, noise_level=0.1)
37 | net = DenoisingMLP(
38 | action_size=env.task.model.nu,
39 | observation_size=env.observation_size,
40 | horizon=env.task.planning_horizon,
41 | hidden_layers=[32, 32],
42 | rngs=nnx.Rngs(0),
43 | )
44 | policy = train(
45 | env,
46 | ctrl,
47 | net,
48 | num_policy_samples=8,
49 | log_dir="/tmp/gpc_particle",
50 | num_iters=10,
51 | num_envs=128,
52 | batch_size=128,
53 | num_epochs=100,
54 | )
55 | policy.save(save_file)
56 | print(f"Saved policy to {save_file}")
57 |
58 | elif args.task == "test":
59 | # Load the policy from a file and test it interactively
60 | print(f"Loading policy from {save_file}")
61 | policy = Policy.load(save_file)
62 | test_interactive(env, policy)
63 |
64 | elif args.task == "sample":
65 | # Use the policy to bootstrap sampling-based MPC
66 | policy = Policy.load(save_file)
67 | ctrl = BootstrappedPredictiveSampling(
68 | policy,
69 | env.get_obs,
70 | inference_timestep=0.1,
71 | num_policy_samples=4,
72 | task=env.task,
73 | num_samples=1,
74 | noise_level=0.1,
75 | )
76 | mj_model = env.task.mj_model
77 | mj_data = mujoco.MjData(mj_model)
78 | run_sampling(ctrl, mj_model, mj_data, frequency=50)
79 |
80 | else:
81 | parser.print_help()
82 |
--------------------------------------------------------------------------------
/examples/double_cart_pole.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mujoco
4 | from flax import nnx
5 | from hydrax.algs import PredictiveSampling
6 | from hydrax.simulation.deterministic import run_interactive as run_sampling
7 |
8 | from gpc.architectures import DenoisingMLP
9 | from gpc.envs import DoubleCartPoleEnv
10 | from gpc.policy import Policy
11 | from gpc.sampling import BootstrappedPredictiveSampling
12 | from gpc.testing import test_interactive
13 | from gpc.training import train
14 |
15 | if __name__ == "__main__":
16 | # Parse command line arguments
17 | parser = argparse.ArgumentParser(
18 | description="Balance a double inverted pendulum on a cart"
19 | )
20 | subparsers = parser.add_subparsers(
21 | dest="task", help="What to do (choose one)"
22 | )
23 | subparsers.add_parser("train", help="Train (and save) a generative policy")
24 | subparsers.add_parser("test", help="Test a generative policy")
25 | subparsers.add_parser(
26 | "sample", help="Bootstrap sampling-based MPC with a generative policy"
27 | )
28 | args = parser.parse_args()
29 |
30 | # Set up the environment and save file
31 | env = DoubleCartPoleEnv(episode_length=400)
32 | save_file = "/tmp/double_cart_pole_policy.pkl"
33 |
34 | if args.task == "train":
35 | # Train the policy and save it to a file
36 | ctrl = PredictiveSampling(env.task, num_samples=16, noise_level=0.3)
37 | net = DenoisingMLP(
38 | action_size=env.task.model.nu,
39 | observation_size=env.observation_size,
40 | horizon=env.task.planning_horizon,
41 | hidden_layers=[128, 128],
42 | rngs=nnx.Rngs(0),
43 | )
44 | policy = train(
45 | env,
46 | ctrl,
47 | net,
48 | num_policy_samples=16,
49 | log_dir="/tmp/gpc_double_cart_pole",
50 | num_iters=50,
51 | num_envs=256,
52 | num_epochs=100,
53 | checkpoint_every=5,
54 | num_videos=4,
55 | )
56 | policy.save(save_file)
57 | print(f"Saved policy to {save_file}")
58 |
59 | elif args.task == "test":
60 | # Load the policy from a file and test it interactively
61 | print(f"Loading policy from {save_file}")
62 | policy = Policy.load(save_file)
63 | test_interactive(env, policy, inference_timestep=0.1)
64 |
65 | elif args.task == "sample":
66 | # Use the policy to bootstrap sampling-based MPC
67 | policy = Policy.load(save_file)
68 | ctrl = BootstrappedPredictiveSampling(
69 | policy,
70 | env.get_obs,
71 | inference_timestep=0.01,
72 | num_policy_samples=4,
73 | task=env.task,
74 | num_samples=1,
75 | noise_level=0.3,
76 | )
77 | mj_model = env.task.mj_model
78 | mj_data = mujoco.MjData(mj_model)
79 | run_sampling(ctrl, mj_model, mj_data, frequency=50)
80 |
81 | else:
82 | parser.print_help()
83 |
--------------------------------------------------------------------------------
/examples/walker.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mujoco
4 | from flax import nnx
5 | from hydrax.algs import PredictiveSampling
6 | from hydrax.simulation.deterministic import run_interactive as run_sampling
7 |
8 | from gpc.architectures import DenoisingCNN
9 | from gpc.envs import WalkerEnv
10 | from gpc.policy import Policy
11 | from gpc.sampling import BootstrappedPredictiveSampling
12 | from gpc.testing import test_interactive
13 | from gpc.training import train
14 |
15 | if __name__ == "__main__":
16 | # Parse command line arguments
17 | parser = argparse.ArgumentParser(description="Planar biped locomotion")
18 | subparsers = parser.add_subparsers(
19 | dest="task", help="What to do (choose one)"
20 | )
21 | subparsers.add_parser("train", help="Train (and save) a generative policy")
22 | subparsers.add_parser("test", help="Test a generative policy")
23 | subparsers.add_parser(
24 | "sample", help="Bootstrap sampling-based MPC with a generative policy"
25 | )
26 | args = parser.parse_args()
27 |
28 | # Set up the environment and save file
29 | env = WalkerEnv(episode_length=500)
30 | save_file = "/tmp/walker_policy.pkl"
31 |
32 | if args.task == "train":
33 | # Train the policy and save it to a file
34 | ctrl = PredictiveSampling(env.task, num_samples=16, noise_level=0.3)
35 | net = DenoisingCNN(
36 | action_size=env.task.model.nu,
37 | observation_size=env.observation_size,
38 | horizon=env.task.planning_horizon,
39 | feature_dims=[64, 64],
40 | timestep_embedding_dim=16,
41 | rngs=nnx.Rngs(0),
42 | )
43 | policy = train(
44 | env,
45 | ctrl,
46 | net,
47 | log_dir="/tmp/gpc_walker",
48 | num_policy_samples=16,
49 | num_iters=20,
50 | num_envs=128,
51 | num_epochs=10,
52 | )
53 | policy.save(save_file)
54 | print(f"Saved policy to {save_file}")
55 |
56 | elif args.task == "test":
57 | # Load the policy from a file and test it interactively
58 | print(f"Loading policy from {save_file}")
59 | policy = Policy.load(save_file)
60 | test_interactive(
61 | env, policy, inference_timestep=0.01, warm_start_level=1.0
62 | )
63 |
64 | elif args.task == "sample":
65 | # Use the policy to bootstrap sampling-based MPC
66 | policy = Policy.load(save_file)
67 | ctrl = BootstrappedPredictiveSampling(
68 | policy,
69 | env.get_obs,
70 | warm_start_level=0.5,
71 | inference_timestep=0.1,
72 | num_policy_samples=128,
73 | task=env.task,
74 | num_samples=1,
75 | noise_level=0.3,
76 | )
77 |
78 | mj_model = env.task.mj_model
79 | mj_data = mujoco.MjData(mj_model)
80 | run_sampling(ctrl, mj_model, mj_data, frequency=50, fixed_camera_id=0)
81 |
82 | else:
83 | parser.print_help()
84 |
--------------------------------------------------------------------------------
/examples/pusht.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mujoco
4 | from flax import nnx
5 | from hydrax.algs import PredictiveSampling
6 | from hydrax.simulation.deterministic import run_interactive as run_sampling
7 |
8 | from gpc.architectures import DenoisingCNN
9 | from gpc.envs import PushTEnv
10 | from gpc.policy import Policy
11 | from gpc.sampling import BootstrappedPredictiveSampling
12 | from gpc.testing import test_interactive
13 | from gpc.training import train
14 |
15 | if __name__ == "__main__":
16 | # Parse command line arguments
17 | parser = argparse.ArgumentParser(
18 | description="Push a T-shaped block on a table"
19 | )
20 | subparsers = parser.add_subparsers(
21 | dest="task", help="What to do (choose one)"
22 | )
23 | subparsers.add_parser("train", help="Train (and save) a generative policy")
24 | subparsers.add_parser("test", help="Test a generative policy")
25 | subparsers.add_parser(
26 | "sample", help="Bootstrap sampling-based MPC with a generative policy"
27 | )
28 | args = parser.parse_args()
29 |
30 | # Set up the environment and save file
31 | env = PushTEnv(episode_length=400)
32 | save_file = "/tmp/pusht_policy.pkl"
33 |
34 | if args.task == "train":
35 | # Train the policy and save it to a file
36 | ctrl = PredictiveSampling(env.task, num_samples=128, noise_level=0.4)
37 | net = DenoisingCNN(
38 | action_size=env.task.model.nu,
39 | observation_size=env.observation_size,
40 | horizon=env.task.planning_horizon,
41 | feature_dims=[32, 32, 32],
42 | timestep_embedding_dim=8,
43 | rngs=nnx.Rngs(0),
44 | )
45 | policy = train(
46 | env,
47 | ctrl,
48 | net,
49 | num_policy_samples=32,
50 | log_dir="/tmp/gpc_pusht",
51 | num_iters=20,
52 | num_envs=128,
53 | num_epochs=10,
54 | checkpoint_every=5,
55 | )
56 | policy.save(save_file)
57 | print(f"Saved policy to {save_file}")
58 |
59 | elif args.task == "test":
60 | # Load the policy from a file and test it interactively
61 | print(f"Loading policy from {save_file}")
62 | policy = Policy.load(save_file)
63 | mj_data = mujoco.MjData(env.task.mj_model)
64 | mj_data.qpos[:] = [0.1, 0.1, 2.0, 0.0, 0.0] # set the initial state
65 | test_interactive(env, policy, mj_data)
66 |
67 | elif args.task == "sample":
68 | # Use the policy to bootstrap sampling-based MPC
69 | policy = Policy.load(save_file)
70 | ctrl = BootstrappedPredictiveSampling(
71 | policy,
72 | env.get_obs,
73 | warm_start_level=0.5,
74 | num_policy_samples=32,
75 | task=env.task,
76 | num_samples=1,
77 | noise_level=0.1,
78 | )
79 | mj_model = env.task.mj_model
80 | mj_data = mujoco.MjData(mj_model)
81 | run_sampling(ctrl, mj_model, mj_data, frequency=50)
82 |
83 | else:
84 | parser.print_help()
85 |
--------------------------------------------------------------------------------
/gpc/envs/crane.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from hydrax.tasks.crane import Crane
4 | from mujoco import mjx
5 |
6 | from gpc.envs import TrainingEnv
7 |
8 |
9 | class CraneEnv(TrainingEnv):
10 | """Training environment for the luffing crane end-effetor tracking task."""
11 |
12 | def __init__(self, episode_length: int = 100) -> None:
13 | """Set up the particle training environment."""
14 | super().__init__(task=Crane(), episode_length=episode_length)
15 |
16 | def reset(self, data: mjx.Data, rng: jax.Array) -> mjx.Data:
17 | """Reset the simulator to start a new episode."""
18 | rng, pos_rng, vel_rng, target_rng = jax.random.split(rng, 4)
19 |
20 | # Crane state
21 | # TODO: figure out a more principled way to initialize these.
22 | # Right now the crane swings wildly at initialization, which prevents
23 | # gathering a lot of data near the target (which is where the system is
24 | # mostly at run time).
25 | q_lim = jnp.array(
26 | [
27 | [-1.0, 1.0], # slew
28 | [0.0, 1.0], # luff
29 | [-1.0, 1.0], # payload x-pos
30 | [1.0, 2.2], # payload y-pos
31 | [0.3, 1.0], # payload z-pos
32 | [1.0, 1.0], # payload orientation (fixed upright)
33 | [0.0, 0.0],
34 | [0.0, 0.0],
35 | [0.0, 0.0],
36 | ]
37 | )
38 | qpos = self.task.model.qpos0 + jax.random.uniform(
39 | pos_rng,
40 | (self.task.model.nq,),
41 | minval=q_lim[:, 0],
42 | maxval=q_lim[:, 1],
43 | )
44 | qvel = jax.random.uniform(
45 | vel_rng, (self.task.model.nv,), minval=-0.1, maxval=0.1
46 | )
47 |
48 | # Target position
49 | # TODO: figure out a better set of potential target positions
50 | pos_min = jnp.array([-1.5, 1.0, 0.0])
51 | pos_max = jnp.array([1.5, 3.0, 1.5])
52 | target_pos = jax.random.uniform(
53 | target_rng, (3,), minval=pos_min, maxval=pos_max
54 | )
55 | mocap_pos = data.mocap_pos.at[0].set(target_pos)
56 |
57 | # Target orientation - this is unused but must be set so vectorization
58 | # (which is determined by the size of rng) works properly.
59 | target_quat = jnp.array([1.0, 0.0, 0.0, 0.0]) + jax.random.uniform(
60 | target_rng, (4,), minval=-0.0, maxval=0.0
61 | )
62 | mocap_quat = data.mocap_quat.at[0].set(target_quat)
63 |
64 | return data.replace(
65 | qpos=qpos, qvel=qvel, mocap_pos=mocap_pos, mocap_quat=mocap_quat
66 | )
67 |
68 | def get_obs(self, data: mjx.Data) -> jax.Array:
69 | """Observe the full crane state, plus end-effector pos/vel."""
70 | ee_pos = self.task._get_payload_position(data)
71 | ee_vel = self.task._get_payload_velocity(data)
72 | return jnp.concatenate([ee_pos, ee_vel, data.qpos, data.qvel])
73 |
74 | @property
75 | def observation_size(self) -> int:
76 | """The size of the observation space."""
77 | return 23
78 |
--------------------------------------------------------------------------------
/examples/crane.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mujoco
4 | from flax import nnx
5 | from hydrax.algs import PredictiveSampling
6 | from hydrax.simulation.deterministic import run_interactive as run_sampling
7 |
8 | from gpc.architectures import DenoisingCNN
9 | from gpc.envs import CraneEnv
10 | from gpc.policy import Policy
11 | from gpc.sampling import BootstrappedPredictiveSampling
12 | from gpc.testing import test_interactive
13 | from gpc.training import train
14 |
15 | if __name__ == "__main__":
16 | # Parse command line arguments
17 | parser = argparse.ArgumentParser(
18 | description="Swing the payload of a luffing crane to a target position."
19 | )
20 | subparsers = parser.add_subparsers(
21 | dest="task", help="What to do (choose one)"
22 | )
23 | subparsers.add_parser("train", help="Train (and save) a generative policy")
24 | subparsers.add_parser("test", help="Test a generative policy")
25 | subparsers.add_parser(
26 | "sample", help="Bootstrap sampling-based MPC with a generative policy"
27 | )
28 | args = parser.parse_args()
29 |
30 | # Set up the environment and save file
31 | env = CraneEnv(episode_length=500)
32 | save_file = "/tmp/crane_policy.pkl"
33 |
34 | if args.task == "train":
35 | # Train the policy and save it to a file
36 | ctrl = PredictiveSampling(env.task, num_samples=4, noise_level=0.05)
37 | net = DenoisingCNN(
38 | action_size=env.task.model.nu,
39 | observation_size=env.observation_size,
40 | horizon=env.task.planning_horizon,
41 | feature_dims=[64, 64],
42 | timestep_embedding_dim=16,
43 | rngs=nnx.Rngs(0),
44 | )
45 | policy = train(
46 | env,
47 | ctrl,
48 | net,
49 | num_policy_samples=2,
50 | log_dir="/tmp/gpc_crane",
51 | num_iters=10,
52 | num_envs=512,
53 | num_videos=4,
54 | batch_size=1024,
55 | num_epochs=20,
56 | checkpoint_every=5,
57 | strategy="policy",
58 | )
59 | policy.save(save_file)
60 | print(f"Saved policy to {save_file}")
61 |
62 | elif args.task == "test":
63 | # Load the policy from a file and test it interactively
64 | print(f"Loading policy from {save_file}")
65 | policy = Policy.load(save_file)
66 | test_interactive(env, policy)
67 |
68 | elif args.task == "sample":
69 | # Use the policy to bootstrap sampling-based MPC
70 | policy = Policy.load(save_file)
71 | ctrl = BootstrappedPredictiveSampling(
72 | policy,
73 | env.get_obs,
74 | inference_timestep=0.1,
75 | num_policy_samples=4, # samples from the flow-matching policy
76 | task=env.task,
77 | num_samples=4, # samples from a gaussian around the previous mean
78 | noise_level=0.05,
79 | )
80 | mj_model = env.task.mj_model
81 | mj_data = mujoco.MjData(mj_model)
82 | run_sampling(ctrl, mj_model, mj_data, frequency=30)
83 |
84 | else:
85 | parser.print_help()
86 |
--------------------------------------------------------------------------------
/gpc/testing.py:
--------------------------------------------------------------------------------
1 | import time
2 | from functools import partial
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | import mujoco
7 | import mujoco.viewer
8 | from mujoco import mjx
9 |
10 | from gpc.envs import TrainingEnv
11 | from gpc.policy import Policy
12 |
13 |
14 | def test_interactive(
15 | env: TrainingEnv,
16 | policy: Policy,
17 | mj_data: mujoco.MjData = None,
18 | inference_timestep: float = 0.1,
19 | warm_start_level: float = 1.0,
20 | ) -> None:
21 | """Test a GPC policy with an interactive simulation.
22 |
23 | Args:
24 | env: The environment, which defines the system to simulate.
25 | policy: The GPC policy to test.
26 | mj_data: The initial state for the simulation.
27 | inference_timestep: The timestep dt to use for flow matching inference.
28 | warm_start_level: The warm start level to use for the policy.
29 | """
30 | rng = jax.random.key(0)
31 | task = env.task
32 |
33 | # Set up the policy
34 | policy = policy.replace(dt=inference_timestep)
35 | policy.model.eval()
36 | jit_policy = jax.jit(
37 | partial(policy.apply, warm_start_level=warm_start_level)
38 | )
39 |
40 | # Set up the mujoco simultion
41 | mj_model = task.mj_model
42 | if mj_data is None:
43 | mj_data = mujoco.MjData(mj_model)
44 |
45 | # Initialize the action sequence
46 | actions = jnp.zeros((task.planning_horizon, task.model.nu))
47 |
48 | # Set up an observation function
49 | mjx_data = mjx.make_data(task.model)
50 |
51 | @jax.jit
52 | def get_obs(mjx_data: mjx.Data) -> jax.Array:
53 | """Get an observation from the mujoco data."""
54 | mjx_data = mjx.forward(task.model, mjx_data) # update sites & sensors
55 | return env.get_obs(mjx_data)
56 |
57 | # Run the simulation
58 | with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
59 | while viewer.is_running():
60 | st = time.time()
61 |
62 | # Get an observation
63 | mjx_data = mjx_data.replace(
64 | qpos=jnp.array(mj_data.qpos),
65 | qvel=jnp.array(mj_data.qvel),
66 | mocap_pos=jnp.array(mj_data.mocap_pos),
67 | mocap_quat=jnp.array(mj_data.mocap_quat),
68 | )
69 | obs = get_obs(mjx_data)
70 |
71 | # Update the action sequence
72 | inference_start = time.time()
73 | rng, policy_rng = jax.random.split(rng)
74 | actions = jit_policy(actions, obs, policy_rng)
75 | mj_data.ctrl[:] = actions[0]
76 |
77 | inference_time = time.time() - inference_start
78 | obs_time = inference_start - st
79 | print(
80 | f" Observation time: {obs_time:.5f}s "
81 | f" Inference time: {inference_time:.5f}s",
82 | end="\r",
83 | )
84 |
85 | mujoco.mj_step(mj_model, mj_data)
86 | viewer.sync()
87 |
88 | elapsed = time.time() - st
89 | if elapsed < mj_model.opt.timestep:
90 | time.sleep(mj_model.opt.timestep - elapsed)
91 |
92 | # Save what was last in the print buffer
93 | print("")
94 |
--------------------------------------------------------------------------------
/examples/humanoid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mujoco
4 | from flax import nnx
5 | from hydrax.algs import MPPI
6 | from hydrax.simulation.deterministic import run_interactive as run_sampling
7 |
8 | from gpc.architectures import DenoisingCNN
9 | from gpc.envs import HumanoidEnv
10 | from gpc.policy import Policy
11 | from gpc.sampling import BootstrappedPredictiveSampling
12 | from gpc.testing import test_interactive
13 | from gpc.training import train
14 |
15 | if __name__ == "__main__":
16 | # Parse command line arguments
17 | parser = argparse.ArgumentParser(
18 | description="Humanoid standup from arbitrary initial positions"
19 | )
20 | subparsers = parser.add_subparsers(
21 | dest="task", help="What to do (choose one)"
22 | )
23 | subparsers.add_parser("train", help="Train (and save) a generative policy")
24 | subparsers.add_parser("test", help="Test a generative policy")
25 | subparsers.add_parser(
26 | "sample", help="Bootstrap sampling-based MPC with a generative policy"
27 | )
28 | args = parser.parse_args()
29 |
30 | # Set up the environment and save file
31 | env = HumanoidEnv(episode_length=400)
32 | save_file = "/tmp/humanoid_policy.pkl"
33 |
34 | if args.task == "train":
35 | # Train the policy and save it to a file
36 | ctrl = MPPI(
37 | env.task,
38 | num_samples=32,
39 | noise_level=1.0,
40 | temperature=0.1,
41 | num_randomizations=2,
42 | )
43 | net = DenoisingCNN(
44 | action_size=env.task.model.nu,
45 | observation_size=env.observation_size,
46 | horizon=env.task.planning_horizon,
47 | feature_dims=(128,) * 3,
48 | timestep_embedding_dim=64,
49 | rngs=nnx.Rngs(0),
50 | )
51 | policy = train(
52 | env,
53 | ctrl,
54 | net,
55 | num_policy_samples=32,
56 | log_dir="/tmp/gpc_humanoid",
57 | num_epochs=10,
58 | num_iters=50,
59 | num_envs=128,
60 | num_videos=2,
61 | checkpoint_every=1,
62 | strategy="best",
63 | )
64 | policy.save(save_file)
65 | print(f"Saved policy to {save_file}")
66 |
67 | elif args.task == "test":
68 | # Load the policy from a file and test it interactively
69 | print(f"Loading policy from {save_file}")
70 | policy = Policy.load(save_file)
71 | test_interactive(env, policy)
72 |
73 | elif args.task == "sample":
74 | # Use the policy to bootstrap sampling-based MPC
75 | policy = Policy.load(save_file)
76 | ctrl = BootstrappedPredictiveSampling(
77 | policy,
78 | env.get_obs,
79 | num_policy_samples=128,
80 | warm_start_level=0.9,
81 | task=env.task,
82 | num_samples=128,
83 | noise_level=0.5,
84 | num_randomizations=2,
85 | )
86 |
87 | mj_model = env.task.mj_model
88 | mj_model.opt.timestep = 0.01
89 |
90 | mj_data = mujoco.MjData(mj_model)
91 | mj_data.qpos[3:7] = [-0.7, 0.0, 0.7, 0.0]
92 |
93 | run_sampling(ctrl, mj_model, mj_data, frequency=50, show_traces=False)
94 |
95 | else:
96 | parser.print_help()
97 |
--------------------------------------------------------------------------------
/gpc/augmented.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from flax.struct import dataclass
6 | from hydrax.alg_base import SamplingBasedController, Trajectory
7 |
8 |
9 | @dataclass
10 | class PACParams:
11 | """Parameters for the policy-augmented controller.
12 |
13 | Attributes:
14 | base_params: The parameters for the base controller.
15 | policy_samples: Control sequences sampled from the policy.
16 | rng: Random number generator key for domain randomization.
17 | """
18 |
19 | base_params: Any
20 | policy_samples: jax.Array
21 | rng: jax.Array
22 |
23 |
24 | class PolicyAugmentedController(SamplingBasedController):
25 | """An SPC generalization where samples are augmented by a learned policy."""
26 |
27 | def __init__(
28 | self,
29 | base_ctrl: SamplingBasedController,
30 | num_policy_samples: int,
31 | ) -> None:
32 | """Initialize the policy-augmented controller.
33 |
34 | Args:
35 | base_ctrl: The base controller to augment.
36 | num_policy_samples: The number of samples to draw from the policy.
37 | """
38 | self.base_ctrl = base_ctrl
39 | self.num_policy_samples = num_policy_samples
40 | super().__init__(
41 | base_ctrl.task,
42 | base_ctrl.num_randomizations,
43 | base_ctrl.risk_strategy,
44 | seed=0,
45 | )
46 |
47 | def init_params(self) -> PACParams:
48 | """Initialize the controller parameters."""
49 | base_params = self.base_ctrl.init_params()
50 | base_rng, our_rng = jax.random.split(base_params.rng)
51 | base_params = base_params.replace(rng=base_rng)
52 | policy_samples = jnp.zeros(
53 | (
54 | self.num_policy_samples,
55 | self.task.planning_horizon,
56 | self.task.model.nu,
57 | )
58 | )
59 | return PACParams(
60 | base_params=base_params,
61 | policy_samples=policy_samples,
62 | rng=our_rng,
63 | )
64 |
65 | def sample_controls(self, params: PACParams) -> Tuple[jax.Array, PACParams]:
66 | """Sample control sequences from the base controller and the policy."""
67 | # Samples from the base controller
68 | base_samples, base_params = self.base_ctrl.sample_controls(
69 | params.base_params
70 | )
71 |
72 | # Include samples from the policy. Assumes that thes have already been
73 | # generated and stored in params.policy_samples.
74 | samples = jnp.append(base_samples, params.policy_samples, axis=0)
75 |
76 | return samples, params.replace(base_params=base_params)
77 |
78 | def update_params(
79 | self, params: PACParams, rollouts: Trajectory
80 | ) -> PACParams:
81 | """Update the policy parameters according to the base controller."""
82 | base_params = self.base_ctrl.update_params(params.base_params, rollouts)
83 | return params.replace(base_params=base_params)
84 |
85 | def get_action(self, params: PACParams, t: float) -> jax.Array:
86 | """Get the action from the base controller at a given time."""
87 | return self.base_ctrl.get_action(params.base_params, t)
88 |
89 | def get_action_sequence(self, params: PACParams) -> jax.Array:
90 | """Get the action sequence from the controller."""
91 | timesteps = jnp.arange(self.task.planning_horizon) * self.task.dt
92 | return jax.vmap(self.get_action, in_axes=(None, 0))(params, timesteps)
93 |
--------------------------------------------------------------------------------
/gpc/sampling.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from hydrax.alg_base import Trajectory
6 | from hydrax.algs.predictive_sampling import PredictiveSampling
7 | from mujoco import mjx
8 |
9 | from gpc.policy import Policy
10 |
11 |
12 | class BootstrappedPredictiveSampling(PredictiveSampling):
13 | """Perform predictive sampling, but add samples from a generative policy."""
14 |
15 | def __init__(
16 | self,
17 | policy: Policy,
18 | observation_fn: Callable[[mjx.Data], jax.Array],
19 | num_policy_samples: int,
20 | warm_start_level: float = 0.0,
21 | inference_timestep: float = 0.1,
22 | **kwargs,
23 | ):
24 | """Initialize the controller.
25 |
26 | Args:
27 | policy: The generative policy to sample from.
28 | observation_fn: A function that produces an observation vector.
29 | num_policy_samples: The number of samples to take from the policy.
30 | warm_start_level: The warm start level in [0, 1] to use for the
31 | policy samples. 0.0 generates samples from scratch, while 1.0
32 | seed all samples from the previous action sequence.
33 | inference_timestep: The timestep dt for flow matching inference.
34 | **kwargs: Constructor arguments for PredictiveSampling.
35 | """
36 | self.observation_fn = observation_fn
37 | self.policy = policy.replace(dt=inference_timestep)
38 | self.policy.model.eval() # Don't update batch statistics
39 | self.warm_start_level = jnp.clip(warm_start_level, 0.0, 1.0)
40 | self.num_policy_samples = num_policy_samples
41 |
42 | super().__init__(**kwargs)
43 |
44 | def optimize(self, state: mjx.Data, params: Any) -> Tuple[Any, Trajectory]:
45 | """Perform an optimization step to update the policy parameters.
46 |
47 | In addition to sampling random control sequences, also sample control
48 | sequences from the generative policy.
49 |
50 | Args:
51 | state: The initial state x₀.
52 | params: The current policy parameters, U ~ π(params).
53 |
54 | Returns:
55 | Updated policy parameters
56 | Rollouts used to update the parameters
57 | """
58 | rng, policy_rng, dr_rng = jax.random.split(params.rng, 3)
59 |
60 | # Sample random control sequences
61 | controls, params = self.sample_controls(params)
62 | controls = jnp.clip(controls, self.task.u_min, self.task.u_max)
63 |
64 | # Update sensor readings and get an observation
65 | state = mjx.forward(self.task.model, state)
66 | y = self.observation_fn(state)
67 |
68 | # Sample from the generative policy, which is conditioned on the latest
69 | # observation.
70 | policy_rngs = jax.random.split(policy_rng, self.num_policy_samples)
71 | policy_controls = jax.vmap(
72 | self.policy.apply, in_axes=(None, None, 0, None)
73 | )(
74 | params.mean,
75 | y,
76 | policy_rngs,
77 | self.warm_start_level,
78 | )
79 |
80 | # Combine the random and policy samples
81 | controls = jnp.concatenate([controls, policy_controls], axis=0)
82 |
83 | # Roll out the control sequences, applying domain randomizations and
84 | # combining costs using self.risk_strategy.
85 | rollouts = self.rollout_with_randomizations(state, controls, dr_rng)
86 |
87 | # Update the policy parameters based on the combined costs
88 | params = params.replace(rng=rng)
89 | params = self.update_params(params, rollouts)
90 | return params, rollouts
91 |
--------------------------------------------------------------------------------
/gpc/policy.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Tuple, Union
3 |
4 | import cloudpickle
5 | import jax
6 | import jax.numpy as jnp
7 | from flax import nnx
8 | from flax.struct import dataclass
9 |
10 |
11 | @dataclass
12 | class Policy:
13 | """A pickle-able Generative Predictive Control policy.
14 |
15 | Generates action sequences using flow matching, conditioned on the latest
16 | observation, e.g., samples U = [u_0, u_1, ...] ~ p(U | y).
17 |
18 | Attributes:
19 | model: The flow matching network that generates the action sequence.
20 | normalizer: Observation normalization module.
21 | u_min: The minimum action values.
22 | u_max: The maximum action values.
23 | dt: The integration step size for flow matching.
24 | """
25 |
26 | model: nnx.Module
27 | normalizer: nnx.BatchNorm
28 | u_min: jax.Array
29 | u_max: jax.Array
30 | dt: float = 0.1
31 |
32 | def save(self, path: Union[Path, str]) -> None:
33 | """Save the policy to a file.
34 |
35 | Args:
36 | path: The path to save the policy to.
37 | """
38 | with open(path, "wb") as f:
39 | cloudpickle.dump(self, f)
40 |
41 | @staticmethod
42 | def load(path: Union[Path, str]) -> "Policy":
43 | """Load a policy from a file.
44 |
45 | Args:
46 | path: The path to load the policy from.
47 |
48 | Returns:
49 | The loaded policy instance
50 | """
51 | with open(path, "rb") as f:
52 | policy = cloudpickle.load(f)
53 | return policy
54 |
55 | def apply(
56 | self,
57 | prev: jax.Array,
58 | y: jax.Array,
59 | rng: jax.Array,
60 | warm_start_level: float = 0.0,
61 | ) -> jax.Array:
62 | """Generate an action sequence conditioned on the observation.
63 |
64 | Args:
65 | prev: The previous action sequence.
66 | y: The current observation.
67 | rng: The random number generator key.
68 | warm_start_level: The degree of warm-starting to use, in [0, 1].
69 |
70 | A warm-start level of 0.0 means the action sequence is generated from
71 | scratch, with the seed for flow matching drawn from a random normal
72 | distribution. A warm-start level of 1.0 means the seed is the previous
73 | action sequence. Values in between interpolate between these two, with
74 | larger values giving smoother but less exploratory action sequences.
75 |
76 | Returns:
77 | The updated action sequence
78 | """
79 | # Normalize the observation, but don't update the stored statistics
80 | y = self.normalizer(y, use_running_average=True)
81 |
82 | # Set the initial sample
83 | warm_start_level = jnp.clip(warm_start_level, 0.0, 1.0)
84 | noise = jax.random.normal(rng, prev.shape)
85 | U = warm_start_level * prev + (1 - warm_start_level) * noise
86 |
87 | def _step(args: Tuple[jax.Array, float]) -> Tuple[jax.Array, float]:
88 | """Flow the sample U along the learned vector field."""
89 | U, t = args
90 | U += self.dt * self.model(U, y, t)
91 | U = jax.numpy.clip(U, -1, 1)
92 | return U, t + self.dt
93 |
94 | # While t < 1, U += dt * model(U, y, t)
95 | U, t = jax.lax.while_loop(
96 | lambda args: jnp.all(args[1] < 1.0),
97 | _step,
98 | (U, jnp.zeros(1)),
99 | )
100 |
101 | # Rescale actions from [-1, 1] to [u_min, u_max]
102 | mean = (self.u_max + self.u_min) / 2
103 | scale = (self.u_max - self.u_min) / 2
104 | U = U * scale + mean
105 |
106 | return U
107 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Generative Predictive Control
2 |
3 | This repository contains code for the paper ["Generative Predictive Control: Flow
4 | Matching Policies for Dynamic and Difficult-to-Demonstrate Tasks"](https://arxiv.org/abs/2502.13406)
5 | by Vince Kurtz and Joel Burdick. [Video summary](https://youtu.be/mjL7CF877Ow).
6 |
7 | This includes code for training and testing flow-matching policies on each of
8 | the robot systems shown below:
9 |
10 | [
](examples/cart_pole.py)
11 | [
](examples/double_cart_pole.py)
12 | [
](examples/pusht.py)
13 | [
](examples/walker.py)
14 | [
](examples/crane.py)
15 | [
](examples/humanoid.py)
16 |
17 | Generative Predictive Control (GPC) is a supervised learning framework for
18 | training flow-matching policies on tasks that are difficult to demonstrate but
19 | easy to simulate. GPC alternates between generating training data with
20 | [sampling-based predictive control](https://github.com/vincekurtz/hydrax),
21 | fitting a generative model to the data, and using the generative model to
22 | improve the sampling distribution.
23 |
24 |
26 |