├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── assets ├── env_teaser.png └── ogbench.svg ├── data_gen_scripts ├── commands.sh ├── generate_antsoccer.py ├── generate_locomaze.py ├── generate_manipspace.py ├── generate_powderworld.py ├── main_sac.py ├── online_env_utils.py └── viz_utils.py ├── impls ├── agents │ ├── __init__.py │ ├── crl.py │ ├── gcbc.py │ ├── gciql.py │ ├── gcivl.py │ ├── hiql.py │ ├── qrl.py │ └── sac.py ├── hyperparameters.sh ├── main.py ├── requirements.txt └── utils │ ├── __init__.py │ ├── datasets.py │ ├── encoders.py │ ├── env_utils.py │ ├── evaluation.py │ ├── flax_utils.py │ ├── log_utils.py │ └── networks.py ├── ogbench ├── __init__.py ├── locomaze │ ├── __init__.py │ ├── ant.py │ ├── assets │ │ ├── ant.xml │ │ ├── humanoid.xml │ │ └── point.xml │ ├── humanoid.py │ ├── maze.py │ └── point.py ├── manipspace │ ├── __init__.py │ ├── controllers │ │ ├── __init__.py │ │ └── diff_ik.py │ ├── descriptions │ │ ├── button_inner.xml │ │ ├── button_outer.xml │ │ ├── buttons.xml │ │ ├── cube.xml │ │ ├── cube_inner.xml │ │ ├── cube_outer.xml │ │ ├── drawer.xml │ │ ├── floor_wall.xml │ │ ├── metaworld │ │ │ ├── button │ │ │ │ ├── button.stl │ │ │ │ ├── buttonring.stl │ │ │ │ ├── metal1.png │ │ │ │ ├── stopbot.stl │ │ │ │ ├── stopbutton.stl │ │ │ │ ├── stopbuttonrim.stl │ │ │ │ ├── stopbuttonrod.stl │ │ │ │ └── stoptop.stl │ │ │ ├── drawer │ │ │ │ ├── drawer.stl │ │ │ │ ├── drawercase.stl │ │ │ │ └── drawerhandle.stl │ │ │ └── window │ │ │ │ ├── window_base.stl │ │ │ │ ├── window_frame.stl │ │ │ │ ├── window_h_base.stl │ │ │ │ ├── window_h_frame.stl │ │ │ │ ├── windowa_frame.stl │ │ │ │ ├── windowa_glass.stl │ │ │ │ ├── windowa_h_frame.stl │ │ │ │ ├── windowa_h_glass.stl │ │ │ │ ├── windowb_frame.stl │ │ │ │ ├── windowb_glass.stl │ │ │ │ ├── windowb_h_frame.stl │ │ │ │ └── windowb_h_glass.stl │ │ ├── robotiq_2f85 │ │ │ ├── 2f85.png │ │ │ ├── 2f85.xml │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── assets │ │ │ │ ├── base.stl │ │ │ │ ├── base_mount.stl │ │ │ │ ├── coupler.stl │ │ │ │ ├── driver.stl │ │ │ │ ├── follower.stl │ │ │ │ ├── pad.stl │ │ │ │ ├── silicone_pad.stl │ │ │ │ └── spring_link.stl │ │ │ └── scene.xml │ │ ├── universal_robots_ur5e │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── assets │ │ │ │ ├── base_0.obj │ │ │ │ ├── base_1.obj │ │ │ │ ├── forearm_0.obj │ │ │ │ ├── forearm_1.obj │ │ │ │ ├── forearm_2.obj │ │ │ │ ├── forearm_3.obj │ │ │ │ ├── shoulder_0.obj │ │ │ │ ├── shoulder_1.obj │ │ │ │ ├── shoulder_2.obj │ │ │ │ ├── upperarm_0.obj │ │ │ │ ├── upperarm_1.obj │ │ │ │ ├── upperarm_2.obj │ │ │ │ ├── upperarm_3.obj │ │ │ │ ├── wrist1_0.obj │ │ │ │ ├── wrist1_1.obj │ │ │ │ ├── wrist1_2.obj │ │ │ │ ├── wrist2_0.obj │ │ │ │ ├── wrist2_1.obj │ │ │ │ ├── wrist2_2.obj │ │ │ │ └── wrist3.obj │ │ │ ├── scene.xml │ │ │ ├── ur5e.png │ │ │ └── ur5e.xml │ │ └── window.xml │ ├── envs │ │ ├── __init__.py │ │ ├── cube_env.py │ │ ├── env.py │ │ ├── manipspace_env.py │ │ ├── puzzle_env.py │ │ └── scene_env.py │ ├── lie │ │ ├── __init__.py │ │ ├── se3.py │ │ ├── so3.py │ │ └── utils.py │ ├── mjcf_utils.py │ ├── oracles │ │ ├── __init__.py │ │ ├── markov │ │ │ ├── __init__.py │ │ │ ├── button_markov.py │ │ │ ├── cube_markov.py │ │ │ ├── drawer_markov.py │ │ │ ├── markov_oracle.py │ │ │ └── window_markov.py │ │ └── plan │ │ │ ├── __init__.py │ │ │ ├── button_plan.py │ │ │ ├── cube_plan.py │ │ │ ├── drawer_plan.py │ │ │ ├── plan_oracle.py │ │ │ └── window_plan.py │ └── viewer_utils.py ├── online_locomotion │ ├── __init__.py │ ├── ant.py │ ├── ant_ball.py │ ├── assets │ │ ├── ant.xml │ │ └── humanoid.xml │ ├── humanoid.py │ └── wrappers.py ├── powderworld │ ├── __init__.py │ ├── behaviors.py │ ├── powderworld_env.py │ └── sim.py ├── relabel_utils.py └── utils.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | dist/ 3 | *.py[cod] 4 | *$py.class 5 | *.egg-info/ 6 | .DS_Store 7 | .idea/ 8 | .ruff_cache/ 9 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change log 2 | 3 | ## ogbench 1.1.3 (2025-06-03) 4 | - Add the `cube-octuple` task. 5 | 6 | ## ogbench 1.1.2 (2025-03-30) 7 | - Improve compatibility with `gymnasium`. 8 | 9 | ## ogbench 1.1.1 (2025-03-02) 10 | - Make it compatible with the latest version of `gymnasium` (1.1.0). 11 | 12 | ## ogbench 1.1.0 (2025-02-13) 13 | - Added `-singletask` environments for standard (i.e., non-goal-conditioned) offline RL. 14 | - Added `-oraclerep` environments for offline goal-conditioned RL with oracle goal representations. 15 | 16 | ## ogbench 1.0.1 (2024-10-28) 17 | - Fixed a bug in the reward function of manipulation tasks. 18 | 19 | ## ogbench 1.0.0 (2024-10-25) 20 | - Initial release. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2024 OGBench Authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/env_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/assets/env_teaser.png -------------------------------------------------------------------------------- /assets/ogbench.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /data_gen_scripts/generate_powderworld.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from collections import defaultdict 3 | 4 | import gymnasium 5 | import numpy as np 6 | from absl import app, flags 7 | from tqdm import trange 8 | 9 | import ogbench.powderworld # noqa 10 | from ogbench.powderworld.behaviors import FillBehavior, LineBehavior, SquareBehavior 11 | 12 | FLAGS = flags.FLAGS 13 | 14 | flags.DEFINE_integer('seed', 0, 'Random seed.') 15 | flags.DEFINE_string('env_name', 'powderworld-v0', 'Environment name.') 16 | flags.DEFINE_string('dataset_type', 'play', 'Dataset type.') 17 | flags.DEFINE_string('save_path', None, 'Save path.') 18 | flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes.') 19 | flags.DEFINE_integer('max_episode_steps', 1001, 'Maximum number of steps in an episode.') 20 | flags.DEFINE_float('p_random_action', 0.5, 'Probability of selecting a random action.') 21 | 22 | 23 | def main(_): 24 | assert FLAGS.dataset_type in ['play'] 25 | 26 | # Initialize environment. 27 | env = gymnasium.make( 28 | FLAGS.env_name, 29 | mode='data_collection', 30 | max_episode_steps=FLAGS.max_episode_steps, 31 | ) 32 | env.reset() 33 | 34 | # Initialize agents. 35 | agents = [ 36 | FillBehavior(env=env), 37 | LineBehavior(env=env), 38 | SquareBehavior(env=env), 39 | ] 40 | probs = np.array([1, 3, 3]) # Agent selection probabilities. 41 | probs = probs / probs.sum() 42 | 43 | # Collect data. 44 | dataset = defaultdict(list) 45 | total_steps = 0 46 | total_train_steps = 0 47 | num_train_episodes = FLAGS.num_episodes 48 | num_val_episodes = FLAGS.num_episodes // 10 49 | for ep_idx in trange(num_train_episodes + num_val_episodes): 50 | ob, info = env.reset() 51 | agent = np.random.choice(agents, p=probs) 52 | agent.reset(ob, info) 53 | 54 | done = False 55 | step = 0 56 | 57 | action_step = 0 # Action cycle counter (0, 1, 2). 58 | while not done: 59 | if action_step == 0: 60 | # Select an action every 3 steps. 61 | if np.random.rand() < FLAGS.p_random_action: 62 | # Sample a random action. 63 | semantic_action = env.unwrapped.sample_semantic_action() 64 | else: 65 | # Get an action from the agent. 66 | semantic_action = agent.select_action(ob, info) 67 | action = env.unwrapped.semantic_action_to_action(*semantic_action) 68 | next_ob, reward, terminated, truncated, info = env.step(action) 69 | done = terminated or truncated 70 | 71 | if agent.done and FLAGS.dataset_type == 'play': 72 | agent = np.random.choice(agents, p=probs) 73 | agent.reset(ob, info) 74 | 75 | dataset['observations'].append(ob) 76 | dataset['actions'].append(action) 77 | dataset['terminals'].append(done) 78 | 79 | ob = next_ob 80 | step += 1 81 | action_step = (action_step + 1) % 3 82 | 83 | total_steps += step 84 | if ep_idx < num_train_episodes: 85 | total_train_steps += step 86 | 87 | print('Total steps:', total_steps) 88 | 89 | train_path = FLAGS.save_path 90 | val_path = FLAGS.save_path.replace('.npz', '-val.npz') 91 | pathlib.Path(train_path).parent.mkdir(parents=True, exist_ok=True) 92 | 93 | # Split the dataset into training and validation sets. 94 | train_dataset = {} 95 | val_dataset = {} 96 | for k, v in dataset.items(): 97 | if 'observations' in k and v[0].dtype == np.uint8: 98 | dtype = np.uint8 99 | elif 'actions': 100 | dtype = np.int32 101 | elif k == 'terminals': 102 | dtype = bool 103 | else: 104 | dtype = np.float32 105 | train_dataset[k] = np.array(v[:total_train_steps], dtype=dtype) 106 | val_dataset[k] = np.array(v[total_train_steps:], dtype=dtype) 107 | 108 | for path, dataset in [(train_path, train_dataset), (val_path, val_dataset)]: 109 | np.savez_compressed(path, **dataset) 110 | 111 | 112 | if __name__ == '__main__': 113 | app.run(main) 114 | -------------------------------------------------------------------------------- /data_gen_scripts/online_env_utils.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | from utils.env_utils import EpisodeMonitor 3 | 4 | 5 | def make_online_env(env_name): 6 | """Make online environment. 7 | 8 | If the environment name contains the '-xy' suffix, the environment will be wrapped with a directional locomotion 9 | wrapper. For example, 'online-ant-xy-v0' will return an 'online-ant-v0' environment wrapped with GymXYWrapper. 10 | 11 | Args: 12 | env_name: Name of the environment. 13 | """ 14 | import ogbench.online_locomotion # noqa 15 | 16 | # Manually recognize the '-xy' suffix, which indicates that the environment should be wrapped with a directional 17 | # locomotion wrapper. 18 | if '-xy' in env_name: 19 | env_name = env_name.replace('-xy', '') 20 | apply_xy_wrapper = True 21 | else: 22 | apply_xy_wrapper = False 23 | 24 | # Set camera. 25 | if 'humanoid' in env_name: 26 | extra_kwargs = dict(camera_id=0) 27 | else: 28 | extra_kwargs = dict() 29 | 30 | # Make environment. 31 | env = gymnasium.make(env_name, render_mode='rgb_array', height=200, width=200, **extra_kwargs) 32 | 33 | if apply_xy_wrapper: 34 | # Apply the directional locomotion wrapper. 35 | from ogbench.online_locomotion.wrappers import DMCHumanoidXYWrapper, GymXYWrapper 36 | 37 | if 'humanoid' in env_name: 38 | env = DMCHumanoidXYWrapper(env, resample_interval=200) 39 | else: 40 | env = GymXYWrapper(env, resample_interval=100) 41 | 42 | env = EpisodeMonitor(env) 43 | 44 | return env 45 | -------------------------------------------------------------------------------- /data_gen_scripts/viz_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | from matplotlib import figure 4 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 5 | 6 | 7 | def get_2d_colors(points, min_point, max_point): 8 | """Get colors corresponding to 2-D points.""" 9 | points = np.array(points) 10 | min_point = np.array(min_point) 11 | max_point = np.array(max_point) 12 | 13 | colors = (points - min_point) / (max_point - min_point) 14 | colors = np.hstack((colors, (2 - np.sum(colors, axis=1, keepdims=True)) / 2)) 15 | colors = np.clip(colors, 0, 1) 16 | colors = np.c_[colors, np.full(len(colors), 0.8)] 17 | 18 | return colors 19 | 20 | 21 | def visualize_trajs(env_name, trajs): 22 | """Visualize x-y trajectories in locomotion environments. 23 | 24 | It reads 'xy' and 'direction' from the 'info' field of the trajectories. 25 | """ 26 | matplotlib.use('Agg') 27 | 28 | fig = figure.Figure(tight_layout=True) 29 | canvas = FigureCanvas(fig) 30 | if 'xy' in trajs[0]['info'][0]: 31 | ax = fig.add_subplot() 32 | 33 | max_xy = 0.0 34 | for traj in trajs: 35 | xy = np.array([info['xy'] for info in traj['info']]) 36 | direction = np.array([info['direction'] for info in traj['info']]) 37 | color = get_2d_colors(direction, [-1, -1], [1, 1]) 38 | for i in range(len(xy) - 1): 39 | ax.plot(xy[i : i + 2, 0], xy[i : i + 2, 1], color=color[i], linewidth=0.7) 40 | max_xy = max(max_xy, np.abs(xy).max() * 1.2) 41 | 42 | plot_axis = [-max_xy, max_xy, -max_xy, max_xy] 43 | ax.axis(plot_axis) 44 | ax.set_aspect('equal') 45 | else: 46 | return None 47 | 48 | fig.tight_layout() 49 | canvas.draw() 50 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') 51 | out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 52 | return out_image 53 | -------------------------------------------------------------------------------- /impls/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from agents.crl import CRLAgent 2 | from agents.gcbc import GCBCAgent 3 | from agents.gciql import GCIQLAgent 4 | from agents.gcivl import GCIVLAgent 5 | from agents.hiql import HIQLAgent 6 | from agents.qrl import QRLAgent 7 | from agents.sac import SACAgent 8 | 9 | agents = dict( 10 | crl=CRLAgent, 11 | gcbc=GCBCAgent, 12 | gciql=GCIQLAgent, 13 | gcivl=GCIVLAgent, 14 | hiql=HIQLAgent, 15 | qrl=QRLAgent, 16 | sac=SACAgent, 17 | ) 18 | -------------------------------------------------------------------------------- /impls/agents/gcbc.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import flax 4 | import jax 5 | import jax.numpy as jnp 6 | import ml_collections 7 | import optax 8 | from utils.encoders import GCEncoder, encoder_modules 9 | from utils.flax_utils import ModuleDict, TrainState, nonpytree_field 10 | from utils.networks import GCActor, GCDiscreteActor 11 | 12 | 13 | class GCBCAgent(flax.struct.PyTreeNode): 14 | """Goal-conditioned behavioral cloning (GCBC) agent.""" 15 | 16 | rng: Any 17 | network: Any 18 | config: Any = nonpytree_field() 19 | 20 | def actor_loss(self, batch, grad_params, rng=None): 21 | """Compute the BC actor loss.""" 22 | dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params) 23 | log_prob = dist.log_prob(batch['actions']) 24 | 25 | actor_loss = -log_prob.mean() 26 | 27 | actor_info = { 28 | 'actor_loss': actor_loss, 29 | 'bc_log_prob': log_prob.mean(), 30 | } 31 | if not self.config['discrete']: 32 | actor_info.update( 33 | { 34 | 'mse': jnp.mean((dist.mode() - batch['actions']) ** 2), 35 | 'std': jnp.mean(dist.scale_diag), 36 | } 37 | ) 38 | 39 | return actor_loss, actor_info 40 | 41 | @jax.jit 42 | def total_loss(self, batch, grad_params, rng=None): 43 | """Compute the total loss.""" 44 | info = {} 45 | rng = rng if rng is not None else self.rng 46 | 47 | rng, actor_rng = jax.random.split(rng) 48 | actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng) 49 | for k, v in actor_info.items(): 50 | info[f'actor/{k}'] = v 51 | 52 | loss = actor_loss 53 | return loss, info 54 | 55 | @jax.jit 56 | def update(self, batch): 57 | """Update the agent and return a new agent with information dictionary.""" 58 | new_rng, rng = jax.random.split(self.rng) 59 | 60 | def loss_fn(grad_params): 61 | return self.total_loss(batch, grad_params, rng=rng) 62 | 63 | new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn) 64 | 65 | return self.replace(network=new_network, rng=new_rng), info 66 | 67 | @jax.jit 68 | def sample_actions( 69 | self, 70 | observations, 71 | goals=None, 72 | seed=None, 73 | temperature=1.0, 74 | ): 75 | """Sample actions from the actor.""" 76 | dist = self.network.select('actor')(observations, goals, temperature=temperature) 77 | actions = dist.sample(seed=seed) 78 | if not self.config['discrete']: 79 | actions = jnp.clip(actions, -1, 1) 80 | return actions 81 | 82 | @classmethod 83 | def create( 84 | cls, 85 | seed, 86 | ex_observations, 87 | ex_actions, 88 | config, 89 | ): 90 | """Create a new agent. 91 | 92 | Args: 93 | seed: Random seed. 94 | ex_observations: Example batch of observations. 95 | ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value. 96 | config: Configuration dictionary. 97 | """ 98 | rng = jax.random.PRNGKey(seed) 99 | rng, init_rng = jax.random.split(rng, 2) 100 | 101 | ex_goals = ex_observations 102 | if config['discrete']: 103 | action_dim = ex_actions.max() + 1 104 | else: 105 | action_dim = ex_actions.shape[-1] 106 | 107 | # Define encoder. 108 | encoders = dict() 109 | if config['encoder'] is not None: 110 | encoder_module = encoder_modules[config['encoder']] 111 | encoders['actor'] = GCEncoder(concat_encoder=encoder_module()) 112 | 113 | # Define actor network. 114 | if config['discrete']: 115 | actor_def = GCDiscreteActor( 116 | hidden_dims=config['actor_hidden_dims'], 117 | action_dim=action_dim, 118 | gc_encoder=encoders.get('actor'), 119 | ) 120 | else: 121 | actor_def = GCActor( 122 | hidden_dims=config['actor_hidden_dims'], 123 | action_dim=action_dim, 124 | state_dependent_std=False, 125 | const_std=config['const_std'], 126 | gc_encoder=encoders.get('actor'), 127 | ) 128 | 129 | network_info = dict( 130 | actor=(actor_def, (ex_observations, ex_goals)), 131 | ) 132 | networks = {k: v[0] for k, v in network_info.items()} 133 | network_args = {k: v[1] for k, v in network_info.items()} 134 | 135 | network_def = ModuleDict(networks) 136 | network_tx = optax.adam(learning_rate=config['lr']) 137 | network_params = network_def.init(init_rng, **network_args)['params'] 138 | network = TrainState.create(network_def, network_params, tx=network_tx) 139 | 140 | return cls(rng, network=network, config=flax.core.FrozenDict(**config)) 141 | 142 | 143 | def get_config(): 144 | config = ml_collections.ConfigDict( 145 | dict( 146 | # Agent hyperparameters. 147 | agent_name='gcbc', # Agent name. 148 | lr=3e-4, # Learning rate. 149 | batch_size=1024, # Batch size. 150 | actor_hidden_dims=(512, 512, 512), # Actor network hidden dimensions. 151 | discount=0.99, # Discount factor (unused by default; can be used for geometric goal sampling in GCDataset). 152 | const_std=True, # Whether to use constant standard deviation for the actor. 153 | discrete=False, # Whether the action space is discrete. 154 | encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.). 155 | # Dataset hyperparameters. 156 | dataset_class='GCDataset', # Dataset class name. 157 | value_p_curgoal=0.0, # Unused (defined for compatibility with GCDataset). 158 | value_p_trajgoal=1.0, # Unused (defined for compatibility with GCDataset). 159 | value_p_randomgoal=0.0, # Unused (defined for compatibility with GCDataset). 160 | value_geom_sample=False, # Unused (defined for compatibility with GCDataset). 161 | actor_p_curgoal=0.0, # Probability of using the current state as the actor goal. 162 | actor_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the actor goal. 163 | actor_p_randomgoal=0.0, # Probability of using a random state as the actor goal. 164 | actor_geom_sample=False, # Whether to use geometric sampling for future actor goals. 165 | gc_negative=True, # Unused (defined for compatibility with GCDataset). 166 | p_aug=0.0, # Probability of applying image augmentation. 167 | frame_stack=ml_collections.config_dict.placeholder(int), # Number of frames to stack. 168 | ) 169 | ) 170 | return config 171 | -------------------------------------------------------------------------------- /impls/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import time 5 | from collections import defaultdict 6 | 7 | import jax 8 | import numpy as np 9 | import tqdm 10 | import wandb 11 | from absl import app, flags 12 | from agents import agents 13 | from ml_collections import config_flags 14 | from utils.datasets import Dataset, GCDataset, HGCDataset 15 | from utils.env_utils import make_env_and_datasets 16 | from utils.evaluation import evaluate 17 | from utils.flax_utils import restore_agent, save_agent 18 | from utils.log_utils import CsvLogger, get_exp_name, get_flag_dict, get_wandb_video, setup_wandb 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | flags.DEFINE_string('run_group', 'Debug', 'Run group.') 23 | flags.DEFINE_integer('seed', 0, 'Random seed.') 24 | flags.DEFINE_string('env_name', 'antmaze-large-navigate-v0', 'Environment (dataset) name.') 25 | flags.DEFINE_string('save_dir', 'exp/', 'Save directory.') 26 | flags.DEFINE_string('restore_path', None, 'Restore path.') 27 | flags.DEFINE_integer('restore_epoch', None, 'Restore epoch.') 28 | 29 | flags.DEFINE_integer('train_steps', 1000000, 'Number of training steps.') 30 | flags.DEFINE_integer('log_interval', 5000, 'Logging interval.') 31 | flags.DEFINE_integer('eval_interval', 100000, 'Evaluation interval.') 32 | flags.DEFINE_integer('save_interval', 1000000, 'Saving interval.') 33 | 34 | flags.DEFINE_integer('eval_tasks', None, 'Number of tasks to evaluate (None for all).') 35 | flags.DEFINE_integer('eval_episodes', 20, 'Number of episodes for each task.') 36 | flags.DEFINE_float('eval_temperature', 0, 'Actor temperature for evaluation.') 37 | flags.DEFINE_float('eval_gaussian', None, 'Action Gaussian noise for evaluation.') 38 | flags.DEFINE_integer('video_episodes', 1, 'Number of video episodes for each task.') 39 | flags.DEFINE_integer('video_frame_skip', 3, 'Frame skip for videos.') 40 | flags.DEFINE_integer('eval_on_cpu', 1, 'Whether to evaluate on CPU.') 41 | 42 | config_flags.DEFINE_config_file('agent', 'agents/gciql.py', lock_config=False) 43 | 44 | 45 | def main(_): 46 | # Set up logger. 47 | exp_name = get_exp_name(FLAGS.seed) 48 | setup_wandb(project='OGBench', group=FLAGS.run_group, name=exp_name) 49 | 50 | FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, FLAGS.run_group, exp_name) 51 | os.makedirs(FLAGS.save_dir, exist_ok=True) 52 | flag_dict = get_flag_dict() 53 | with open(os.path.join(FLAGS.save_dir, 'flags.json'), 'w') as f: 54 | json.dump(flag_dict, f) 55 | 56 | # Set up environment and dataset. 57 | config = FLAGS.agent 58 | env, train_dataset, val_dataset = make_env_and_datasets(FLAGS.env_name, frame_stack=config['frame_stack']) 59 | 60 | dataset_class = { 61 | 'GCDataset': GCDataset, 62 | 'HGCDataset': HGCDataset, 63 | }[config['dataset_class']] 64 | train_dataset = dataset_class(Dataset.create(**train_dataset), config) 65 | if val_dataset is not None: 66 | val_dataset = dataset_class(Dataset.create(**val_dataset), config) 67 | 68 | # Initialize agent. 69 | random.seed(FLAGS.seed) 70 | np.random.seed(FLAGS.seed) 71 | 72 | example_batch = train_dataset.sample(1) 73 | if config['discrete']: 74 | # Fill with the maximum action to let the agent know the action space size. 75 | example_batch['actions'] = np.full_like(example_batch['actions'], env.action_space.n - 1) 76 | 77 | agent_class = agents[config['agent_name']] 78 | agent = agent_class.create( 79 | FLAGS.seed, 80 | example_batch['observations'], 81 | example_batch['actions'], 82 | config, 83 | ) 84 | 85 | # Restore agent. 86 | if FLAGS.restore_path is not None: 87 | agent = restore_agent(agent, FLAGS.restore_path, FLAGS.restore_epoch) 88 | 89 | # Train agent. 90 | train_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'train.csv')) 91 | eval_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'eval.csv')) 92 | first_time = time.time() 93 | last_time = time.time() 94 | for i in tqdm.tqdm(range(1, FLAGS.train_steps + 1), smoothing=0.1, dynamic_ncols=True): 95 | # Update agent. 96 | batch = train_dataset.sample(config['batch_size']) 97 | agent, update_info = agent.update(batch) 98 | 99 | # Log metrics. 100 | if i % FLAGS.log_interval == 0: 101 | train_metrics = {f'training/{k}': v for k, v in update_info.items()} 102 | if val_dataset is not None: 103 | val_batch = val_dataset.sample(config['batch_size']) 104 | _, val_info = agent.total_loss(val_batch, grad_params=None) 105 | train_metrics.update({f'validation/{k}': v for k, v in val_info.items()}) 106 | train_metrics['time/epoch_time'] = (time.time() - last_time) / FLAGS.log_interval 107 | train_metrics['time/total_time'] = time.time() - first_time 108 | last_time = time.time() 109 | wandb.log(train_metrics, step=i) 110 | train_logger.log(train_metrics, step=i) 111 | 112 | # Evaluate agent. 113 | if i == 1 or i % FLAGS.eval_interval == 0: 114 | if FLAGS.eval_on_cpu: 115 | eval_agent = jax.device_put(agent, device=jax.devices('cpu')[0]) 116 | else: 117 | eval_agent = agent 118 | renders = [] 119 | eval_metrics = {} 120 | overall_metrics = defaultdict(list) 121 | task_infos = env.unwrapped.task_infos if hasattr(env.unwrapped, 'task_infos') else env.task_infos 122 | num_tasks = FLAGS.eval_tasks if FLAGS.eval_tasks is not None else len(task_infos) 123 | for task_id in tqdm.trange(1, num_tasks + 1): 124 | task_name = task_infos[task_id - 1]['task_name'] 125 | eval_info, trajs, cur_renders = evaluate( 126 | agent=eval_agent, 127 | env=env, 128 | task_id=task_id, 129 | config=config, 130 | num_eval_episodes=FLAGS.eval_episodes, 131 | num_video_episodes=FLAGS.video_episodes, 132 | video_frame_skip=FLAGS.video_frame_skip, 133 | eval_temperature=FLAGS.eval_temperature, 134 | eval_gaussian=FLAGS.eval_gaussian, 135 | ) 136 | renders.extend(cur_renders) 137 | metric_names = ['success'] 138 | eval_metrics.update( 139 | {f'evaluation/{task_name}_{k}': v for k, v in eval_info.items() if k in metric_names} 140 | ) 141 | for k, v in eval_info.items(): 142 | if k in metric_names: 143 | overall_metrics[k].append(v) 144 | for k, v in overall_metrics.items(): 145 | eval_metrics[f'evaluation/overall_{k}'] = np.mean(v) 146 | 147 | if FLAGS.video_episodes > 0: 148 | video = get_wandb_video(renders=renders, n_cols=num_tasks) 149 | eval_metrics['video'] = video 150 | 151 | wandb.log(eval_metrics, step=i) 152 | eval_logger.log(eval_metrics, step=i) 153 | 154 | # Save agent. 155 | if i % FLAGS.save_interval == 0: 156 | save_agent(agent, FLAGS.save_dir, i) 157 | 158 | train_logger.close() 159 | eval_logger.close() 160 | 161 | 162 | if __name__ == '__main__': 163 | app.run(main) 164 | -------------------------------------------------------------------------------- /impls/requirements.txt: -------------------------------------------------------------------------------- 1 | ogbench # Use the PyPI version of OGBench. Replace this with `pip install -e .` if you want to use the local version. 2 | jax[cuda12] >= 0.4.26 3 | flax >= 0.8.4 4 | distrax >= 0.1.5 5 | ml_collections 6 | matplotlib 7 | moviepy 8 | wandb -------------------------------------------------------------------------------- /impls/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/impls/utils/__init__.py -------------------------------------------------------------------------------- /impls/utils/encoders.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Sequence 3 | 4 | import flax.linen as nn 5 | import jax.numpy as jnp 6 | 7 | from utils.networks import MLP 8 | 9 | 10 | class ResnetStack(nn.Module): 11 | """ResNet stack module.""" 12 | 13 | num_features: int 14 | num_blocks: int 15 | max_pooling: bool = True 16 | 17 | @nn.compact 18 | def __call__(self, x): 19 | initializer = nn.initializers.xavier_uniform() 20 | conv_out = nn.Conv( 21 | features=self.num_features, 22 | kernel_size=(3, 3), 23 | strides=1, 24 | kernel_init=initializer, 25 | padding='SAME', 26 | )(x) 27 | 28 | if self.max_pooling: 29 | conv_out = nn.max_pool( 30 | conv_out, 31 | window_shape=(3, 3), 32 | padding='SAME', 33 | strides=(2, 2), 34 | ) 35 | 36 | for _ in range(self.num_blocks): 37 | block_input = conv_out 38 | conv_out = nn.relu(conv_out) 39 | conv_out = nn.Conv( 40 | features=self.num_features, 41 | kernel_size=(3, 3), 42 | strides=1, 43 | padding='SAME', 44 | kernel_init=initializer, 45 | )(conv_out) 46 | 47 | conv_out = nn.relu(conv_out) 48 | conv_out = nn.Conv( 49 | features=self.num_features, 50 | kernel_size=(3, 3), 51 | strides=1, 52 | padding='SAME', 53 | kernel_init=initializer, 54 | )(conv_out) 55 | conv_out += block_input 56 | 57 | return conv_out 58 | 59 | 60 | class ImpalaEncoder(nn.Module): 61 | """IMPALA encoder.""" 62 | 63 | width: int = 1 64 | stack_sizes: tuple = (16, 32, 32) 65 | num_blocks: int = 2 66 | dropout_rate: float = None 67 | mlp_hidden_dims: Sequence[int] = (512,) 68 | layer_norm: bool = False 69 | 70 | def setup(self): 71 | stack_sizes = self.stack_sizes 72 | self.stack_blocks = [ 73 | ResnetStack( 74 | num_features=stack_sizes[i] * self.width, 75 | num_blocks=self.num_blocks, 76 | ) 77 | for i in range(len(stack_sizes)) 78 | ] 79 | if self.dropout_rate is not None: 80 | self.dropout = nn.Dropout(rate=self.dropout_rate) 81 | 82 | @nn.compact 83 | def __call__(self, x, train=True, cond_var=None): 84 | x = x.astype(jnp.float32) / 255.0 85 | 86 | conv_out = x 87 | 88 | for idx in range(len(self.stack_blocks)): 89 | conv_out = self.stack_blocks[idx](conv_out) 90 | if self.dropout_rate is not None: 91 | conv_out = self.dropout(conv_out, deterministic=not train) 92 | 93 | conv_out = nn.relu(conv_out) 94 | if self.layer_norm: 95 | conv_out = nn.LayerNorm()(conv_out) 96 | out = conv_out.reshape((*x.shape[:-3], -1)) 97 | 98 | out = MLP(self.mlp_hidden_dims, activate_final=True, layer_norm=self.layer_norm)(out) 99 | 100 | return out 101 | 102 | 103 | class GCEncoder(nn.Module): 104 | """Helper module to handle inputs to goal-conditioned networks. 105 | 106 | It takes in observations (s) and goals (g) and returns the concatenation of `state_encoder(s)`, `goal_encoder(g)`, 107 | and `concat_encoder([s, g])`. It ignores the encoders that are not provided. This way, the module can handle both 108 | early and late fusion (or their variants) of state and goal information. 109 | """ 110 | 111 | state_encoder: nn.Module = None 112 | goal_encoder: nn.Module = None 113 | concat_encoder: nn.Module = None 114 | 115 | @nn.compact 116 | def __call__(self, observations, goals=None, goal_encoded=False): 117 | """Returns the representations of observations and goals. 118 | 119 | If `goal_encoded` is True, `goals` is assumed to be already encoded representations. In this case, either 120 | `goal_encoder` or `concat_encoder` must be None. 121 | """ 122 | reps = [] 123 | if self.state_encoder is not None: 124 | reps.append(self.state_encoder(observations)) 125 | if goals is not None: 126 | if goal_encoded: 127 | # Can't have both goal_encoder and concat_encoder in this case. 128 | assert self.goal_encoder is None or self.concat_encoder is None 129 | reps.append(goals) 130 | else: 131 | if self.goal_encoder is not None: 132 | reps.append(self.goal_encoder(goals)) 133 | if self.concat_encoder is not None: 134 | reps.append(self.concat_encoder(jnp.concatenate([observations, goals], axis=-1))) 135 | reps = jnp.concatenate(reps, axis=-1) 136 | return reps 137 | 138 | 139 | encoder_modules = { 140 | 'impala': ImpalaEncoder, 141 | 'impala_debug': functools.partial(ImpalaEncoder, num_blocks=1, stack_sizes=(4, 4)), 142 | 'impala_small': functools.partial(ImpalaEncoder, num_blocks=1), 143 | 'impala_large': functools.partial(ImpalaEncoder, stack_sizes=(64, 128, 128), mlp_hidden_dims=(1024,)), 144 | } 145 | -------------------------------------------------------------------------------- /impls/utils/env_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import platform 4 | import time 5 | 6 | import gymnasium 7 | import numpy as np 8 | from gymnasium.spaces import Box 9 | 10 | import ogbench 11 | from utils.datasets import Dataset 12 | 13 | 14 | class EpisodeMonitor(gymnasium.Wrapper): 15 | """Environment wrapper to monitor episode statistics.""" 16 | 17 | def __init__(self, env): 18 | super().__init__(env) 19 | self._reset_stats() 20 | self.total_timesteps = 0 21 | 22 | def _reset_stats(self): 23 | self.reward_sum = 0.0 24 | self.episode_length = 0 25 | self.start_time = time.time() 26 | 27 | def step(self, action): 28 | observation, reward, terminated, truncated, info = self.env.step(action) 29 | 30 | self.reward_sum += reward 31 | self.episode_length += 1 32 | self.total_timesteps += 1 33 | info['total'] = {'timesteps': self.total_timesteps} 34 | 35 | if terminated or truncated: 36 | info['episode'] = {} 37 | info['episode']['return'] = self.reward_sum 38 | info['episode']['length'] = self.episode_length 39 | info['episode']['duration'] = time.time() - self.start_time 40 | 41 | return observation, reward, terminated, truncated, info 42 | 43 | def reset(self, *args, **kwargs): 44 | self._reset_stats() 45 | return self.env.reset(*args, **kwargs) 46 | 47 | 48 | class FrameStackWrapper(gymnasium.Wrapper): 49 | """Environment wrapper to stack observations.""" 50 | 51 | def __init__(self, env, num_stack): 52 | super().__init__(env) 53 | 54 | self.num_stack = num_stack 55 | self.frames = collections.deque(maxlen=num_stack) 56 | 57 | low = np.concatenate([self.observation_space.low] * num_stack, axis=-1) 58 | high = np.concatenate([self.observation_space.high] * num_stack, axis=-1) 59 | self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype) 60 | 61 | def get_observation(self): 62 | assert len(self.frames) == self.num_stack 63 | return np.concatenate(list(self.frames), axis=-1) 64 | 65 | def reset(self, **kwargs): 66 | ob, info = self.env.reset(**kwargs) 67 | for _ in range(self.num_stack): 68 | self.frames.append(ob) 69 | if 'goal' in info: 70 | info['goal'] = np.concatenate([info['goal']] * self.num_stack, axis=-1) 71 | return self.get_observation(), info 72 | 73 | def step(self, action): 74 | ob, reward, terminated, truncated, info = self.env.step(action) 75 | self.frames.append(ob) 76 | return self.get_observation(), reward, terminated, truncated, info 77 | 78 | 79 | def make_env_and_datasets(dataset_name, frame_stack=None): 80 | """Make OGBench environment and datasets. 81 | 82 | Args: 83 | dataset_name: Name of the dataset. 84 | frame_stack: Number of frames to stack. 85 | 86 | Returns: 87 | A tuple of the environment, training dataset, and validation dataset. 88 | """ 89 | # Use compact dataset to save memory. 90 | env, train_dataset, val_dataset = ogbench.make_env_and_datasets(dataset_name, compact_dataset=True) 91 | train_dataset = Dataset.create(**train_dataset) 92 | val_dataset = Dataset.create(**val_dataset) 93 | 94 | if frame_stack is not None: 95 | env = FrameStackWrapper(env, frame_stack) 96 | 97 | env.reset() 98 | 99 | return env, train_dataset, val_dataset 100 | -------------------------------------------------------------------------------- /impls/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import jax 4 | import numpy as np 5 | from tqdm import trange 6 | 7 | 8 | def supply_rng(f, rng=jax.random.PRNGKey(0)): 9 | """Helper function to split the random number generator key before each call to the function.""" 10 | 11 | def wrapped(*args, **kwargs): 12 | nonlocal rng 13 | rng, key = jax.random.split(rng) 14 | return f(*args, seed=key, **kwargs) 15 | 16 | return wrapped 17 | 18 | 19 | def flatten(d, parent_key='', sep='.'): 20 | """Flatten a dictionary.""" 21 | items = [] 22 | for k, v in d.items(): 23 | new_key = parent_key + sep + k if parent_key else k 24 | if hasattr(v, 'items'): 25 | items.extend(flatten(v, new_key, sep=sep).items()) 26 | else: 27 | items.append((new_key, v)) 28 | return dict(items) 29 | 30 | 31 | def add_to(dict_of_lists, single_dict): 32 | """Append values to the corresponding lists in the dictionary.""" 33 | for k, v in single_dict.items(): 34 | dict_of_lists[k].append(v) 35 | 36 | 37 | def evaluate( 38 | agent, 39 | env, 40 | task_id=None, 41 | config=None, 42 | num_eval_episodes=50, 43 | num_video_episodes=0, 44 | video_frame_skip=3, 45 | eval_temperature=0, 46 | eval_gaussian=None, 47 | ): 48 | """Evaluate the agent in the environment. 49 | 50 | Args: 51 | agent: Agent. 52 | env: Environment. 53 | task_id: Task ID to be passed to the environment. 54 | config: Configuration dictionary. 55 | num_eval_episodes: Number of episodes to evaluate the agent. 56 | num_video_episodes: Number of episodes to render. These episodes are not included in the statistics. 57 | video_frame_skip: Number of frames to skip between renders. 58 | eval_temperature: Action sampling temperature. 59 | eval_gaussian: Standard deviation of the Gaussian noise to add to the actions. 60 | 61 | Returns: 62 | A tuple containing the statistics, trajectories, and rendered videos. 63 | """ 64 | actor_fn = supply_rng(agent.sample_actions, rng=jax.random.PRNGKey(np.random.randint(0, 2**32))) 65 | trajs = [] 66 | stats = defaultdict(list) 67 | 68 | renders = [] 69 | for i in trange(num_eval_episodes + num_video_episodes): 70 | traj = defaultdict(list) 71 | should_render = i >= num_eval_episodes 72 | 73 | observation, info = env.reset(options=dict(task_id=task_id, render_goal=should_render)) 74 | goal = info.get('goal') 75 | goal_frame = info.get('goal_rendered') 76 | done = False 77 | step = 0 78 | render = [] 79 | while not done: 80 | action = actor_fn(observations=observation, goals=goal, temperature=eval_temperature) 81 | action = np.array(action) 82 | if not config.get('discrete'): 83 | if eval_gaussian is not None: 84 | action = np.random.normal(action, eval_gaussian) 85 | action = np.clip(action, -1, 1) 86 | 87 | next_observation, reward, terminated, truncated, info = env.step(action) 88 | done = terminated or truncated 89 | step += 1 90 | 91 | if should_render and (step % video_frame_skip == 0 or done): 92 | frame = env.render().copy() 93 | if goal_frame is not None: 94 | render.append(np.concatenate([goal_frame, frame], axis=0)) 95 | else: 96 | render.append(frame) 97 | 98 | transition = dict( 99 | observation=observation, 100 | next_observation=next_observation, 101 | action=action, 102 | reward=reward, 103 | done=done, 104 | info=info, 105 | ) 106 | add_to(traj, transition) 107 | observation = next_observation 108 | if i < num_eval_episodes: 109 | add_to(stats, flatten(info)) 110 | trajs.append(traj) 111 | else: 112 | renders.append(np.array(render)) 113 | 114 | for k, v in stats.items(): 115 | stats[k] = np.mean(v) 116 | 117 | return stats, trajs, renders 118 | -------------------------------------------------------------------------------- /impls/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from datetime import datetime 4 | 5 | import absl.flags as flags 6 | import ml_collections 7 | import numpy as np 8 | import wandb 9 | from PIL import Image, ImageEnhance 10 | 11 | 12 | class CsvLogger: 13 | """CSV logger for logging metrics to a CSV file.""" 14 | 15 | def __init__(self, path): 16 | self.path = path 17 | self.header = None 18 | self.file = None 19 | self.disallowed_types = (wandb.Image, wandb.Video, wandb.Histogram) 20 | 21 | def log(self, row, step): 22 | row['step'] = step 23 | if self.file is None: 24 | self.file = open(self.path, 'w') 25 | if self.header is None: 26 | self.header = [k for k, v in row.items() if not isinstance(v, self.disallowed_types)] 27 | self.file.write(','.join(self.header) + '\n') 28 | filtered_row = {k: v for k, v in row.items() if not isinstance(v, self.disallowed_types)} 29 | self.file.write(','.join([str(filtered_row.get(k, '')) for k in self.header]) + '\n') 30 | else: 31 | filtered_row = {k: v for k, v in row.items() if not isinstance(v, self.disallowed_types)} 32 | self.file.write(','.join([str(filtered_row.get(k, '')) for k in self.header]) + '\n') 33 | self.file.flush() 34 | 35 | def close(self): 36 | if self.file is not None: 37 | self.file.close() 38 | 39 | 40 | def get_exp_name(seed): 41 | """Return the experiment name.""" 42 | exp_name = '' 43 | exp_name += f'sd{seed:03d}_' 44 | if 'SLURM_JOB_ID' in os.environ: 45 | exp_name += f's_{os.environ["SLURM_JOB_ID"]}.' 46 | if 'SLURM_PROCID' in os.environ: 47 | exp_name += f'{os.environ["SLURM_PROCID"]}.' 48 | exp_name += f'{datetime.now().strftime("%Y%m%d_%H%M%S")}' 49 | 50 | return exp_name 51 | 52 | 53 | def get_flag_dict(): 54 | """Return the dictionary of flags.""" 55 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS if '.' not in k} 56 | for k in flag_dict: 57 | if isinstance(flag_dict[k], ml_collections.ConfigDict): 58 | flag_dict[k] = flag_dict[k].to_dict() 59 | return flag_dict 60 | 61 | 62 | def setup_wandb( 63 | entity=None, 64 | project='project', 65 | group=None, 66 | name=None, 67 | mode='online', 68 | ): 69 | """Set up Weights & Biases for logging.""" 70 | wandb_output_dir = tempfile.mkdtemp() 71 | tags = [group] if group is not None else None 72 | 73 | init_kwargs = dict( 74 | config=get_flag_dict(), 75 | project=project, 76 | entity=entity, 77 | tags=tags, 78 | group=group, 79 | dir=wandb_output_dir, 80 | name=name, 81 | settings=wandb.Settings( 82 | start_method='thread', 83 | _disable_stats=False, 84 | ), 85 | mode=mode, 86 | save_code=True, 87 | ) 88 | 89 | run = wandb.init(**init_kwargs) 90 | 91 | return run 92 | 93 | 94 | def reshape_video(v, n_cols=None): 95 | """Helper function to reshape videos.""" 96 | if v.ndim == 4: 97 | v = v[None,] 98 | 99 | _, t, h, w, c = v.shape 100 | 101 | if n_cols is None: 102 | # Set n_cols to the square root of the number of videos. 103 | n_cols = np.ceil(np.sqrt(v.shape[0])).astype(int) 104 | if v.shape[0] % n_cols != 0: 105 | len_addition = n_cols - v.shape[0] % n_cols 106 | v = np.concatenate((v, np.zeros(shape=(len_addition, t, h, w, c))), axis=0) 107 | n_rows = v.shape[0] // n_cols 108 | 109 | v = np.reshape(v, newshape=(n_rows, n_cols, t, h, w, c)) 110 | v = np.transpose(v, axes=(2, 5, 0, 3, 1, 4)) 111 | v = np.reshape(v, newshape=(t, c, n_rows * h, n_cols * w)) 112 | 113 | return v 114 | 115 | 116 | def get_wandb_video(renders=None, n_cols=None, fps=15): 117 | """Return a Weights & Biases video. 118 | 119 | It takes a list of videos and reshapes them into a single video with the specified number of columns. 120 | 121 | Args: 122 | renders: List of videos. Each video should be a numpy array of shape (t, h, w, c). 123 | n_cols: Number of columns for the reshaped video. If None, it is set to the square root of the number of videos. 124 | """ 125 | # Pad videos to the same length. 126 | max_length = max([len(render) for render in renders]) 127 | for i, render in enumerate(renders): 128 | assert render.dtype == np.uint8 129 | 130 | # Decrease brightness of the padded frames. 131 | final_frame = render[-1] 132 | final_image = Image.fromarray(final_frame) 133 | enhancer = ImageEnhance.Brightness(final_image) 134 | final_image = enhancer.enhance(0.5) 135 | final_frame = np.array(final_image) 136 | 137 | pad = np.repeat(final_frame[np.newaxis, ...], max_length - len(render), axis=0) 138 | renders[i] = np.concatenate([render, pad], axis=0) 139 | 140 | # Add borders. 141 | renders[i] = np.pad(renders[i], ((0, 0), (1, 1), (1, 1), (0, 0)), mode='constant', constant_values=0) 142 | renders = np.array(renders) # (n, t, h, w, c) 143 | 144 | renders = reshape_video(renders, n_cols) # (t, c, nr * h, nc * w) 145 | 146 | return wandb.Video(renders, fps=fps, format='mp4') 147 | -------------------------------------------------------------------------------- /ogbench/__init__.py: -------------------------------------------------------------------------------- 1 | """OGBench: Benchmarking Offline Goal-Conditioned RL""" 2 | 3 | import ogbench.locomaze 4 | import ogbench.manipspace 5 | import ogbench.powderworld 6 | from ogbench.utils import download_datasets, load_dataset, make_env_and_datasets 7 | 8 | __all__ = ( 9 | 'locomaze', 10 | 'manipspace', 11 | 'powderworld', 12 | 'download_datasets', 13 | 'load_dataset', 14 | 'make_env_and_datasets', 15 | ) 16 | -------------------------------------------------------------------------------- /ogbench/locomaze/ant.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gymnasium 4 | import numpy as np 5 | from gymnasium import utils 6 | from gymnasium.envs.mujoco import MujocoEnv 7 | from gymnasium.spaces import Box 8 | 9 | 10 | class AntEnv(MujocoEnv, utils.EzPickle): 11 | """Gymnasium Ant environment. 12 | 13 | Unlike the original Ant environment, this environment uses a restricted joint range for the actuators, as typically 14 | done in previous works in hierarchical reinforcement learning. It also uses a control frequency of 10Hz instead of 15 | 20Hz, which is the default in the original environment. 16 | """ 17 | 18 | xml_file = os.path.join(os.path.dirname(__file__), 'assets', 'ant.xml') 19 | metadata = { 20 | 'render_modes': ['human', 'rgb_array', 'depth_array'], 21 | 'render_fps': 10, 22 | } 23 | if gymnasium.__version__ >= '1.1.0': 24 | metadata['render_modes'] += ['rgbd_tuple'] 25 | 26 | def __init__( 27 | self, 28 | xml_file=None, 29 | reset_noise_scale=0.1, 30 | render_mode='rgb_array', 31 | width=200, 32 | height=200, 33 | **kwargs, 34 | ): 35 | """Initialize the Ant environment. 36 | 37 | Args: 38 | xml_file: Path to the XML description (optional). 39 | reset_noise_scale: Scale of the noise added to the initial state during reset. 40 | render_mode: Rendering mode. 41 | width: Width of the rendered image. 42 | height: Height of the rendered image. 43 | **kwargs: Additional keyword arguments. 44 | """ 45 | if xml_file is None: 46 | xml_file = self.xml_file 47 | utils.EzPickle.__init__( 48 | self, 49 | xml_file, 50 | reset_noise_scale, 51 | **kwargs, 52 | ) 53 | 54 | self._reset_noise_scale = reset_noise_scale 55 | 56 | observation_space = Box(low=-np.inf, high=np.inf, shape=(29,), dtype=np.float64) 57 | 58 | MujocoEnv.__init__( 59 | self, 60 | xml_file, 61 | frame_skip=5, 62 | observation_space=observation_space, 63 | render_mode=render_mode, 64 | width=width, 65 | height=height, 66 | **kwargs, 67 | ) 68 | 69 | def step(self, action): 70 | prev_qpos = self.data.qpos.copy() 71 | prev_qvel = self.data.qvel.copy() 72 | 73 | self.do_simulation(action, self.frame_skip) 74 | 75 | qpos = self.data.qpos.copy() 76 | qvel = self.data.qvel.copy() 77 | 78 | observation = self.get_ob() 79 | 80 | if self.render_mode == 'human': 81 | self.render() 82 | 83 | return ( 84 | observation, 85 | 0.0, 86 | False, 87 | False, 88 | { 89 | 'xy': self.get_xy(), 90 | 'prev_qpos': prev_qpos, 91 | 'prev_qvel': prev_qvel, 92 | 'qpos': qpos, 93 | 'qvel': qvel, 94 | }, 95 | ) 96 | 97 | def get_ob(self): 98 | position = self.data.qpos.flat.copy() 99 | velocity = self.data.qvel.flat.copy() 100 | 101 | return np.concatenate([position, velocity]) 102 | 103 | def reset_model(self): 104 | noise_low = -self._reset_noise_scale 105 | noise_high = self._reset_noise_scale 106 | 107 | qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) 108 | qvel = self.init_qvel + self._reset_noise_scale * self.np_random.standard_normal(self.model.nv) 109 | self.set_state(qpos, qvel) 110 | 111 | observation = self.get_ob() 112 | 113 | return observation 114 | 115 | def get_xy(self): 116 | return self.data.qpos[:2].copy() 117 | 118 | def set_xy(self, xy): 119 | qpos = self.data.qpos.copy() 120 | qvel = self.data.qvel.copy() 121 | qpos[:2] = xy 122 | self.set_state(qpos, qvel) 123 | -------------------------------------------------------------------------------- /ogbench/locomaze/assets/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /ogbench/locomaze/assets/point.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 42 | -------------------------------------------------------------------------------- /ogbench/locomaze/humanoid.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | 4 | import gymnasium 5 | import mujoco 6 | import numpy as np 7 | from gymnasium import utils 8 | from gymnasium.envs.mujoco import MujocoEnv 9 | from gymnasium.spaces import Box 10 | 11 | 12 | class HumanoidEnv(MujocoEnv, utils.EzPickle): 13 | """DMC Humanoid environment. 14 | 15 | Several methods are reimplemented to remove the dependency on the `dm_control` package. It is supposed to work 16 | identically to the original Humanoid environment. 17 | """ 18 | 19 | xml_file = os.path.join(os.path.dirname(__file__), 'assets', 'humanoid.xml') 20 | metadata = { 21 | 'render_modes': ['human', 'rgb_array', 'depth_array'], 22 | 'render_fps': 40, 23 | } 24 | if gymnasium.__version__ >= '1.1.0': 25 | metadata['render_modes'] += ['rgbd_tuple'] 26 | 27 | def __init__( 28 | self, 29 | xml_file=None, 30 | render_mode='rgb_array', 31 | width=200, 32 | height=200, 33 | **kwargs, 34 | ): 35 | """Initialize the Humanoid environment. 36 | 37 | Args: 38 | xml_file: Path to the XML description (optional). 39 | render_mode: Rendering mode. 40 | width: Width of the rendered image. 41 | height: Height of the rendered image. 42 | **kwargs: Additional keyword arguments. 43 | """ 44 | if xml_file is None: 45 | xml_file = self.xml_file 46 | utils.EzPickle.__init__( 47 | self, 48 | xml_file, 49 | **kwargs, 50 | ) 51 | 52 | observation_space = Box(low=-np.inf, high=np.inf, shape=(69,), dtype=np.float64) 53 | 54 | MujocoEnv.__init__( 55 | self, 56 | xml_file, 57 | frame_skip=5, 58 | observation_space=observation_space, 59 | render_mode=render_mode, 60 | width=width, 61 | height=height, 62 | **kwargs, 63 | ) 64 | 65 | def step(self, action): 66 | prev_qpos = self.data.qpos.copy() 67 | prev_qvel = self.data.qvel.copy() 68 | 69 | self.do_simulation(action, self.frame_skip) 70 | 71 | qpos = self.data.qpos.copy() 72 | qvel = self.data.qvel.copy() 73 | 74 | observation = self.get_ob() 75 | 76 | if self.render_mode == 'human': 77 | self.render() 78 | 79 | return ( 80 | observation, 81 | 0.0, 82 | False, 83 | False, 84 | { 85 | 'xy': self.get_xy(), 86 | 'prev_qpos': prev_qpos, 87 | 'prev_qvel': prev_qvel, 88 | 'qpos': qpos, 89 | 'qvel': qvel, 90 | }, 91 | ) 92 | 93 | def _step_mujoco_simulation(self, ctrl, n_frames): 94 | self.data.ctrl[:] = ctrl 95 | 96 | # DMC-style stepping (see Page 6 of https://arxiv.org/abs/2006.12983). 97 | if self.model.opt.integrator != mujoco.mjtIntegrator.mjINT_RK4.value: 98 | mujoco.mj_step2(self.model, self.data) 99 | if n_frames > 1: 100 | mujoco.mj_step(self.model, self.data, n_frames - 1) 101 | else: 102 | mujoco.mj_step(self.model, self.data, n_frames) 103 | 104 | mujoco.mj_step1(self.model, self.data) 105 | 106 | def get_ob(self): 107 | xy = self.data.qpos[:2] 108 | joint_angles = self.data.qpos[7:] # Skip the 7 DoFs of the free root joint. 109 | head_height = self.data.xpos[2, 2] # ['head', 'z'] 110 | torso_frame = self.data.xmat[1].reshape(3, 3) # ['torso'] 111 | torso_pos = self.data.xpos[1] # ['torso'] 112 | positions = [] 113 | for idx in [16, 10, 13, 7]: # ['left_hand', 'left_foot', 'right_hand', 'right_foot'] 114 | torso_to_limb = self.data.xpos[idx] - torso_pos 115 | positions.append(torso_to_limb.dot(torso_frame)) 116 | extremities = np.hstack(positions) 117 | torso_vertical_orientation = self.data.xmat[1, [6, 7, 8]] # ['torso', ['zx', 'zy', 'zz']] 118 | center_of_mass_velocity = self.data.sensordata[0:3] # ['torso_subtreelinvel'] 119 | velocity = self.data.qvel 120 | 121 | return np.concatenate( 122 | [ 123 | xy, 124 | joint_angles, 125 | [head_height], 126 | extremities, 127 | torso_vertical_orientation, 128 | center_of_mass_velocity, 129 | velocity, 130 | ] 131 | ) 132 | 133 | @contextlib.contextmanager 134 | def disable(self, *flags): 135 | old_bitmask = self.model.opt.disableflags 136 | new_bitmask = old_bitmask 137 | for flag in flags: 138 | if isinstance(flag, str): 139 | field_name = 'mjDSBL_' + flag.upper() 140 | flag = getattr(mujoco.mjtDisableBit, field_name) 141 | elif isinstance(flag, int): 142 | flag = mujoco.mjtDisableBit(flag) 143 | new_bitmask |= flag.value 144 | self.model.opt.disableflags = new_bitmask 145 | try: 146 | yield 147 | finally: 148 | self.model.opt.disableflags = old_bitmask 149 | 150 | def reset_model(self): 151 | penetrating = True 152 | while penetrating: 153 | quat = self.np_random.uniform(size=4) 154 | quat /= np.linalg.norm(quat) 155 | self.data.qpos[3:7] = quat 156 | self.data.qvel = 0.1 * self.np_random.standard_normal(self.model.nv) 157 | 158 | for joint_id in range(1, self.model.njnt): 159 | range_min, range_max = self.model.jnt_range[joint_id] 160 | self.data.qpos[6 + joint_id] = self.np_random.uniform(range_min, range_max) 161 | 162 | with self.disable('actuation'): 163 | mujoco.mj_forward(self.model, self.data) 164 | penetrating = self.data.ncon > 0 165 | 166 | observation = self.get_ob() 167 | 168 | return observation 169 | 170 | def get_xy(self): 171 | return self.data.qpos[:2].copy() 172 | 173 | def set_xy(self, xy): 174 | qpos = self.data.qpos.copy() 175 | qvel = self.data.qvel.copy() 176 | qpos[:2] = xy 177 | self.set_state(qpos, qvel) 178 | -------------------------------------------------------------------------------- /ogbench/locomaze/point.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gymnasium 4 | import mujoco 5 | import numpy as np 6 | from gymnasium import utils 7 | from gymnasium.envs.mujoco import MujocoEnv 8 | from gymnasium.spaces import Box 9 | 10 | 11 | class PointEnv(MujocoEnv, utils.EzPickle): 12 | """PointMass environment. 13 | 14 | This is a simple 2-D point mass environment, where the agent is controlled by an x-y action vector that is added to 15 | the current position of the point mass. 16 | """ 17 | 18 | xml_file = os.path.join(os.path.dirname(__file__), 'assets', 'point.xml') 19 | metadata = { 20 | 'render_modes': ['human', 'rgb_array', 'depth_array'], 21 | 'render_fps': 10, 22 | } 23 | if gymnasium.__version__ >= '1.1.0': 24 | metadata['render_modes'] += ['rgbd_tuple'] 25 | 26 | def __init__( 27 | self, 28 | xml_file=None, 29 | render_mode='rgb_array', 30 | width=200, 31 | height=200, 32 | **kwargs, 33 | ): 34 | """Initialize the Humanoid environment. 35 | 36 | Args: 37 | xml_file: Path to the XML description (optional). 38 | render_mode: Rendering mode. 39 | width: Width of the rendered image. 40 | height: Height of the rendered image. 41 | **kwargs: Additional keyword arguments. 42 | """ 43 | if xml_file is None: 44 | xml_file = self.xml_file 45 | utils.EzPickle.__init__( 46 | self, 47 | xml_file, 48 | **kwargs, 49 | ) 50 | 51 | observation_space = Box(low=-np.inf, high=np.inf, shape=(6,), dtype=np.float64) 52 | 53 | MujocoEnv.__init__( 54 | self, 55 | xml_file, 56 | frame_skip=5, 57 | observation_space=observation_space, 58 | render_mode=render_mode, 59 | width=width, 60 | height=height, 61 | **kwargs, 62 | ) 63 | 64 | def step(self, action): 65 | prev_qpos = self.data.qpos.copy() 66 | prev_qvel = self.data.qvel.copy() 67 | 68 | action = 0.2 * action 69 | 70 | self.data.qpos[:] = self.data.qpos + action 71 | self.data.qvel[:] = np.array([0.0, 0.0]) 72 | 73 | mujoco.mj_step(self.model, self.data, nstep=self.frame_skip) 74 | 75 | qpos = self.data.qpos.flat.copy() 76 | qvel = self.data.qvel.flat.copy() 77 | 78 | observation = self.get_ob() 79 | 80 | if self.render_mode == 'human': 81 | self.render() 82 | 83 | return ( 84 | observation, 85 | 0.0, 86 | False, 87 | False, 88 | { 89 | 'xy': self.get_xy(), 90 | 'prev_qpos': prev_qpos, 91 | 'prev_qvel': prev_qvel, 92 | 'qpos': qpos, 93 | 'qvel': qvel, 94 | }, 95 | ) 96 | 97 | def get_ob(self): 98 | return self.data.qpos.flat.copy() 99 | 100 | def reset_model(self): 101 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.1, high=0.1) 102 | qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1 103 | 104 | self.set_state(qpos, qvel) 105 | 106 | return self.get_ob() 107 | 108 | def get_xy(self): 109 | return self.data.qpos.copy() 110 | 111 | def set_xy(self, xy): 112 | qpos = self.data.qpos.copy() 113 | qvel = self.data.qvel.copy() 114 | qpos[:] = xy 115 | self.set_state(qpos, qvel) 116 | -------------------------------------------------------------------------------- /ogbench/manipspace/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | from ogbench.manipspace.controllers.diff_ik import DiffIKController 2 | 3 | __all__ = ('DiffIKController',) 4 | -------------------------------------------------------------------------------- /ogbench/manipspace/controllers/diff_ik.py: -------------------------------------------------------------------------------- 1 | import mujoco 2 | import numpy as np 3 | 4 | PI = np.pi 5 | PI_2 = 2 * np.pi 6 | 7 | 8 | def angle_diff(q1: np.ndarray, q2: np.ndarray) -> np.ndarray: 9 | return np.mod(q1 - q2 + PI, PI_2) - PI 10 | 11 | 12 | class DiffIKController: 13 | """Differential inverse kinematics controller.""" 14 | 15 | def __init__( 16 | self, 17 | model: mujoco.MjModel, 18 | sites: list, 19 | qpos0: np.ndarray = None, 20 | damping_coeff: float = 1e-12, 21 | max_angle_change: float = np.radians(45), 22 | ): 23 | self._model = model 24 | self._data = mujoco.MjData(self._model) 25 | self._qp0 = qpos0 26 | self._max_angle_change = max_angle_change 27 | 28 | # Cache references. 29 | self._ns = len(sites) # Number of sites. 30 | self._site_ids = np.asarray([self._model.site(s).id for s in sites]) 31 | 32 | # Preallocate arrays. 33 | self._err = np.empty((self._ns, 6)) 34 | self._site_quat = np.empty((self._ns, 4)) 35 | self._site_quat_inv = np.empty((self._ns, 4)) 36 | self._err_quat = np.empty((self._ns, 4)) 37 | self._jac = np.empty((6 * self._ns, self._model.nv)) 38 | self._damping = damping_coeff * np.eye(6 * self._ns) 39 | self._eye = np.eye(self._model.nv) 40 | 41 | def _forward_kinematics(self) -> None: 42 | """Minimal computation required for forward kinematics.""" 43 | mujoco.mj_kinematics(self._model, self._data) 44 | mujoco.mj_comPos(self._model, self._data) # Required for mj_jacSite. 45 | 46 | def _integrate(self, update: np.ndarray) -> None: 47 | """Integrate the joint velocities in-place.""" 48 | mujoco.mj_integratePos(self._model, self._data.qpos, update, 1.0) 49 | 50 | def _compute_translational_error(self, pos: np.ndarray) -> None: 51 | """Compute the error between the desired and current site positions.""" 52 | self._err[:, :3] = pos - self._data.site_xpos[self._site_ids] 53 | 54 | def _compute_rotational_error(self, quat: np.ndarray) -> None: 55 | """Compute the error between the desired and current site orientations.""" 56 | for i, site_id in enumerate(self._site_ids): 57 | mujoco.mju_mat2Quat(self._site_quat[i], self._data.site_xmat[site_id]) 58 | mujoco.mju_negQuat(self._site_quat_inv[i], self._site_quat[i]) 59 | mujoco.mju_mulQuat(self._err_quat[i], quat[i], self._site_quat_inv[i]) 60 | mujoco.mju_quat2Vel(self._err[i, 3:], self._err_quat[i], 1.0) 61 | 62 | def _compute_jacobian(self) -> None: 63 | """Update site end-effector Jacobians.""" 64 | for i, site_id in enumerate(self._site_ids): 65 | jacp = self._jac[6 * i : 6 * i + 3] 66 | jacr = self._jac[6 * i + 3 : 6 * i + 6] 67 | mujoco.mj_jacSite(self._model, self._data, jacp, jacr, site_id) 68 | 69 | def _error_threshold_reached(self, pos_thresh: float, ori_thresh: float) -> bool: 70 | """Return True if position and rotation errors are below the thresholds.""" 71 | pos_achieved = np.linalg.norm(self._err[:, :3]) <= pos_thresh 72 | ori_achieved = np.linalg.norm(self._err[:, 3:]) <= ori_thresh 73 | return pos_achieved and ori_achieved 74 | 75 | def _solve(self) -> np.ndarray: 76 | """Solve for joint velocities using damped least squares.""" 77 | H = self._jac @ self._jac.T + self._damping 78 | x = self._jac.T @ np.linalg.solve(H, self._err.ravel()) 79 | if self._qp0 is not None: 80 | jac_pinv = np.linalg.pinv(H) 81 | q_err = angle_diff(self._qp0, self._data.qpos) 82 | x += (self._eye - (self._jac.T @ jac_pinv) @ self._jac) @ q_err 83 | return x 84 | 85 | def _scale_update(self, update: np.ndarray) -> np.ndarray: 86 | """Scale down update so that the max allowable angle change is not exceeded.""" 87 | update_max = np.max(np.abs(update)) 88 | if update_max > self._max_angle_change: 89 | update *= self._max_angle_change / update_max 90 | return update 91 | 92 | def solve( 93 | self, 94 | pos: np.ndarray, 95 | quat: np.ndarray, 96 | curr_qpos: np.ndarray, 97 | max_iters: int = 20, 98 | pos_thresh: float = 1e-4, 99 | ori_thresh: float = 1e-4, 100 | ) -> np.ndarray: 101 | self._data.qpos = curr_qpos 102 | 103 | for _ in range(max_iters): 104 | self._forward_kinematics() 105 | 106 | self._compute_translational_error(np.atleast_2d(pos)) 107 | self._compute_rotational_error(np.atleast_2d(quat)) 108 | if self._error_threshold_reached(pos_thresh, ori_thresh): 109 | break 110 | 111 | self._compute_jacobian() 112 | update = self._scale_update(self._solve()) 113 | self._integrate(update) 114 | 115 | return self._data.qpos.copy() 116 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/button_inner.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/button_outer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/buttons.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/cube.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/cube_inner.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/cube_outer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/drawer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/floor_wall.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/button/button.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/button.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/button/buttonring.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/buttonring.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/button/metal1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/metal1.png -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/button/stopbot.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stopbot.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/button/stopbutton.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stopbutton.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/button/stopbuttonrim.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stopbuttonrim.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/button/stopbuttonrod.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stopbuttonrod.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/button/stoptop.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stoptop.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/drawer/drawer.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/drawer/drawer.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/drawer/drawercase.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/drawer/drawercase.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/drawer/drawerhandle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/drawer/drawerhandle.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/window_base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/window_base.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/window_frame.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/window_frame.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/window_h_base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/window_h_base.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/window_h_frame.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/window_h_frame.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/windowa_frame.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowa_frame.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/windowa_glass.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowa_glass.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/windowa_h_frame.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowa_h_frame.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/windowa_h_glass.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowa_h_glass.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/windowb_frame.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowb_frame.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/windowb_glass.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowb_glass.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/windowb_h_frame.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowb_h_frame.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/metaworld/window/windowb_h_glass.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowb_h_glass.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/2f85.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/2f85.png -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, ROS-Industrial 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 18 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 19 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 20 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 21 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 23 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/README.md: -------------------------------------------------------------------------------- 1 | # Robotiq 2F-85 Description (MJCF) 2 | 3 | Requires MuJoCo 2.2.2 or later. 4 | 5 | ## Overview 6 | 7 | This package contains a simplified robot description (MJCF) of the [Robotiq 85mm 8 | 2-Finger Adaptive 9 | Gripper](https://robotiq.com/products/2f85-140-adaptive-robot-gripper) developed 10 | by [Robotiq](https://robotiq.com/). It is derived from the [publicly available 11 | URDF 12 | description](https://github.com/ros-industrial/robotiq/tree/kinetic-devel/robotiq_2f_85_gripper_visualization). 13 | 14 |

15 | 16 |

17 | 18 | ## URDF → MJCF derivation steps 19 | 20 | 1. Added ` ` to the URDF's 21 | `` clause in order to preserve visual geometries. 22 | 2. Loaded the URDF into MuJoCo and saved a corresponding MJCF. 23 | 3. Manually edited the MJCF to extract common properties into the `` section. 24 | 4. Added `` clauses to prevent collisions between the linkage bodies. 25 | 5. Broke up collision pads into two pads for more contacts. 26 | 6. Increased pad friction and priority. 27 | 7. Added `impratio=10` for better noslip. 28 | 8. Added `scene.xml` which includes the robot, with a textured groundplane, skybox, and haze. 29 | 9. Added hanging box to `scene.xml`. 30 | 31 | ## License 32 | 33 | This model is released under a [BSD-2-Clause License](LICENSE). 34 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/assets/base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/base.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/assets/base_mount.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/base_mount.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/assets/coupler.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/coupler.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/assets/driver.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/driver.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/assets/follower.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/follower.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/assets/pad.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/pad.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/assets/silicone_pad.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/silicone_pad.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/assets/spring_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/spring_link.stl -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/robotiq_2f85/scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 41 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/universal_robots_ur5e/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 ROS Industrial Consortium 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation and/or 11 | other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software without 15 | specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/universal_robots_ur5e/README.md: -------------------------------------------------------------------------------- 1 | # Universal Robots UR5e Description (MJCF) 2 | 3 | Requires MuJoCo 2.3.3 or later. 4 | 5 | ## Overview 6 | 7 | This package contains a simplified robot description (MJCF) of the 8 | [UR5e](https://www.universal-robots.com/products/ur5-robot/) developed by 9 | [Universal Robots](https://www.universal-robots.com/). It is derived from the 10 | [publicly available URDF 11 | description](https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_e_description). 12 | 13 |

14 | 15 |

16 | 17 | ### URDF → MJCF derivation steps 18 | 19 | 1. Converted the DAE [mesh 20 | files](https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_e_description/meshes/ur5e/visual) 21 | to OBJ format using [Blender](https://www.blender.org/). 22 | 2. Processed `.obj` files with [`obj2mjcf`](https://github.com/kevinzakka/obj2mjcf). 23 | 3. Added ` ` to the URDF's 24 | `` clause in order to preserve visual geometries. 25 | 4. Loaded the URDF into MuJoCo and saved a corresponding MJCF. 26 | 5. Added a tracking light to the base. 27 | 6. Manually edited the MJCF to extract common properties into the `` section. 28 | 7. Added position-controlled actuators. Max joint torque values were taken from 29 | [here](https://www.universal-robots.com/articles/ur/robot-care-maintenance/max-joint-torques/). 30 | 8. Added home joint configuration as a `keyframe`. 31 | 9. Manually designed collision geometries. 32 | 10. Added `scene.xml` which includes the robot, with a textured ground plane, skybox and haze. 33 | 34 | ## License 35 | 36 | This model is released under a [BSD-3-Clause License](LICENSE). 37 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/universal_robots_ur5e/scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/universal_robots_ur5e/ur5e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/universal_robots_ur5e/ur5e.png -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/universal_robots_ur5e/ur5e.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 137 | -------------------------------------------------------------------------------- /ogbench/manipspace/descriptions/window.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /ogbench/manipspace/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/envs/__init__.py -------------------------------------------------------------------------------- /ogbench/manipspace/lie/__init__.py: -------------------------------------------------------------------------------- 1 | from ogbench.manipspace.lie.se3 import SE3 2 | from ogbench.manipspace.lie.so3 import SO3 3 | from ogbench.manipspace.lie.utils import get_epsilon, interpolate, mat2quat, skew 4 | 5 | __all__ = ( 6 | 'SE3', 7 | 'SO3', 8 | 'get_epsilon', 9 | 'interpolate', 10 | 'mat2quat', 11 | 'skew', 12 | ) 13 | -------------------------------------------------------------------------------- /ogbench/manipspace/lie/se3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Any 5 | 6 | import numpy as np 7 | 8 | from ogbench.manipspace.lie.so3 import SO3 9 | from ogbench.manipspace.lie.utils import get_epsilon, skew 10 | 11 | _IDENTITY_WXYZ_XYZ = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float64) 12 | 13 | 14 | @dataclass(frozen=True) 15 | class SE3: 16 | """Special Euclidean group for proper rigid transforms in 3D. 17 | 18 | Internal parameterization is (qw, qx, qy, qz, x, y, z). 19 | Tangent parameterization is (vx, vy, vz, omega_x, omega_y, omega_z). 20 | """ 21 | 22 | wxyz_xyz: np.ndarray 23 | matrix_dim: int = 4 24 | parameters_dim: int = 7 25 | tangent_dim: int = 6 26 | space_dim: int = 3 27 | 28 | def __repr__(self) -> str: 29 | quat = np.round(self.wxyz_xyz[:4], 5) 30 | xyz = np.round(self.wxyz_xyz[4:], 5) 31 | return f'{self.__class__.__name__}(wxyz={quat}, xyz={xyz})' 32 | 33 | @staticmethod 34 | def identity() -> SE3: 35 | return SE3(wxyz_xyz=_IDENTITY_WXYZ_XYZ) 36 | 37 | @staticmethod 38 | def from_rotation_and_translation( 39 | rotation: SO3, 40 | translation: np.ndarray, 41 | ) -> SE3: 42 | assert translation.shape == (SE3.space_dim,) 43 | return SE3(wxyz_xyz=np.concatenate([rotation.wxyz, translation])) 44 | 45 | @staticmethod 46 | def from_matrix(matrix: np.ndarray) -> SE3: 47 | assert matrix.shape == (SE3.matrix_dim, SE3.matrix_dim) 48 | return SE3.from_rotation_and_translation( 49 | rotation=SO3.from_matrix(matrix[:3, :3]), 50 | translation=matrix[:3, 3], 51 | ) 52 | 53 | @staticmethod 54 | def sample_uniform() -> SE3: 55 | return SE3.from_rotation_and_translation( 56 | rotation=SO3.sample_uniform(), 57 | translation=np.random.uniform(-1.0, 1.0, size=(SE3.space_dim,)), 58 | ) 59 | 60 | def rotation(self) -> SO3: 61 | return SO3(wxyz=self.wxyz_xyz[:4]) 62 | 63 | def translation(self) -> np.ndarray: 64 | return self.wxyz_xyz[4:] 65 | 66 | def as_matrix(self) -> np.ndarray: 67 | hmat = np.eye(self.matrix_dim, dtype=np.float64) 68 | hmat[:3, :3] = self.rotation().as_matrix() 69 | hmat[:3, 3] = self.translation() 70 | return hmat 71 | 72 | @staticmethod 73 | def exp(tangent: np.ndarray) -> SE3: 74 | assert tangent.shape == (SE3.tangent_dim,) 75 | rotation = SO3.exp(tangent[3:]) 76 | theta_squared = tangent[3:] @ tangent[3:] 77 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype) 78 | theta_squared_safe = 1.0 if use_taylor else theta_squared 79 | theta_safe = np.sqrt(theta_squared_safe) 80 | skew_omega = skew(tangent[3:]) 81 | if use_taylor: 82 | V = rotation.as_matrix() 83 | else: 84 | V = ( 85 | np.eye(3, dtype=np.float64) 86 | + (1.0 - np.cos(theta_safe)) / (theta_squared_safe) * skew_omega 87 | + (theta_safe - np.sin(theta_safe)) / (theta_squared_safe * theta_safe) * (skew_omega @ skew_omega) 88 | ) 89 | return SE3.from_rotation_and_translation( 90 | rotation=rotation, 91 | translation=V @ tangent[:3], 92 | ) 93 | 94 | def log(self) -> np.ndarray: 95 | omega = self.rotation().log() 96 | theta_squared = omega @ omega 97 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype) 98 | skew_omega = skew(omega) 99 | theta_squared_safe = 1.0 if use_taylor else theta_squared 100 | theta_safe = np.sqrt(theta_squared_safe) 101 | half_theta_safe = 0.5 * theta_safe 102 | if use_taylor: 103 | V_inv = np.eye(3, dtype=np.float64) - 0.5 * skew_omega + (skew_omega @ skew_omega) / 12.0 104 | else: 105 | V_inv = ( 106 | np.eye(3, dtype=np.float64) 107 | - 0.5 * skew_omega 108 | + (1.0 - theta_safe * np.cos(half_theta_safe) / (2.0 * np.sin(half_theta_safe))) 109 | / theta_squared_safe 110 | * (skew_omega @ skew_omega) 111 | ) 112 | return np.concatenate([V_inv @ self.translation(), omega]) 113 | 114 | def adjoint(self) -> np.ndarray: 115 | R = self.rotation().as_matrix() 116 | return np.block( 117 | [ 118 | [R, np.zeros((3, 3), dtype=np.float64)], 119 | [skew(self.translation()) @ R, R], 120 | ] 121 | ) 122 | 123 | def inverse(self) -> SE3: 124 | R_inv = self.rotation().inverse() 125 | return SE3.from_rotation_and_translation( 126 | rotation=R_inv, 127 | translation=-(R_inv @ self.translation()), 128 | ) 129 | 130 | def normalize(self) -> SE3: 131 | return SE3.from_rotation_and_translation( 132 | rotation=self.rotation().normalize(), 133 | translation=self.translation(), 134 | ) 135 | 136 | def apply(self, target: np.ndarray) -> np.ndarray: 137 | assert target.shape == (SE3.space_dim,) 138 | return self.rotation() @ target + self.translation() 139 | 140 | def multiply(self, other: SE3) -> SE3: 141 | return SE3.from_rotation_and_translation( 142 | rotation=self.rotation() @ other.rotation(), 143 | translation=(self.rotation() @ other.translation()) + self.translation(), 144 | ) 145 | 146 | def __matmul__(self, other: Any) -> Any: 147 | if isinstance(other, np.ndarray): 148 | return self.apply(target=other) 149 | elif isinstance(other, SE3): 150 | return self.multiply(other=other) 151 | else: 152 | raise ValueError(f'Unsupported argument type for @ operator: {type(other)}') 153 | -------------------------------------------------------------------------------- /ogbench/manipspace/lie/so3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Any 5 | 6 | import mujoco 7 | import numpy as np 8 | 9 | from ogbench.manipspace.lie.utils import get_epsilon 10 | 11 | _IDENTITIY_WXYZ = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64) 12 | _INVERT_QUAT_SIGN = np.array([1.0, -1.0, -1.0, -1.0], dtype=np.float64) 13 | 14 | 15 | @dataclass(frozen=True) 16 | class RollPitchYaw: 17 | roll: float 18 | pitch: float 19 | yaw: float 20 | 21 | 22 | @dataclass(frozen=True) 23 | class SO3: 24 | """Special orthogonal group for 3D rotations. 25 | 26 | Internal parameterization is (qw, qx, qy, qz). 27 | Tangent parameterization is (omega_x, omega_y, omega_z). 28 | """ 29 | 30 | wxyz: np.ndarray 31 | matrix_dim: int = 3 32 | parameters_dim: int = 4 33 | tangent_dim: int = 3 34 | space_dim: int = 3 35 | 36 | def __post_init__(self) -> None: 37 | if self.wxyz.shape != (self.parameters_dim,): 38 | raise ValueError(f'Expeced wxyz to be a length 4 vector but got {self.wxyz.shape[0]}.') 39 | 40 | def __repr__(self) -> str: 41 | wxyz = np.round(self.wxyz, 5) 42 | return f'{self.__class__.__name__}(wxyz={wxyz})' 43 | 44 | def copy(self) -> SO3: 45 | return SO3(wxyz=self.wxyz.copy()) 46 | 47 | @staticmethod 48 | def from_x_radians(theta: float) -> SO3: 49 | return SO3.exp(np.array([theta, 0.0, 0.0], dtype=np.float64)) 50 | 51 | @staticmethod 52 | def from_y_radians(theta: float) -> SO3: 53 | return SO3.exp(np.array([0.0, theta, 0.0], dtype=np.float64)) 54 | 55 | @staticmethod 56 | def from_z_radians(theta: float) -> SO3: 57 | return SO3.exp(np.array([0.0, 0.0, theta], dtype=np.float64)) 58 | 59 | @staticmethod 60 | def from_rpy_radians( 61 | roll: float, 62 | pitch: float, 63 | yaw: float, 64 | ) -> SO3: 65 | return SO3.from_z_radians(yaw) @ SO3.from_y_radians(pitch) @ SO3.from_x_radians(roll) 66 | 67 | @staticmethod 68 | def from_matrix(matrix: np.ndarray) -> SO3: 69 | assert matrix.shape == (SO3.matrix_dim, SO3.matrix_dim) 70 | wxyz = np.zeros(SO3.parameters_dim, dtype=np.float64) 71 | mujoco.mju_mat2Quat(wxyz, matrix.ravel()) 72 | return SO3(wxyz=wxyz) 73 | 74 | @staticmethod 75 | def identity() -> SO3: 76 | return SO3(wxyz=_IDENTITIY_WXYZ) 77 | 78 | @staticmethod 79 | def sample_uniform() -> SO3: 80 | u1, u2, u3 = np.random.uniform( 81 | low=np.zeros(shape=(3,)), 82 | high=np.array([1.0, 2.0 * np.pi, 2.0 * np.pi]), 83 | ) 84 | a = np.sqrt(1.0 - u1) 85 | b = np.sqrt(u1) 86 | wxyz = np.array( 87 | [ 88 | a * np.sin(u2), 89 | a * np.cos(u2), 90 | b * np.sin(u3), 91 | b * np.cos(u3), 92 | ], 93 | dtype=np.float64, 94 | ) 95 | return SO3(wxyz=wxyz) 96 | 97 | def as_matrix(self) -> np.ndarray: 98 | mat = np.zeros(9, dtype=np.float64) 99 | mujoco.mju_quat2Mat(mat, self.wxyz) 100 | return mat.reshape(3, 3) 101 | 102 | def compute_roll_radians(self) -> float: 103 | q0, q1, q2, q3 = self.wxyz 104 | return np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1**2 + q2**2)) 105 | 106 | def compute_pitch_radians(self) -> float: 107 | q0, q1, q2, q3 = self.wxyz 108 | return np.arcsin(2 * (q0 * q2 - q3 * q1)) 109 | 110 | def compute_yaw_radians(self) -> float: 111 | q0, q1, q2, q3 = self.wxyz 112 | return np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2**2 + q3**2)) 113 | 114 | def as_rpy_radians(self) -> RollPitchYaw: 115 | return RollPitchYaw( 116 | roll=self.compute_roll_radians(), 117 | pitch=self.compute_pitch_radians(), 118 | yaw=self.compute_yaw_radians(), 119 | ) 120 | 121 | @staticmethod 122 | def exp(tangent: np.ndarray) -> SO3: 123 | assert tangent.shape == (SO3.tangent_dim,) 124 | theta_squared = tangent @ tangent 125 | theta_pow_4 = theta_squared * theta_squared 126 | use_taylor = theta_squared < get_epsilon(tangent.dtype) 127 | safe_theta = 1.0 if use_taylor else np.sqrt(theta_squared) 128 | safe_half_theta = 0.5 * safe_theta 129 | if use_taylor: 130 | real = 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0 131 | imaginary = 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0 132 | else: 133 | real = np.cos(safe_half_theta) 134 | imaginary = np.sin(safe_half_theta) / safe_theta 135 | wxyz = np.concatenate([np.array([real]), imaginary * tangent]) 136 | return SO3(wxyz=wxyz) 137 | 138 | def log(self) -> np.ndarray: 139 | w = self.wxyz[0] 140 | norm_sq = self.wxyz[1:] @ self.wxyz[1:] 141 | use_taylor = norm_sq < get_epsilon(norm_sq.dtype) 142 | norm_safe = 1.0 if use_taylor else np.sqrt(norm_sq) 143 | w_safe = w if use_taylor else 1.0 144 | atan_n_over_w = np.arctan2(-norm_safe if w < 0 else norm_safe, abs(w)) 145 | if use_taylor: 146 | atan_factor = 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3 147 | else: 148 | if abs(w) < get_epsilon(w.dtype): 149 | scl = 1.0 if w > 0.0 else -1.0 150 | atan_factor = scl * np.pi / norm_safe 151 | else: 152 | atan_factor = 2.0 * atan_n_over_w / norm_safe 153 | return atan_factor * self.wxyz[1:] 154 | 155 | def adjoint(self) -> np.ndarray: 156 | return self.as_matrix() 157 | 158 | def inverse(self) -> SO3: 159 | return SO3(wxyz=self.wxyz * _INVERT_QUAT_SIGN) 160 | 161 | def normalize(self) -> SO3: 162 | return SO3(wxyz=self.wxyz / np.linalg.norm(self.wxyz)) 163 | 164 | def apply(self, target: np.ndarray) -> np.ndarray: 165 | assert target.shape == (SO3.space_dim,) 166 | padded_target = np.concatenate([np.zeros(1, dtype=np.float64), target]) 167 | return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[1:] 168 | 169 | def multiply(self, other: SO3) -> SO3: 170 | w0, x0, y0, z0 = self.wxyz 171 | w1, x1, y1, z1 = other.wxyz 172 | return SO3( 173 | wxyz=np.array( 174 | [ 175 | -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1, 176 | x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1, 177 | -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1, 178 | x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1, 179 | ], 180 | dtype=np.float64, 181 | ) 182 | ) 183 | 184 | def __matmul__(self, other: Any) -> Any: 185 | if isinstance(other, np.ndarray): 186 | return self.apply(target=other) 187 | elif isinstance(other, SO3): 188 | return self.multiply(other=other) 189 | else: 190 | raise ValueError(f'Unsupported argument type for @ operator: {type(other)}') 191 | -------------------------------------------------------------------------------- /ogbench/manipspace/lie/utils.py: -------------------------------------------------------------------------------- 1 | import mujoco 2 | import numpy as np 3 | 4 | 5 | def get_epsilon(dtype: np.dtype) -> float: 6 | return { 7 | np.dtype('float32'): 1e-5, 8 | np.dtype('float64'): 1e-10, 9 | }[dtype] 10 | 11 | 12 | def skew(x: np.ndarray) -> np.ndarray: 13 | assert x.shape == (3,) 14 | wx, wy, wz = x 15 | return np.array( 16 | [ 17 | [0.0, -wz, wy], 18 | [wz, 0.0, -wx], 19 | [-wy, wx, 0.0], 20 | ] 21 | ) 22 | 23 | 24 | def mat2quat(mat: np.ndarray): 25 | """Convert a MuJoCo matrix (9,) to a quaternion (4,).""" 26 | assert mat.shape == (9,) 27 | quat = np.empty(4, dtype=np.float64) 28 | mujoco.mju_mat2Quat(quat, mat) 29 | return quat 30 | 31 | 32 | def interpolate(p0, p1, alpha=0.5): 33 | """Interpolate between two points on a manifold.""" 34 | assert 0.0 <= alpha <= 1.0 35 | exp_func = getattr(type(p0), 'exp') 36 | return p0 @ exp_func(alpha * (p0.inverse() @ p1).log()) 37 | -------------------------------------------------------------------------------- /ogbench/manipspace/mjcf_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | import numpy as np 5 | from dm_control import mjcf 6 | from lxml import etree 7 | 8 | 9 | def attach( 10 | parent_xml_or_model: Any, 11 | child_xml_or_model: Any, 12 | attach_site: Any = None, 13 | remove_keyframes: bool = True, 14 | add_freejoint: bool = False, 15 | ) -> mjcf.Element: 16 | if isinstance(parent_xml_or_model, Path): 17 | assert parent_xml_or_model.exists() 18 | parent = mjcf.from_path(parent_xml_or_model.as_posix()) 19 | else: 20 | assert isinstance(parent_xml_or_model, mjcf.RootElement) 21 | parent = parent_xml_or_model 22 | 23 | if isinstance(child_xml_or_model, Path): 24 | assert child_xml_or_model.exists() 25 | child = mjcf.from_path(child_xml_or_model.as_posix()) 26 | else: 27 | assert isinstance(child_xml_or_model, mjcf.RootElement) 28 | child = child_xml_or_model 29 | 30 | if attach_site is not None: 31 | if isinstance(attach_site, str): 32 | attachment_site = parent.find('site', attach_site) 33 | assert attachment_site is not None 34 | else: 35 | assert isinstance(attach_site, mjcf.Element) 36 | attachment_site = attach_site 37 | frame = attachment_site.attach(child) 38 | else: 39 | frame = parent.attach(child) 40 | if add_freejoint: 41 | frame.add('freejoint') 42 | 43 | if remove_keyframes: 44 | keyframes = parent.find_all('key') 45 | if keyframes is not None: 46 | for key in keyframes: 47 | key.remove() 48 | 49 | return frame 50 | 51 | 52 | def to_string( 53 | root: mjcf.RootElement, 54 | precision: float = 17, 55 | zero_threshold: float = 0.0, 56 | pretty: bool = False, 57 | ) -> str: 58 | xml_string = root.to_xml_string(precision=precision, zero_threshold=zero_threshold) 59 | root = etree.XML(xml_string, etree.XMLParser(remove_blank_text=True)) 60 | 61 | # Remove hashes from asset filenames. 62 | tags = ['mesh', 'texture'] 63 | for tag in tags: 64 | assets = [asset for asset in root.find('asset').iter() if asset.tag == tag and 'file' in asset.attrib] 65 | for asset in assets: 66 | name, extension = asset.get('file').split('.') 67 | asset.set('file', '.'.join((name[:-41], extension))) # Remove hash. 68 | 69 | if not pretty: 70 | return etree.tostring(root, pretty_print=True).decode() 71 | 72 | # Remove auto-generated names. 73 | for elem in root.iter(): 74 | for key in elem.keys(): 75 | if key == 'name' and 'unnamed' in elem.get(key): 76 | elem.attrib.pop(key) 77 | 78 | # Get string from lxml. 79 | xml_string = etree.tostring(root, pretty_print=True) 80 | 81 | # Remove redundant attributes. 82 | xml_string = xml_string.replace(b' gravcomp="0"', b'') 83 | 84 | # Insert spaces between top-level elements. 85 | lines = xml_string.splitlines() 86 | newlines = [] 87 | for line in lines: 88 | newlines.append(line) 89 | if line.startswith(b' <'): 90 | if line.startswith(b' '): 91 | newlines.append(b'') 92 | newlines.append(b'') 93 | xml_string = b'\n'.join(newlines) 94 | 95 | return xml_string.decode() 96 | 97 | 98 | def get_assets(root: mjcf.RootElement) -> dict: 99 | assets = {} 100 | for file, payload in root.get_assets().items(): 101 | name, extension = file.split('.') 102 | assets['.'.join((name[:-41], extension))] = payload # Remove hash. 103 | return assets 104 | 105 | 106 | def safe_find_all(root: mjcf.RootElement, namespace: str, *args, **kwargs): 107 | """Find all given elements or throw an error if none are found.""" 108 | features = root.find_all(namespace, *args, **kwargs) 109 | if not features: 110 | raise ValueError(f'{namespace} not found in the MJCF model.') 111 | return features 112 | 113 | 114 | def safe_find(root: mjcf.RootElement, namespace: str, identifier: str): 115 | """Find the given element or throw an error if not found.""" 116 | feature = root.find(namespace, identifier) 117 | if feature is None: 118 | raise ValueError(f'{namespace} {identifier} not found.') 119 | return feature 120 | 121 | 122 | def add_bounding_box_site(body: mjcf.Element, lower: np.ndarray, upper: np.ndarray, **kwargs) -> mjcf.Element: 123 | """Visualize a bounding box as a box site attached to the given body.""" 124 | pos = (lower + upper) / 2 125 | size = (upper - lower) / 2 126 | size += 1e-7 127 | return body.add('site', type='box', pos=pos, size=size, **kwargs) 128 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/oracles/__init__.py -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/markov/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/oracles/markov/__init__.py -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/markov/button_markov.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ogbench.manipspace.oracles.markov.markov_oracle import MarkovOracle 4 | 5 | 6 | class ButtonMarkovOracle(MarkovOracle): 7 | def __init__(self, max_step=100, gripper_always_closed=False, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._max_step = max_step 10 | self._gripper_always_closed = gripper_always_closed 11 | 12 | def reset(self, ob, info): 13 | self._done = False 14 | self._step = 0 15 | self._final_pos = np.random.uniform(*self._env.unwrapped._arm_sampling_bounds) 16 | self._final_yaw = np.random.uniform(-np.pi, np.pi) 17 | 18 | def select_action(self, ob, info): 19 | effector_pos = info['proprio/effector_pos'] 20 | effector_yaw = info['proprio/effector_yaw'][0] 21 | 22 | target_button = info['privileged/target_button'] 23 | button_target_top_pos = info['privileged/target_button_top_pos'] + np.array([0, 0, 0.06]) 24 | button_target_bottom_pos = info['privileged/target_button_top_pos'] - np.array([0, 0, 0.022]) 25 | button_state = info[f'privileged/button_{target_button}_state'] 26 | target_state = info['privileged/target_button_state'] 27 | 28 | above_threshold = 0.16 29 | above = effector_pos[2] > above_threshold 30 | xy_aligned = np.linalg.norm(button_target_top_pos[:2] - effector_pos[:2]) <= 0.04 31 | target_achieved = button_state == target_state 32 | final_pos_aligned = np.linalg.norm(self._final_pos - effector_pos) <= 0.04 33 | 34 | gain_pos = 5 35 | gain_yaw = 3 36 | action = np.zeros(5) 37 | if not target_achieved: 38 | if not xy_aligned: 39 | self.print_phase('1: Move above the button') 40 | action = np.zeros(5) 41 | diff = button_target_top_pos - effector_pos 42 | diff = self.shape_diff(diff) 43 | action[:3] = diff[:3] * gain_pos 44 | action[4] = 1 45 | else: 46 | self.print_phase('2: Press the button') 47 | action = np.zeros(5) 48 | diff = button_target_bottom_pos - effector_pos 49 | diff = self.shape_diff(diff) 50 | action[:3] = diff[:3] * gain_pos 51 | action[4] = 1 52 | else: 53 | if not above: 54 | self.print_phase('3: Release the button') 55 | diff = ( 56 | np.array([button_target_top_pos[0], button_target_top_pos[1], above_threshold * 2]) - effector_pos 57 | ) 58 | diff = self.shape_diff(diff) 59 | action[:3] = diff[:3] * gain_pos 60 | action[3] = (self._final_yaw - effector_yaw) * gain_yaw 61 | action[4] = 1 if self._gripper_always_closed else -1 62 | else: 63 | self.print_phase('4: Move to the final position') 64 | diff = self._final_pos - effector_pos 65 | diff = self.shape_diff(diff) 66 | action[:3] = diff[:3] * gain_pos 67 | action[3] = (self._final_yaw - effector_yaw) * gain_yaw 68 | action[4] = 1 if self._gripper_always_closed else -1 69 | 70 | if final_pos_aligned: 71 | self._done = True 72 | 73 | action = np.clip(action, -1, 1) 74 | if self._debug: 75 | print(action) 76 | 77 | self._step += 1 78 | if self._step == self._max_step: 79 | self._done = True 80 | 81 | return action 82 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/markov/cube_markov.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ogbench.manipspace.oracles.markov.markov_oracle import MarkovOracle 4 | 5 | 6 | class CubeMarkovOracle(MarkovOracle): 7 | def __init__(self, max_step=200, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._max_step = max_step 10 | 11 | def reset(self, ob, info): 12 | self._done = False 13 | self._step = 0 14 | self._max_step = 200 15 | self._final_pos = np.random.uniform(*self._env.unwrapped._arm_sampling_bounds) 16 | self._final_yaw = np.random.uniform(-np.pi, np.pi) 17 | 18 | def select_action(self, ob, info): 19 | effector_pos = info['proprio/effector_pos'] 20 | effector_yaw = info['proprio/effector_yaw'][0] 21 | gripper_opening = info['proprio/gripper_opening'] 22 | 23 | target_block = info['privileged/target_block'] 24 | block_pos = info[f'privileged/block_{target_block}_pos'] 25 | block_yaw = self.shortest_yaw(effector_yaw, info[f'privileged/block_{target_block}_yaw'][0]) 26 | target_pos = info['privileged/target_block_pos'] 27 | target_yaw = self.shortest_yaw(effector_yaw, info['privileged/target_block_yaw'][0]) 28 | 29 | block_above_offset = np.array([0, 0, 0.18]) 30 | above_threshold = 0.16 31 | gripper_closed = info['proprio/gripper_contact'] > 0.5 32 | gripper_open = info['proprio/gripper_contact'] < 0.1 33 | above = effector_pos[2] > above_threshold 34 | xy_aligned = np.linalg.norm(block_pos[:2] - effector_pos[:2]) <= 0.04 35 | pos_aligned = np.linalg.norm(block_pos - effector_pos) <= 0.02 36 | target_xy_aligned = np.linalg.norm(target_pos[:2] - block_pos[:2]) <= 0.04 37 | target_pos_aligned = np.linalg.norm(target_pos - block_pos) <= 0.02 38 | final_pos_aligned = np.linalg.norm(self._final_pos - effector_pos) <= 0.04 39 | 40 | gain_pos = 5 41 | gain_yaw = 3 42 | action = np.zeros(5) 43 | if not target_pos_aligned: 44 | if not xy_aligned: 45 | self.print_phase('1: Move above the block') 46 | action = np.zeros(5) 47 | diff = block_pos + block_above_offset - effector_pos 48 | diff = self.shape_diff(diff) 49 | action[:3] = diff[:3] * gain_pos 50 | action[3] = (block_yaw - effector_yaw) * gain_yaw 51 | action[4] = -1 52 | elif not pos_aligned: 53 | self.print_phase('2: Move to the block') 54 | diff = block_pos - effector_pos 55 | diff = self.shape_diff(diff) 56 | action[:3] = diff[:3] * gain_pos 57 | action[3] = (block_yaw - effector_yaw) * gain_yaw 58 | action[4] = -1 59 | elif pos_aligned and not gripper_closed: 60 | self.print_phase('3: Grasp') 61 | diff = block_pos - effector_pos 62 | diff = self.shape_diff(diff) 63 | action[:3] = diff[:3] * gain_pos 64 | action[3] = (block_yaw - effector_yaw) * gain_yaw 65 | action[4] = 1 66 | elif pos_aligned and gripper_closed and not above and not target_xy_aligned: 67 | self.print_phase('4: Move in the air') 68 | diff = np.array([block_pos[0], block_pos[1], block_above_offset[2] * 2]) - effector_pos 69 | diff = self.shape_diff(diff) 70 | action[:3] = diff[:3] * gain_pos 71 | action[3] = (target_yaw - block_yaw) * gain_yaw 72 | action[4] = 1 73 | elif pos_aligned and gripper_closed and above and not target_xy_aligned: 74 | self.print_phase('5: Move above the target') 75 | diff = target_pos + block_above_offset - effector_pos 76 | diff = self.shape_diff(diff) 77 | action[:3] = diff[:3] * gain_pos 78 | action[3] = (target_yaw - block_yaw) * gain_yaw 79 | action[4] = 1 80 | else: 81 | self.print_phase('6: Move to the target') 82 | diff = target_pos - effector_pos 83 | diff = self.shape_diff(diff) 84 | action[:3] = diff[:3] * gain_pos 85 | action[3] = (target_yaw - block_yaw) * gain_yaw 86 | action[4] = 1 87 | else: 88 | if not gripper_open: 89 | self.print_phase('7: Release') 90 | diff = target_pos - effector_pos 91 | diff = self.shape_diff(diff) 92 | action[:3] = diff[:3] * gain_pos 93 | action[3] = (target_yaw - block_yaw) * gain_yaw 94 | action[4] = -1 95 | elif gripper_open and not above: 96 | self.print_phase('8: Move in the air') 97 | diff = np.array([block_pos[0], block_pos[1], above_threshold * 2]) - effector_pos 98 | diff = self.shape_diff(diff) 99 | action[:3] = diff[:3] * gain_pos 100 | action[3] = (self._final_yaw - effector_yaw) * gain_yaw 101 | action[4] = -1 102 | else: 103 | self.print_phase('9: Move to the final position') 104 | diff = self._final_pos - effector_pos 105 | diff = self.shape_diff(diff) 106 | action[:3] = diff[:3] * gain_pos 107 | action[3] = (self._final_yaw - effector_yaw) * gain_yaw 108 | action[4] = -1 109 | 110 | if final_pos_aligned: 111 | self._done = True 112 | 113 | action = np.clip(action, -1, 1) 114 | if self._debug: 115 | print(action) 116 | 117 | self._step += 1 118 | if self._step == self._max_step: 119 | self._done = True 120 | 121 | return action 122 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/markov/drawer_markov.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ogbench.manipspace.oracles.markov.markov_oracle import MarkovOracle 4 | 5 | 6 | class DrawerMarkovOracle(MarkovOracle): 7 | def __init__(self, max_step=75, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._max_step = max_step 10 | 11 | def reset(self, ob, info): 12 | self._done = False 13 | self._step = 0 14 | self._final_pos = np.random.uniform(*self._env.unwrapped._arm_sampling_bounds) 15 | self._final_yaw = np.random.uniform(-np.pi, np.pi) 16 | 17 | def select_action(self, ob, info): 18 | effector_pos = info['proprio/effector_pos'] 19 | effector_yaw = info['proprio/effector_yaw'][0] 20 | gripper_opening = info['proprio/gripper_opening'] 21 | 22 | drawer_pos = info['privileged/drawer_handle_pos'] 23 | drawer_yaw = self.shortest_yaw(effector_yaw, info['privileged/drawer_handle_yaw'][0], n=2) 24 | target_pos = info['privileged/target_drawer_handle_pos'] 25 | 26 | drawer_above_offset = np.array([0, 0, 0.12]) 27 | above_threshold = 0.18 28 | above = effector_pos[2] > above_threshold 29 | xy_aligned = np.linalg.norm(drawer_pos[:2] - effector_pos[:2]) <= 0.04 30 | pos_aligned = np.linalg.norm(drawer_pos - effector_pos) <= 0.03 31 | target_pos_aligned = np.linalg.norm(target_pos - drawer_pos) <= 0.01 32 | final_pos_aligned = np.linalg.norm(self._final_pos - effector_pos) <= 0.04 33 | 34 | gain_pos = 5 35 | gain_yaw = 3 36 | action = np.zeros(5) 37 | if not target_pos_aligned: 38 | if not xy_aligned: 39 | self.print_phase('1: Move above the drawer handle') 40 | action = np.zeros(5) 41 | diff = drawer_pos + drawer_above_offset - effector_pos 42 | diff = self.shape_diff(diff) 43 | action[:3] = diff[:3] * gain_pos 44 | action[3] = (drawer_yaw - effector_yaw) * gain_yaw 45 | action[4] = -1 46 | elif not pos_aligned: 47 | self.print_phase('2: Move to the drawer handle') 48 | diff = drawer_pos - effector_pos 49 | diff = self.shape_diff(diff) 50 | action[:3] = diff[:3] * gain_pos 51 | action[3] = (drawer_yaw - effector_yaw) * gain_yaw 52 | action[4] = -1 53 | else: 54 | self.print_phase('3: Move to the target') 55 | diff = target_pos - effector_pos 56 | diff = self.shape_diff(diff) 57 | action[:3] = diff[:3] * gain_pos 58 | action[3] = (drawer_yaw - effector_yaw) * gain_yaw 59 | action[4] = -1 60 | else: 61 | if not above: 62 | self.print_phase('4: Move in the air') 63 | diff = ( 64 | np.array( 65 | [ 66 | drawer_pos[0], 67 | drawer_pos[1], 68 | above_threshold * 2, 69 | ] 70 | ) 71 | - effector_pos 72 | ) 73 | diff = self.shape_diff(diff) 74 | action[:3] = diff[:3] * gain_pos 75 | action[3] = (self._final_yaw - effector_yaw) * gain_yaw 76 | action[4] = -1 77 | else: 78 | self.print_phase('5: Move to the final position') 79 | diff = self._final_pos - effector_pos 80 | diff = self.shape_diff(diff) 81 | action[:3] = diff[:3] * gain_pos 82 | action[3] = (self._final_yaw - effector_yaw) * gain_yaw 83 | action[4] = -1 84 | 85 | if final_pos_aligned: 86 | self._done = True 87 | 88 | action = np.clip(action, -1, 1) 89 | if self._debug: 90 | print(action) 91 | 92 | self._step += 1 93 | if self._step == self._max_step: 94 | self._done = True 95 | 96 | return action 97 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/markov/markov_oracle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MarkovOracle: 5 | """Markovian oracle for manipulation tasks.""" 6 | 7 | def __init__(self, env, min_norm=0.4): 8 | """Initialize the oracle. 9 | 10 | Args: 11 | env: Environment. 12 | min_norm: Minimum norm for the relative position. Setting it to a non-zero value can help the agent to learn 13 | more robust policies. 14 | """ 15 | self._env = env 16 | self._min_norm = min_norm 17 | self._debug = False # Set to True to print debug information. 18 | self._done = False 19 | 20 | if self._debug: 21 | np.set_printoptions(suppress=True) 22 | 23 | def shape_diff(self, diff): 24 | """Shape the difference vector to have a minimum norm.""" 25 | diff_norm = np.linalg.norm(diff) 26 | if diff_norm >= self._min_norm: 27 | return diff 28 | else: 29 | return diff / (diff_norm + 1e-6) * self._min_norm 30 | 31 | def shortest_yaw(self, eff_yaw, obj_yaw, n=4): 32 | """Find the symmetry-aware shortest yaw angle to the object.""" 33 | symmetries = np.array([i * 2 * np.pi / n + obj_yaw for i in range(-n, n + 1)]) 34 | d = np.argmin(np.abs(eff_yaw - symmetries)) 35 | return symmetries[d] 36 | 37 | def print_phase(self, phase): 38 | """Print the current phase.""" 39 | if self._debug: 40 | print(f'Phase {phase:50}', end=' ') 41 | 42 | @property 43 | def done(self): 44 | return self._done 45 | 46 | def reset(self, ob, info): 47 | pass 48 | 49 | def select_action(self, ob, info): 50 | pass 51 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/markov/window_markov.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ogbench.manipspace.oracles.markov.markov_oracle import MarkovOracle 4 | 5 | 6 | class WindowMarkovOracle(MarkovOracle): 7 | def __init__(self, max_step=75, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._max_step = max_step 10 | 11 | def reset(self, ob, info): 12 | self._done = False 13 | self._step = 0 14 | arm_sampling_bounds = self._env.unwrapped._arm_sampling_bounds.copy() 15 | arm_sampling_bounds[0, 2] = max(arm_sampling_bounds[0, 2], 0.3) 16 | self._final_pos = np.random.uniform(*arm_sampling_bounds) 17 | self._final_yaw = np.random.uniform(-np.pi, np.pi) 18 | 19 | def select_action(self, ob, info): 20 | effector_pos = info['proprio/effector_pos'] 21 | effector_yaw = info['proprio/effector_yaw'][0] 22 | gripper_opening = info['proprio/gripper_opening'] 23 | 24 | window_pos = info['privileged/window_handle_pos'] 25 | window_yaw = self.shortest_yaw(effector_yaw, info['privileged/window_handle_yaw'][0], n=2) 26 | target_pos = info['privileged/target_window_handle_pos'] 27 | 28 | window_above_offset = np.array([0, 0, 0.06]) 29 | window_handle_offset = np.array([0, 0, 0]) 30 | above_threshold = 0.28 31 | above = effector_pos[2] > above_threshold 32 | xy_aligned = np.linalg.norm(window_pos[:2] + window_handle_offset[:2] - effector_pos[:2]) <= 0.04 33 | pos_aligned = np.linalg.norm(window_pos + window_handle_offset - effector_pos) <= 0.03 34 | target_pos_aligned = np.linalg.norm(target_pos - window_pos) <= 0.01 35 | final_pos_aligned = np.linalg.norm(self._final_pos - effector_pos) <= 0.04 36 | 37 | gain_pos = 5 38 | gain_yaw = 3 39 | action = np.zeros(5) 40 | if not target_pos_aligned: 41 | if not xy_aligned: 42 | self.print_phase('1: Move above the window handle') 43 | action = np.zeros(5) 44 | diff = window_pos + window_handle_offset + window_above_offset - effector_pos 45 | diff = self.shape_diff(diff) 46 | action[:3] = diff[:3] * gain_pos 47 | action[3] = (window_yaw - effector_yaw) * gain_yaw 48 | action[4] = -1 49 | elif not pos_aligned: 50 | self.print_phase('2: Move to the window handle') 51 | diff = window_pos + window_handle_offset - effector_pos 52 | diff = self.shape_diff(diff) 53 | action[:3] = diff[:3] * gain_pos 54 | action[3] = (window_yaw - effector_yaw) * gain_yaw 55 | action[4] = -1 56 | else: 57 | self.print_phase('3: Move to the target') 58 | diff = target_pos + window_handle_offset - effector_pos 59 | diff = self.shape_diff(diff) 60 | action[:3] = diff[:3] * gain_pos 61 | action[3] = (window_yaw - effector_yaw) * gain_yaw 62 | action[4] = 1 63 | else: 64 | if not above: 65 | self.print_phase('4: Move in the air') 66 | diff = ( 67 | np.array( 68 | [ 69 | window_pos[0], 70 | window_pos[1], 71 | above_threshold * 2, 72 | ] 73 | ) 74 | - effector_pos 75 | ) 76 | diff = self.shape_diff(diff) 77 | action[:3] = diff[:3] * gain_pos 78 | action[3] = (self._final_yaw - effector_yaw) * gain_yaw 79 | action[4] = -1 80 | else: 81 | self.print_phase('5: Move to the final position') 82 | diff = self._final_pos - effector_pos 83 | diff = self.shape_diff(diff) 84 | action[:3] = diff[:3] * gain_pos 85 | action[3] = (self._final_yaw - effector_yaw) * gain_yaw 86 | action[4] = -1 87 | 88 | if final_pos_aligned: 89 | self._done = True 90 | 91 | action = np.clip(action, -1, 1) 92 | if self._debug: 93 | print(action) 94 | 95 | self._step += 1 96 | if self._step == self._max_step: 97 | self._done = True 98 | 99 | return action 100 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/plan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/oracles/plan/__init__.py -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/plan/button_plan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ogbench.manipspace.oracles.plan.plan_oracle import PlanOracle 4 | 5 | 6 | class ButtonPlanOracle(PlanOracle): 7 | def __init__(self, gripper_always_closed=False, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._gripper_always_closed = gripper_always_closed 10 | 11 | def compute_keyframes(self, plan_input): 12 | # Poses. 13 | poses = {} 14 | poses['initial'] = plan_input['effector_initial'] 15 | poses['press_start'] = self.above(plan_input['button'], 0.06) 16 | poses['press'] = self.above(plan_input['button'], -0.025) 17 | poses['press_end'] = poses['press_start'] 18 | poses['final'] = plan_input['effector_goal'] 19 | 20 | # Times. 21 | times = {} 22 | distance = np.linalg.norm(poses['initial'].translation() - poses['press_start'].translation()) 23 | times['initial'] = 0.0 24 | times['press_start'] = times['initial'] + self._dt * (0.5 + distance * 4) 25 | times['press'] = times['press_start'] + self._dt * 0.8 26 | times['press_end'] = times['press'] + self._dt * 0.8 27 | times['final'] = times['press_end'] + self._dt * 1.25 28 | for time in times.keys(): 29 | if time != 'initial': 30 | times[time] += np.random.uniform(-1, 1) * self._dt * 0.1 31 | 32 | # Grasps. 33 | grasps = {} 34 | if self._gripper_always_closed: 35 | g = 1.0 36 | else: 37 | g = 0.0 38 | for name in times.keys(): 39 | if not self._gripper_always_closed: 40 | if name in {'press_start', 'final'}: 41 | g = 1.0 - g 42 | grasps[name] = g 43 | 44 | return times, poses, grasps 45 | 46 | def reset(self, ob, info): 47 | plan_input = { 48 | 'effector_initial': self.to_pose( 49 | pos=info['proprio/effector_pos'], 50 | yaw=info['proprio/effector_yaw'][0], 51 | ), 52 | 'effector_goal': self.to_pose( 53 | pos=np.random.uniform(*self._env.unwrapped._arm_sampling_bounds), 54 | yaw=np.random.uniform(-np.pi, np.pi), 55 | ), 56 | 'button': self.to_pose( 57 | pos=info['privileged/target_button_top_pos'], 58 | yaw=info['privileged/target_button_top_pos'][0], 59 | ), 60 | } 61 | 62 | times, poses, grasps = self.compute_keyframes(plan_input) 63 | poses = [poses[name] for name in times.keys()] 64 | grasps = [grasps[name] for name in times.keys()] 65 | times = list(times.values()) 66 | 67 | self._t_init = info['time'][0] 68 | self._t_max = times[-1] 69 | self._done = False 70 | self._plan = self.compute_plan(times, poses, grasps) 71 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/plan/cube_plan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ogbench.manipspace import lie 4 | from ogbench.manipspace.oracles.plan.plan_oracle import PlanOracle 5 | 6 | 7 | class CubePlanOracle(PlanOracle): 8 | def __init__( 9 | self, 10 | *args, 11 | **kwargs, 12 | ): 13 | super().__init__(*args, **kwargs) 14 | 15 | def compute_keyframes(self, plan_input): 16 | # Poses. 17 | poses = {} 18 | 19 | # Pick. 20 | block_initial = self.shortest_yaw( 21 | eff_yaw=self.get_yaw(plan_input['effector_initial']), 22 | obj_yaw=self.get_yaw(plan_input['block_initial']), 23 | translation=plan_input['block_initial'].translation(), 24 | ) 25 | poses['initial'] = plan_input['effector_initial'] 26 | poses['pick'] = self.above(block_initial, 0.1 + np.random.uniform(0, 0.1)) 27 | poses['pick_start'] = block_initial 28 | poses['pick_end'] = block_initial 29 | poses['postpick'] = poses['pick'] 30 | 31 | # Place. 32 | block_goal = self.shortest_yaw( 33 | eff_yaw=self.get_yaw(poses['postpick']), 34 | obj_yaw=self.get_yaw(plan_input['block_goal']), 35 | translation=plan_input['block_goal'].translation(), 36 | ) 37 | poses['place'] = self.above(block_goal, 0.1 + np.random.uniform(0, 0.1)) 38 | poses['place_start'] = block_goal 39 | poses['place_end'] = block_goal 40 | poses['postplace'] = poses['place'] 41 | poses['final'] = plan_input['effector_goal'] 42 | 43 | # Clearance. 44 | midway = lie.interpolate(poses['postpick'], poses['place']) 45 | poses['clearance'] = lie.SE3.from_rotation_and_translation( 46 | rotation=midway.rotation(), 47 | translation=np.array([*midway.translation()[:2], poses['initial'].translation()[-1]]) 48 | + np.random.uniform([-0.1, -0.1, 0], [0.1, 0.1, 0.2]), 49 | ) 50 | 51 | # Times. 52 | times = {} 53 | times['initial'] = 0.0 54 | times['pick'] = times['initial'] + self._dt 55 | times['pick_start'] = times['pick'] + self._dt * 1.5 56 | times['pick_end'] = times['pick_start'] + self._dt 57 | times['postpick'] = times['pick_end'] + self._dt 58 | times['clearance'] = times['postpick'] + self._dt 59 | times['place'] = times['clearance'] + self._dt 60 | times['place_start'] = times['place'] + self._dt * 1.5 61 | times['place_end'] = times['place_start'] + self._dt 62 | times['postplace'] = times['place_end'] + self._dt 63 | times['final'] = times['postplace'] + self._dt 64 | for time in times.keys(): 65 | if time != 'initial': 66 | times[time] += np.random.uniform(-1, 1) * self._dt * 0.2 67 | 68 | # Grasps. 69 | g = 0.0 70 | grasps = {} 71 | for name in times.keys(): 72 | if name in {'pick_end', 'place_end'}: 73 | g = 1.0 - g 74 | grasps[name] = g 75 | 76 | return times, poses, grasps 77 | 78 | def reset(self, ob, info): 79 | target_block = info['privileged/target_block'] 80 | plan_input = { 81 | 'effector_initial': self.to_pose( 82 | pos=info['proprio/effector_pos'], 83 | yaw=info['proprio/effector_yaw'][0], 84 | ), 85 | 'effector_goal': self.to_pose( 86 | pos=np.random.uniform(*self._env.unwrapped._arm_sampling_bounds), 87 | yaw=np.random.uniform(-np.pi, np.pi), 88 | ), 89 | 'block_initial': self.to_pose( 90 | pos=info[f'privileged/block_{target_block}_pos'], 91 | yaw=info[f'privileged/block_{target_block}_yaw'][0], 92 | ), 93 | 'block_goal': self.to_pose( 94 | pos=info['privileged/target_block_pos'], 95 | yaw=info['privileged/target_block_yaw'][0], 96 | ), 97 | } 98 | 99 | times, poses, grasps = self.compute_keyframes(plan_input) 100 | poses = [poses[name] for name in times.keys()] 101 | grasps = [grasps[name] for name in times.keys()] 102 | times = list(times.values()) 103 | 104 | self._t_init = info['time'][0] 105 | self._t_max = times[-1] 106 | self._done = False 107 | self._plan = self.compute_plan(times, poses, grasps) 108 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/plan/drawer_plan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ogbench.manipspace.oracles.plan.plan_oracle import PlanOracle 4 | 5 | 6 | class DrawerPlanOracle(PlanOracle): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | def compute_keyframes(self, plan_input): 11 | # Poses. 12 | poses = {} 13 | drawer_initial = self.shortest_yaw( 14 | eff_yaw=self.get_yaw(plan_input['effector_initial']), 15 | obj_yaw=self.get_yaw(plan_input['drawer_initial']), 16 | translation=plan_input['drawer_initial'].translation(), 17 | n=2, 18 | ) 19 | drawer_goal = self.shortest_yaw( 20 | eff_yaw=self.get_yaw(plan_input['effector_initial']), 21 | obj_yaw=self.get_yaw(plan_input['drawer_initial']), 22 | translation=plan_input['drawer_goal'].translation(), 23 | n=2, 24 | ) 25 | poses['initial'] = plan_input['effector_initial'] 26 | poses['approach'] = self.above(drawer_initial, 0.12) 27 | poses['grasp_start'] = drawer_initial 28 | poses['grasp_end'] = drawer_initial 29 | poses['move'] = drawer_goal 30 | poses['release'] = drawer_goal 31 | poses['clearance'] = self.above(drawer_goal, 0.12) 32 | poses['final'] = plan_input['effector_goal'] 33 | 34 | # Times. 35 | times = {} 36 | times['initial'] = 0.0 37 | times['approach'] = times['initial'] + self._dt 38 | times['grasp_start'] = times['approach'] + self._dt * 0.5 39 | times['grasp_end'] = times['grasp_start'] + self._dt * 0.5 40 | times['move'] = times['grasp_end'] + self._dt * 0.5 41 | times['release'] = times['move'] + self._dt * 0.5 42 | times['clearance'] = times['release'] + self._dt * 0.5 43 | times['final'] = times['clearance'] + self._dt 44 | for time in times.keys(): 45 | if time != 'initial': 46 | times[time] += np.random.uniform(-1, 1) * self._dt * 0.1 47 | 48 | # Grasps. 49 | grasps = {} 50 | g = 0.0 51 | for name in times.keys(): 52 | if name in {'grasp_end', 'release'}: 53 | g = 1.0 - g 54 | grasps[name] = g 55 | 56 | return times, poses, grasps 57 | 58 | def reset(self, ob, info): 59 | plan_input = { 60 | 'effector_initial': self.to_pose( 61 | pos=info['proprio/effector_pos'], 62 | yaw=info['proprio/effector_yaw'][0], 63 | ), 64 | 'effector_goal': self.to_pose( 65 | pos=np.random.uniform(*self._env.unwrapped._arm_sampling_bounds), 66 | yaw=np.random.uniform(-np.pi, np.pi), 67 | ), 68 | 'drawer_initial': self.to_pose( 69 | pos=info['privileged/drawer_handle_pos'], 70 | yaw=info['privileged/drawer_handle_yaw'][0], 71 | ), 72 | 'drawer_goal': self.to_pose( 73 | pos=info['privileged/target_drawer_handle_pos'], 74 | yaw=info['privileged/drawer_handle_yaw'][0], 75 | ), 76 | } 77 | 78 | times, poses, grasps = self.compute_keyframes(plan_input) 79 | poses = [poses[name] for name in times.keys()] 80 | grasps = [grasps[name] for name in times.keys()] 81 | times = list(times.values()) 82 | 83 | self._t_init = info['time'][0] 84 | self._t_max = times[-1] 85 | self._done = False 86 | self._plan = self.compute_plan(times, poses, grasps) 87 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/plan/plan_oracle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import interp1d 3 | from scipy.ndimage import gaussian_filter1d 4 | 5 | from ogbench.manipspace import lie 6 | 7 | 8 | class PlanOracle: 9 | """Non-Markovian oracle that follows a pre-computed plan. 10 | 11 | It first generates a plan by interpolating the keyframes of the task and adds temporally correlated noise. Then, it 12 | computes the actions by computing the difference between the current state and the next state in the plan. 13 | """ 14 | 15 | def __init__(self, env, segment_dt=0.4, noise=0.1, noise_smoothing=0.5): 16 | """Initialize the oracle. 17 | 18 | Args: 19 | env: Environment. 20 | segment_dt: Default duration of each segment between keyframes in the plan. 21 | noise: Noise level to add to the plan. 22 | noise_smoothing: Noise smoothing level. 23 | """ 24 | self._env = env 25 | self._env_dt = self._env.unwrapped._control_timestep 26 | self._dt = segment_dt 27 | self._noise = noise 28 | self._noise_smoothing = noise_smoothing 29 | 30 | self._done = False 31 | self._t_init = None 32 | self._t_max = None 33 | self._plan = None 34 | 35 | def above(self, pose, z): 36 | return ( 37 | lie.SE3.from_rotation_and_translation( 38 | rotation=lie.SO3.identity(), 39 | translation=np.array([0.0, 0.0, z]), 40 | ) 41 | @ pose 42 | ) 43 | 44 | def to_pose(self, pos, yaw): 45 | return lie.SE3.from_rotation_and_translation( 46 | rotation=lie.SO3.from_z_radians(yaw), 47 | translation=pos, 48 | ) 49 | 50 | def get_yaw(self, pose): 51 | yaw = pose.rotation().compute_yaw_radians() 52 | if yaw < 0.0: 53 | return yaw + 2 * np.pi 54 | return yaw 55 | 56 | def shortest_yaw(self, eff_yaw, obj_yaw, translation, n=4): 57 | """Find the symmetry-aware shortest yaw angle to the object.""" 58 | symmetries = np.array([i * 2 * np.pi / n + obj_yaw for i in range(-n, n + 1)]) 59 | d = np.argmin(np.abs(eff_yaw - symmetries)) 60 | return lie.SE3.from_rotation_and_translation( 61 | rotation=lie.SO3.from_z_radians(symmetries[d]), 62 | translation=translation, 63 | ) 64 | 65 | def compute_plan(self, times, poses, grasps): 66 | # Interpolate grasps. 67 | grasp_interp = interp1d(times, grasps, kind='linear', axis=0, assume_sorted=True) 68 | 69 | # Interpolate poses. 70 | xyzs = [p.translation() for p in poses] 71 | xyz_interp = interp1d(times, xyzs, kind='linear', axis=0, assume_sorted=True) 72 | 73 | # Interpolate orientations. 74 | quats = [p.rotation() for p in poses] 75 | 76 | def quat_interp(t): 77 | s = np.searchsorted(times, t, side='right') - 1 78 | interp_time = (t - times[s]) / (times[s + 1] - times[s]) 79 | interp_time = np.clip(interp_time, 0.0, 1.0) 80 | return lie.interpolate(quats[s], quats[s + 1], interp_time) 81 | 82 | # Generate the plan. 83 | plan = [] 84 | t = 0.0 85 | while t < self._t_max: 86 | action = np.zeros(5) 87 | action[:3] = xyz_interp(t) 88 | action[3] = quat_interp(t).compute_yaw_radians() 89 | action[4] = grasp_interp(t) 90 | plan.append(action) 91 | t += self._env_dt 92 | 93 | plan = np.array(plan) 94 | 95 | # Add temporally correlated noise to the plan. 96 | if self._noise > 0: 97 | noise = np.random.normal(0, 1, size=(len(plan), 5)) * np.array([0.05, 0.05, 0.05, 0.3, 1.0]) * self._noise 98 | noise = gaussian_filter1d(noise, axis=0, sigma=self._noise_smoothing) 99 | plan += noise 100 | 101 | return plan 102 | 103 | @property 104 | def done(self): 105 | return self._done 106 | 107 | def reset(self, ob, info): 108 | pass 109 | 110 | def select_action(self, ob, info): 111 | # Find the current plan index. 112 | cur_plan_idx = int((info['time'][0] - self._t_init + 1e-7) // self._env_dt) 113 | if cur_plan_idx >= len(self._plan) - 1: 114 | cur_plan_idx = len(self._plan) - 1 115 | self._done = True 116 | 117 | # Compute the difference between the current state and the current plan. 118 | ab_action = self._plan[cur_plan_idx] 119 | action = np.zeros(5) 120 | action[:3] = ab_action[:3] - info['proprio/effector_pos'] 121 | action[3] = ab_action[3] - info['proprio/effector_yaw'][0] 122 | action[4] = ab_action[4] - info['proprio/gripper_opening'][0] 123 | action = self._env.unwrapped.normalize_action(action) 124 | 125 | return action 126 | -------------------------------------------------------------------------------- /ogbench/manipspace/oracles/plan/window_plan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ogbench.manipspace.oracles.plan.plan_oracle import PlanOracle 4 | 5 | 6 | class WindowPlanOracle(PlanOracle): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | def compute_keyframes(self, plan_input): 11 | # Poses. 12 | poses = {} 13 | window_initial = self.shortest_yaw( 14 | eff_yaw=self.get_yaw(plan_input['effector_initial']), 15 | obj_yaw=self.get_yaw(plan_input['window_initial']), 16 | translation=plan_input['window_initial'].translation(), 17 | n=2, 18 | ) 19 | window_goal = self.shortest_yaw( 20 | eff_yaw=self.get_yaw(plan_input['effector_initial']), 21 | obj_yaw=self.get_yaw(plan_input['window_initial']), 22 | translation=plan_input['window_goal'].translation(), 23 | n=2, 24 | ) 25 | poses['initial'] = plan_input['effector_initial'] 26 | poses['approach'] = self.above(window_initial, 0.06) 27 | poses['grasp_start'] = window_initial 28 | poses['grasp_end'] = window_initial 29 | poses['move'] = window_goal 30 | poses['release'] = window_goal 31 | poses['clearance'] = self.above(window_goal, 0.06) 32 | poses['final'] = plan_input['effector_goal'] 33 | 34 | # Times. 35 | times = {} 36 | times['initial'] = 0.0 37 | times['approach'] = times['initial'] + self._dt 38 | times['grasp_start'] = times['approach'] + self._dt * 0.5 39 | times['grasp_end'] = times['grasp_start'] + self._dt * 0.5 40 | times['move'] = times['grasp_end'] + self._dt * 0.5 41 | times['release'] = times['move'] + self._dt * 0.5 42 | times['clearance'] = times['release'] + self._dt * 0.5 43 | times['final'] = times['clearance'] + self._dt 44 | for time in times.keys(): 45 | if time != 'initial': 46 | times[time] += np.random.uniform(-1, 1) * self._dt * 0.1 47 | 48 | # Grasps. 49 | grasps = {} 50 | g = 0.0 51 | for name in times.keys(): 52 | if name in {'grasp_end', 'release'}: 53 | g = 1.0 - g 54 | grasps[name] = g 55 | 56 | return times, poses, grasps 57 | 58 | def reset(self, ob, info): 59 | plan_input = { 60 | 'effector_initial': self.to_pose( 61 | pos=info['proprio/effector_pos'], 62 | yaw=info['proprio/effector_yaw'][0], 63 | ), 64 | 'effector_goal': self.to_pose( 65 | pos=np.random.uniform(*self._env.unwrapped._arm_sampling_bounds), 66 | yaw=np.random.uniform(-np.pi, np.pi), 67 | ), 68 | 'window_initial': self.to_pose( 69 | pos=info['privileged/window_handle_pos'], 70 | yaw=info['privileged/window_handle_yaw'][0], 71 | ), 72 | 'window_goal': self.to_pose( 73 | pos=info['privileged/target_window_handle_pos'], 74 | yaw=info['privileged/window_handle_yaw'][0], 75 | ), 76 | } 77 | 78 | times, poses, grasps = self.compute_keyframes(plan_input) 79 | poses = [poses[name] for name in times.keys()] 80 | grasps = [grasps[name] for name in times.keys()] 81 | times = list(times.values()) 82 | 83 | self._t_init = info['time'][0] 84 | self._t_max = times[-1] 85 | self._done = False 86 | self._plan = self.compute_plan(times, poses, grasps) 87 | -------------------------------------------------------------------------------- /ogbench/manipspace/viewer_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from dm_control.viewer import user_input 4 | 5 | 6 | @dataclass 7 | class KeyCallback: 8 | reset: bool = False 9 | pause: bool = False 10 | 11 | def __call__(self, key: int) -> None: 12 | if key == user_input.KEY_ENTER: 13 | self.reset = True 14 | elif key == user_input.KEY_SPACE: 15 | self.pause = not self.pause 16 | -------------------------------------------------------------------------------- /ogbench/online_locomotion/__init__.py: -------------------------------------------------------------------------------- 1 | from gymnasium.envs.registration import register 2 | 3 | register( 4 | id='online-ant-v0', 5 | entry_point='ogbench.online_locomotion.ant:AntEnv', 6 | max_episode_steps=1000, 7 | ) 8 | register( 9 | id='online-antball-v0', 10 | entry_point='ogbench.online_locomotion.ant_ball:AntBallEnv', 11 | max_episode_steps=200, 12 | ) 13 | register( 14 | id='online-humanoid-v0', 15 | entry_point='ogbench.online_locomotion.humanoid:HumanoidEnv', 16 | max_episode_steps=1000, 17 | ) 18 | -------------------------------------------------------------------------------- /ogbench/online_locomotion/ant.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gymnasium 4 | import numpy as np 5 | from gymnasium import utils 6 | from gymnasium.envs.mujoco import MujocoEnv 7 | from gymnasium.spaces import Box 8 | 9 | DEFAULT_CAMERA_CONFIG = { 10 | 'distance': 4.0, 11 | } 12 | 13 | 14 | class AntEnv(MujocoEnv, utils.EzPickle): 15 | """Gymnasium Ant environment. 16 | 17 | Unlike the original Ant environment, this environment uses a restricted joint range for the actuators, as typically 18 | done in previous works in hierarchical reinforcement learning. It also uses a control frequency of 10Hz instead of 19 | 20Hz, which is the default in the original environment. 20 | """ 21 | 22 | xml_file = os.path.join(os.path.dirname(__file__), 'assets', 'ant.xml') 23 | metadata = { 24 | 'render_modes': ['human', 'rgb_array', 'depth_array'], 25 | 'render_fps': 10, 26 | } 27 | if gymnasium.__version__ >= '1.1.0': 28 | metadata['render_modes'] += ['rgbd_tuple'] 29 | 30 | def __init__( 31 | self, 32 | xml_file=None, 33 | ctrl_cost_weight=0.5, 34 | use_contact_forces=False, 35 | contact_cost_weight=5e-4, 36 | healthy_reward=1.0, 37 | terminate_when_unhealthy=True, 38 | healthy_z_range=(0.2, 1.0), 39 | contact_force_range=(-1.0, 1.0), 40 | reset_noise_scale=0.1, 41 | exclude_current_positions_from_observation=True, 42 | **kwargs, 43 | ): 44 | if xml_file is None: 45 | xml_file = self.xml_file 46 | utils.EzPickle.__init__( 47 | self, 48 | xml_file, 49 | ctrl_cost_weight, 50 | use_contact_forces, 51 | contact_cost_weight, 52 | healthy_reward, 53 | terminate_when_unhealthy, 54 | healthy_z_range, 55 | contact_force_range, 56 | reset_noise_scale, 57 | exclude_current_positions_from_observation, 58 | **kwargs, 59 | ) 60 | 61 | self._ctrl_cost_weight = ctrl_cost_weight 62 | self._contact_cost_weight = contact_cost_weight 63 | 64 | self._healthy_reward = healthy_reward 65 | self._terminate_when_unhealthy = terminate_when_unhealthy 66 | self._healthy_z_range = healthy_z_range 67 | 68 | self._contact_force_range = contact_force_range 69 | 70 | self._reset_noise_scale = reset_noise_scale 71 | 72 | self._use_contact_forces = use_contact_forces 73 | 74 | self._exclude_current_positions_from_observation = exclude_current_positions_from_observation 75 | 76 | obs_shape = 27 77 | if not exclude_current_positions_from_observation: 78 | obs_shape += 2 79 | if use_contact_forces: 80 | obs_shape += 84 81 | 82 | observation_space = Box(low=-np.inf, high=np.inf, shape=(obs_shape,), dtype=np.float64) 83 | 84 | MujocoEnv.__init__( 85 | self, 86 | xml_file, 87 | 5, 88 | observation_space=observation_space, 89 | default_camera_config=DEFAULT_CAMERA_CONFIG, 90 | **kwargs, 91 | ) 92 | 93 | @property 94 | def healthy_reward(self): 95 | return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward 96 | 97 | def control_cost(self, action): 98 | control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) 99 | return control_cost 100 | 101 | @property 102 | def contact_forces(self): 103 | raw_contact_forces = self.data.cfrc_ext 104 | min_value, max_value = self._contact_force_range 105 | contact_forces = np.clip(raw_contact_forces, min_value, max_value) 106 | return contact_forces 107 | 108 | @property 109 | def contact_cost(self): 110 | contact_cost = self._contact_cost_weight * np.sum(np.square(self.contact_forces)) 111 | return contact_cost 112 | 113 | @property 114 | def is_healthy(self): 115 | state = self.state_vector() 116 | min_z, max_z = self._healthy_z_range 117 | is_healthy = np.isfinite(state).all() and min_z <= state[2] <= max_z 118 | return is_healthy 119 | 120 | @property 121 | def terminated(self): 122 | terminated = not self.is_healthy if self._terminate_when_unhealthy else False 123 | return terminated 124 | 125 | def step(self, action): 126 | xy_position_before = self.get_body_com('torso')[:2].copy() 127 | self.do_simulation(action, self.frame_skip) 128 | xy_position_after = self.get_body_com('torso')[:2].copy() 129 | 130 | xy_velocity = (xy_position_after - xy_position_before) / self.dt 131 | x_velocity, y_velocity = xy_velocity 132 | 133 | forward_reward = x_velocity 134 | healthy_reward = self.healthy_reward 135 | 136 | rewards = forward_reward + healthy_reward 137 | 138 | costs = ctrl_cost = self.control_cost(action) 139 | 140 | terminated = self.terminated 141 | observation = self._get_obs() 142 | info = { 143 | 'reward_forward': forward_reward, 144 | 'reward_ctrl': -ctrl_cost, 145 | 'reward_survive': healthy_reward, 146 | 'x_position': xy_position_after[0], 147 | 'y_position': xy_position_after[1], 148 | 'distance_from_origin': np.linalg.norm(xy_position_after, ord=2), 149 | 'x_velocity': x_velocity, 150 | 'y_velocity': y_velocity, 151 | 'forward_reward': forward_reward, 152 | } 153 | if self._use_contact_forces: 154 | contact_cost = self.contact_cost 155 | costs += contact_cost 156 | info['reward_ctrl'] = -contact_cost 157 | 158 | reward = rewards - costs 159 | 160 | if self.render_mode == 'human': 161 | self.render() 162 | return observation, reward, terminated, False, info 163 | 164 | def _get_obs(self): 165 | position = self.data.qpos.flat.copy() 166 | velocity = self.data.qvel.flat.copy() 167 | 168 | if self._exclude_current_positions_from_observation: 169 | position = position[2:] 170 | 171 | if self._use_contact_forces: 172 | contact_force = self.contact_forces.flat.copy() 173 | return np.concatenate((position, velocity, contact_force)) 174 | else: 175 | return np.concatenate((position, velocity)) 176 | 177 | def reset_model(self): 178 | noise_low = -self._reset_noise_scale 179 | noise_high = self._reset_noise_scale 180 | 181 | qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) 182 | qvel = self.init_qvel + self._reset_noise_scale * self.np_random.standard_normal(self.model.nv) 183 | self.set_state(qpos, qvel) 184 | 185 | observation = self._get_obs() 186 | 187 | return observation 188 | -------------------------------------------------------------------------------- /ogbench/online_locomotion/ant_ball.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import xml.etree.ElementTree as ET 3 | 4 | import numpy as np 5 | from gymnasium.spaces import Box 6 | 7 | from ogbench.online_locomotion.ant import AntEnv 8 | 9 | 10 | class AntBallEnv(AntEnv): 11 | """Gymnasium Ant environment with a ball.""" 12 | 13 | def __init__(self, xml_file=None, *args, **kwargs): 14 | if xml_file is None: 15 | xml_file = self.xml_file 16 | 17 | # Add a ball to the environment. 18 | tree = ET.parse(xml_file) 19 | worldbody = tree.find('.//worldbody') 20 | ET.SubElement( 21 | worldbody, 22 | 'geom', 23 | name='target', 24 | type='cylinder', 25 | size='.4 .05', 26 | pos='0 0 .05', 27 | material='target', 28 | contype='0', 29 | conaffinity='0', 30 | ) 31 | ball = ET.SubElement(worldbody, 'body', name='ball', pos='0 0 3') 32 | ET.SubElement(ball, 'freejoint', name='ball_root') 33 | ET.SubElement(ball, 'geom', name='ball', size='.25', material='ball', priority='1', conaffinity='1', condim='6') 34 | ET.SubElement(ball, 'light', name='ball_light', pos='0 0 4', mode='trackcom') 35 | 36 | # Rename the track camera to avoid automatic tracking. 37 | track_camera = tree.find('.//camera[@name="track"]') 38 | track_camera.set('name', 'back') 39 | _, xml_file = tempfile.mkstemp(text=True, suffix='.xml') 40 | tree.write(xml_file) 41 | 42 | super().__init__(xml_file=xml_file, *args, **kwargs) 43 | 44 | self.cur_goal_xy = np.zeros(2) 45 | self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._get_obs().shape[0],), dtype=np.float64) 46 | 47 | self.reset() 48 | self.render() 49 | 50 | # Adjust the camera. 51 | self.mujoco_renderer.viewer.cam.lookat[0] = 0 52 | self.mujoco_renderer.viewer.cam.lookat[1] = 0 53 | self.mujoco_renderer.viewer.cam.distance = 30 54 | self.mujoco_renderer.viewer.cam.elevation = -90 55 | 56 | def reset(self, options=None, *args, **kwargs): 57 | ob, info = super().reset(*args, **kwargs) 58 | 59 | agent_init_xy = np.random.uniform(low=-1, high=1, size=2) 60 | ball_init_xy = np.random.uniform(low=-2, high=2, size=2) 61 | goal_xy = np.random.uniform(low=-12, high=12, size=2) 62 | 63 | self.set_agent_ball_xy(agent_init_xy, ball_init_xy) 64 | self.set_goal(goal_xy=goal_xy) 65 | ob = self._get_obs() 66 | 67 | return ob, info 68 | 69 | def step(self, action): 70 | prev_agent_xy, prev_ball_xy = self.get_agent_ball_xy() 71 | goal_xy = self.cur_goal_xy 72 | prev_agent_ball_dist = np.linalg.norm(prev_agent_xy - prev_ball_xy) 73 | prev_ball_goal_dist = np.linalg.norm(prev_ball_xy - goal_xy) 74 | 75 | ob, reward, terminated, truncated, info = super().step(action) 76 | 77 | if np.linalg.norm(self.get_agent_ball_xy()[1] - self.cur_goal_xy) <= 0.5: 78 | info['success'] = 1.0 79 | else: 80 | info['success'] = 0.0 81 | 82 | # Compute the distance between the agent and the ball, and the ball and the goal. 83 | agent_xy, ball_xy = self.get_agent_ball_xy() 84 | agent_ball_dist = np.linalg.norm(agent_xy - ball_xy) 85 | ball_goal_dist = np.linalg.norm(ball_xy - goal_xy) 86 | 87 | # Use the change in distances as the reward. 88 | reward = ((prev_ball_goal_dist - ball_goal_dist) * 2.5 + (prev_agent_ball_dist - agent_ball_dist)) * 10 89 | 90 | return ob, reward, terminated, truncated, info 91 | 92 | def set_goal(self, goal_xy): 93 | self.cur_goal_xy = goal_xy 94 | self.model.geom('target').pos[:2] = goal_xy 95 | 96 | def get_agent_ball_xy(self): 97 | agent_xy = self.data.qpos[:2].copy() 98 | ball_xy = self.data.qpos[-7:-5].copy() 99 | 100 | return agent_xy, ball_xy 101 | 102 | def set_agent_ball_xy(self, agent_xy, ball_xy): 103 | qpos = self.data.qpos.copy() 104 | qvel = self.data.qvel.copy() 105 | qpos[:2] = agent_xy 106 | qpos[-7:-5] = ball_xy 107 | self.set_state(qpos, qvel) 108 | 109 | def _get_obs(self): 110 | # Return the agent's position, velocity, the ball's relative position, and the goal's relative position. 111 | agent_xy, ball_xy = self.get_agent_ball_xy() 112 | qpos = self.data.qpos.flat.copy() 113 | qvel = self.data.qvel.flat.copy() 114 | return np.concatenate([qpos[2:-7], qpos[-5:], qvel, ball_xy - agent_xy, np.array(self.cur_goal_xy) - ball_xy]) 115 | -------------------------------------------------------------------------------- /ogbench/online_locomotion/assets/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 91 | -------------------------------------------------------------------------------- /ogbench/online_locomotion/wrappers.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import numpy as np 3 | from gymnasium.spaces import Box 4 | 5 | 6 | class GymXYWrapper(gymnasium.Wrapper): 7 | """Wrapper for directional locomotion tasks.""" 8 | 9 | def __init__(self, env, resample_interval=100): 10 | """Initialize the wrapper. 11 | 12 | Args: 13 | env: Environment. 14 | resample_interval: Interval at which to resample the direction. 15 | """ 16 | super().__init__(env) 17 | 18 | self.z = None 19 | self.num_steps = 0 20 | self.resample_interval = resample_interval 21 | 22 | ob, _ = self.reset() 23 | self.observation_space = Box(low=-np.inf, high=np.inf, shape=ob.shape, dtype=np.float64) 24 | 25 | def reset(self, *args, **kwargs): 26 | ob, info = self.env.reset(*args, **kwargs) 27 | self.z = np.random.randn(2) 28 | self.z = self.z / np.linalg.norm(self.z) 29 | self.num_steps = 0 30 | 31 | return np.concatenate([ob, self.z]), info 32 | 33 | def step(self, action): 34 | cur_xy = self.unwrapped.data.qpos[:2].copy() 35 | ob, reward, terminated, truncated, info = self.env.step(action) 36 | next_xy = self.unwrapped.data.qpos[:2].copy() 37 | self.num_steps += 1 38 | 39 | # Reward is the dot product of the direction and the change in xy. 40 | reward = (next_xy - cur_xy).dot(self.z) 41 | info['xy'] = next_xy 42 | info['direction'] = self.z 43 | 44 | if self.num_steps % self.resample_interval == 0: 45 | self.z = np.random.randn(2) 46 | self.z = self.z / np.linalg.norm(self.z) 47 | 48 | return np.concatenate([ob, self.z]), reward, terminated, truncated, info 49 | 50 | 51 | class DMCHumanoidXYWrapper(GymXYWrapper): 52 | """Wrapper for the directional Humanoid task.""" 53 | 54 | def step(self, action): 55 | from ogbench.online_locomotion.humanoid import tolerance 56 | 57 | cur_xy = self.unwrapped.data.qpos[:2].copy() 58 | ob, reward, terminated, truncated, info = self.env.step(action) 59 | next_xy = self.unwrapped.data.qpos[:2].copy() 60 | self.num_steps += 1 61 | 62 | head_height = self.unwrapped.data.xpos[2, 2] # ['head', 'z'] 63 | torso_upright = self.unwrapped.data.xmat[1, 8] # ['torso', 'zz'] 64 | 65 | standing = tolerance(head_height, bounds=(1.4, float('inf')), margin=1.4 / 4) 66 | upright = tolerance(torso_upright, bounds=(0.9, float('inf')), margin=1.9, sigmoid='linear', value_at_margin=0) 67 | stand_reward = standing * upright 68 | 69 | # Reward is the dot product of the direction and the change in xy, multiplied by the stand reward to encourage 70 | # the agent to stand. 71 | reward = stand_reward * (1 + (next_xy - cur_xy).dot(self.z) * 100) 72 | 73 | info['xy'] = next_xy 74 | info['direction'] = self.z 75 | 76 | if self.num_steps % self.resample_interval == 0: 77 | self.z = np.random.randn(2) 78 | self.z = self.z / np.linalg.norm(self.z) 79 | 80 | return np.concatenate([ob, self.z]), reward, terminated, truncated, info 81 | -------------------------------------------------------------------------------- /ogbench/powderworld/__init__.py: -------------------------------------------------------------------------------- 1 | from gymnasium.envs.registration import register 2 | 3 | register( 4 | id='powderworld-easy-v0', 5 | entry_point='ogbench.powderworld.powderworld_env:PowderworldEnv', 6 | max_episode_steps=500, 7 | kwargs=dict(num_elems=2), 8 | ) 9 | 10 | register( 11 | id='powderworld-medium-v0', 12 | entry_point='ogbench.powderworld.powderworld_env:PowderworldEnv', 13 | max_episode_steps=500, 14 | kwargs=dict(num_elems=5), 15 | ) 16 | 17 | register( 18 | id='powderworld-hard-v0', 19 | entry_point='ogbench.powderworld.powderworld_env:PowderworldEnv', 20 | max_episode_steps=500, 21 | kwargs=dict(num_elems=8), 22 | ) 23 | -------------------------------------------------------------------------------- /ogbench/powderworld/behaviors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Behavior: 5 | """Base class for action behaviors.""" 6 | 7 | def __init__(self, env): 8 | self._env = env 9 | self._done = False 10 | self._step = 0 11 | 12 | self._size = self._env.unwrapped._world_size // self._env.unwrapped._brush_size 13 | assert self._env.unwrapped._brush_size == self._env.unwrapped._grid_size 14 | self._elem_name = None 15 | self._sequence = None 16 | 17 | @property 18 | def done(self): 19 | return self._done 20 | 21 | def reset(self, ob, info): 22 | pass 23 | 24 | def select_action(self, ob, info): 25 | x, y = self._sequence[self._step] 26 | 27 | self._step += 1 28 | if self._step == len(self._sequence): 29 | self._done = True 30 | 31 | return self._elem_name, x, y 32 | 33 | 34 | class FillBehavior(Behavior): 35 | """Fill the entire grid with a single element.""" 36 | 37 | def reset(self, ob, info): 38 | self._done = False 39 | self._step = 0 40 | self._elem_name = np.random.choice(self._env.unwrapped._elem_names) 41 | 42 | # Randomly flip the fill directions. 43 | flip_x = np.random.randint(2) 44 | flip_y = np.random.randint(2) 45 | flip_xy = np.random.randint(2) 46 | 47 | self._sequence = [] 48 | for i in range(self._size * self._size): 49 | x = i % self._size 50 | if flip_x: 51 | x = self._size - x - 1 52 | y = i // self._size 53 | if flip_y: 54 | y = self._size - y - 1 55 | if flip_xy: 56 | x, y = y, x 57 | 58 | self._sequence.append((x, y)) 59 | 60 | 61 | class LineBehavior(Behavior): 62 | """Fill a single line with a single element.""" 63 | 64 | def reset(self, ob, info): 65 | self._done = False 66 | self._step = 0 67 | self._elem_name = np.random.choice(self._env.unwrapped._elem_names) 68 | 69 | # Randomly select the line direction. 70 | target_idx = np.random.randint(self._size) 71 | flip_dir = np.random.randint(2) 72 | flip_xy = np.random.randint(2) 73 | 74 | self._sequence = [] 75 | for i in range(self._size): 76 | x, y = i, target_idx 77 | if flip_dir: 78 | y = self._size - 1 - y 79 | if flip_xy: 80 | x, y = y, x 81 | 82 | self._sequence.append((x, y)) 83 | 84 | 85 | class SquareBehavior(Behavior): 86 | """Draw a square with a single element.""" 87 | 88 | def reset(self, ob, info): 89 | self._done = False 90 | self._step = 0 91 | self._elem_name = np.random.choice(self._env.unwrapped._elem_names) 92 | 93 | length = np.random.randint(1, self._size) 94 | x1 = np.random.randint(self._size - length) 95 | x2 = x1 + length 96 | y1 = np.random.randint(self._size - length) 97 | y2 = y1 + length 98 | 99 | sides = [] 100 | sides.append([(x1, y) for y in range(y1, y2 + 1)]) 101 | sides.append([(x2, y) for y in range(y1, y2 + 1)]) 102 | sides.append([(x, y1) for x in range(x1, x2 + 1)]) 103 | sides.append([(x, y2) for x in range(x1, x2 + 1)]) 104 | 105 | # Randomly reverse sides. 106 | for i in range(4): 107 | if np.random.randint(2): 108 | sides[i].reverse() 109 | 110 | # Randomly shuffle the order of sides. 111 | np.random.shuffle(sides) 112 | 113 | self._sequence = [] 114 | for side in sides: 115 | self._sequence.extend(side) 116 | -------------------------------------------------------------------------------- /ogbench/relabel_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def relabel_dataset(env_name, env, dataset): 5 | """Relabel the dataset with rewards and masks based on the fixed task of the environment. 6 | This is useful for single-task variants of the environments. 7 | 8 | Args: 9 | env_name: Name of the environment. 10 | env: Environment. 11 | dataset: Dataset dictionary. 12 | """ 13 | assert env.unwrapped._reward_task_id is not None, 'The environment is not in the single-task mode.' 14 | env.reset() # Set the task. 15 | 16 | if 'maze' in env_name or 'soccer' in env_name: 17 | # Locomotion environments. 18 | qpos_xy_start_idx = 0 19 | qpos_ball_start_idx = 15 20 | goal_xy = env.unwrapped.cur_goal_xy 21 | goal_tol = env.unwrapped._goal_tol 22 | 23 | # Compute successes. 24 | if 'maze' in env_name: 25 | dists = np.linalg.norm(dataset['qpos'][:, qpos_xy_start_idx : qpos_xy_start_idx + 2] - goal_xy, axis=-1) 26 | else: 27 | dists = np.linalg.norm(dataset['qpos'][:, qpos_ball_start_idx : qpos_ball_start_idx + 2] - goal_xy, axis=-1) 28 | successes = (dists <= goal_tol).astype(np.float32) 29 | 30 | rewards = successes - 1.0 31 | masks = 1.0 - successes 32 | elif 'cube' in env_name or 'scene' in env_name or 'puzzle' in env_name: 33 | # Manipulation environments. 34 | qpos_obj_start_idx = 14 35 | qpos_cube_length = 7 36 | 37 | if 'cube' in env_name: 38 | num_cubes = env.unwrapped._num_cubes 39 | target_cube_xyzs = env.unwrapped._data.mocap_pos.copy() 40 | 41 | # Compute successes. 42 | cube_xyzs_list = [] 43 | for i in range(num_cubes): 44 | cube_xyzs_list.append( 45 | dataset['qpos'][ 46 | :, qpos_obj_start_idx + i * qpos_cube_length : qpos_obj_start_idx + i * qpos_cube_length + 3 47 | ] 48 | ) 49 | cube_xyzs = np.stack(cube_xyzs_list, axis=1) 50 | successes = np.linalg.norm(target_cube_xyzs - cube_xyzs, axis=-1) <= 0.04 51 | elif 'scene' in env_name: 52 | num_cubes = env.unwrapped._num_cubes 53 | num_buttons = env.unwrapped._num_buttons 54 | qpos_drawer_idx = qpos_obj_start_idx + num_cubes * qpos_cube_length + num_buttons 55 | qpos_window_idx = qpos_drawer_idx + 1 56 | target_cube_xyzs = env.unwrapped._data.mocap_pos.copy() 57 | target_button_states = env.unwrapped._target_button_states.copy() 58 | target_drawer_pos = env.unwrapped._target_drawer_pos 59 | target_window_pos = env.unwrapped._target_window_pos 60 | 61 | # Compute successes. 62 | cube_xyzs_list = [] 63 | for i in range(num_cubes): 64 | cube_xyzs_list.append( 65 | dataset['qpos'][ 66 | :, qpos_obj_start_idx + i * qpos_cube_length : qpos_obj_start_idx + i * qpos_cube_length + 3 67 | ] 68 | ) 69 | cube_xyzs = np.stack(cube_xyzs_list, axis=1) 70 | cube_successes = np.linalg.norm(target_cube_xyzs - cube_xyzs, axis=-1) <= 0.04 71 | button_successes = dataset['button_states'] == target_button_states 72 | drawer_success = np.abs(dataset['qpos'][:, qpos_drawer_idx] - target_drawer_pos) <= 0.04 73 | window_success = np.abs(dataset['qpos'][:, qpos_window_idx] - target_window_pos) <= 0.04 74 | successes = np.concatenate( 75 | [cube_successes, button_successes, drawer_success[:, None], window_success[:, None]], axis=-1 76 | ) 77 | elif 'puzzle' in env_name: 78 | num_buttons = env.unwrapped._num_buttons 79 | target_button_states = env.unwrapped._target_button_states.copy() 80 | 81 | # Compute successes. 82 | successes = dataset['button_states'] == target_button_states 83 | 84 | rewards = successes.sum(axis=-1) - successes.shape[-1] 85 | masks = 1.0 - np.all(successes, axis=-1) 86 | else: 87 | raise ValueError(f'Unsupported environment: {env_name}') 88 | 89 | dataset['rewards'] = rewards.astype(np.float32) 90 | dataset['masks'] = masks.astype(np.float32) 91 | 92 | 93 | def add_oracle_reps(env_name, env, dataset): 94 | """Add oracle goal representations to the dataset. 95 | 96 | Args: 97 | env_name: Name of the environment. 98 | env: Environment. 99 | dataset: Dataset dictionary. 100 | """ 101 | if 'maze' in env_name or 'soccer' in env_name: 102 | # Locomotion environments. 103 | qpos_xy_start_idx = 0 104 | qpos_ball_start_idx = 15 105 | 106 | if 'maze' in env_name: 107 | oracle_reps = dataset['qpos'][:, qpos_xy_start_idx : qpos_xy_start_idx + 2] 108 | else: 109 | oracle_reps = dataset['qpos'][:, qpos_ball_start_idx : qpos_ball_start_idx + 2] 110 | elif 'cube' in env_name or 'scene' in env_name or 'puzzle' in env_name: 111 | # Manipulation environments. 112 | qpos_obj_start_idx = 14 113 | qpos_cube_length = 7 114 | xyz_center = np.array([0.425, 0.0, 0.0]) 115 | xyz_scaler = 10.0 116 | drawer_scaler = 18.0 117 | window_scaler = 15.0 118 | 119 | if 'cube' in env_name: 120 | num_cubes = env.unwrapped._num_cubes 121 | 122 | cube_xyzs_list = [] 123 | for i in range(num_cubes): 124 | cube_xyzs_list.append( 125 | dataset['qpos'][ 126 | :, qpos_obj_start_idx + i * qpos_cube_length : qpos_obj_start_idx + i * qpos_cube_length + 3 127 | ] 128 | ) 129 | cube_xyzs = np.stack(cube_xyzs_list, axis=1) 130 | oracle_reps = ((cube_xyzs - xyz_center) * xyz_scaler).reshape(-1, num_cubes * 3) 131 | elif 'scene' in env_name: 132 | num_cubes = env.unwrapped._num_cubes 133 | num_buttons = env.unwrapped._num_buttons 134 | qpos_drawer_idx = qpos_obj_start_idx + num_cubes * qpos_cube_length + num_buttons 135 | qpos_window_idx = qpos_drawer_idx + 1 136 | 137 | cube_xyzs_list = [] 138 | for i in range(num_cubes): 139 | cube_xyzs_list.append( 140 | dataset['qpos'][ 141 | :, qpos_obj_start_idx + i * qpos_cube_length : qpos_obj_start_idx + i * qpos_cube_length + 3 142 | ] 143 | ) 144 | cube_xyzs = np.stack(cube_xyzs_list, axis=1) 145 | cube_reps = ((cube_xyzs - xyz_center) * xyz_scaler).reshape(-1, num_cubes * 3) 146 | button_reps = dataset['button_states'].copy() 147 | drawer_reps = dataset['qpos'][:, [qpos_drawer_idx]] * drawer_scaler 148 | window_reps = dataset['qpos'][:, [qpos_window_idx]] * window_scaler 149 | oracle_reps = np.concatenate([cube_reps, button_reps, drawer_reps, window_reps], axis=-1) 150 | elif 'puzzle' in env_name: 151 | oracle_reps = dataset['button_states'].copy() 152 | else: 153 | raise ValueError(f'Unsupported environment: {env_name}') 154 | 155 | dataset['oracle_reps'] = oracle_reps.astype(np.float32) 156 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "ogbench" 7 | version = "1.1.3" 8 | requires-python = ">=3.8" 9 | dependencies = [ 10 | "mujoco >= 3.1.6", 11 | "dm_control >= 1.0.20", 12 | "gymnasium[mujoco]", 13 | ] 14 | authors = [ 15 | { name = "Seohong Park" }, 16 | { name = "Kevin Frans" }, 17 | { name = "Benjamin Eysenbach" }, 18 | { name = "Sergey Levine" }, 19 | ] 20 | maintainers = [ 21 | { name = "Seohong Park", email = "seohong@berkeley.edu" } 22 | ] 23 | license = { file = "LICENSE" } 24 | classifiers = ["License :: OSI Approved :: MIT License"] 25 | dynamic = ["description"] 26 | 27 | [project.optional-dependencies] 28 | train = [ 29 | "jax[cuda12] >= 0.4.26", 30 | "flax >= 0.8.4", 31 | "distrax >= 0.1.5", 32 | "ml_collections", 33 | "matplotlib", 34 | "moviepy", 35 | "wandb", 36 | ] 37 | dev = [ 38 | "ruff", 39 | ] 40 | all = [ 41 | "ogbench[train,dev]", 42 | ] 43 | 44 | [project.urls] 45 | Home = "https://github.com/seohongpark/ogbench" 46 | 47 | [tool.ruff] 48 | target-version = "py310" 49 | line-length = 120 50 | 51 | [tool.ruff.format] 52 | quote-style = "single" 53 | --------------------------------------------------------------------------------