├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── assets
├── env_teaser.png
└── ogbench.svg
├── data_gen_scripts
├── commands.sh
├── generate_antsoccer.py
├── generate_locomaze.py
├── generate_manipspace.py
├── generate_powderworld.py
├── main_sac.py
├── online_env_utils.py
└── viz_utils.py
├── impls
├── agents
│ ├── __init__.py
│ ├── crl.py
│ ├── gcbc.py
│ ├── gciql.py
│ ├── gcivl.py
│ ├── hiql.py
│ ├── qrl.py
│ └── sac.py
├── hyperparameters.sh
├── main.py
├── requirements.txt
└── utils
│ ├── __init__.py
│ ├── datasets.py
│ ├── encoders.py
│ ├── env_utils.py
│ ├── evaluation.py
│ ├── flax_utils.py
│ ├── log_utils.py
│ └── networks.py
├── ogbench
├── __init__.py
├── locomaze
│ ├── __init__.py
│ ├── ant.py
│ ├── assets
│ │ ├── ant.xml
│ │ ├── humanoid.xml
│ │ └── point.xml
│ ├── humanoid.py
│ ├── maze.py
│ └── point.py
├── manipspace
│ ├── __init__.py
│ ├── controllers
│ │ ├── __init__.py
│ │ └── diff_ik.py
│ ├── descriptions
│ │ ├── button_inner.xml
│ │ ├── button_outer.xml
│ │ ├── buttons.xml
│ │ ├── cube.xml
│ │ ├── cube_inner.xml
│ │ ├── cube_outer.xml
│ │ ├── drawer.xml
│ │ ├── floor_wall.xml
│ │ ├── metaworld
│ │ │ ├── button
│ │ │ │ ├── button.stl
│ │ │ │ ├── buttonring.stl
│ │ │ │ ├── metal1.png
│ │ │ │ ├── stopbot.stl
│ │ │ │ ├── stopbutton.stl
│ │ │ │ ├── stopbuttonrim.stl
│ │ │ │ ├── stopbuttonrod.stl
│ │ │ │ └── stoptop.stl
│ │ │ ├── drawer
│ │ │ │ ├── drawer.stl
│ │ │ │ ├── drawercase.stl
│ │ │ │ └── drawerhandle.stl
│ │ │ └── window
│ │ │ │ ├── window_base.stl
│ │ │ │ ├── window_frame.stl
│ │ │ │ ├── window_h_base.stl
│ │ │ │ ├── window_h_frame.stl
│ │ │ │ ├── windowa_frame.stl
│ │ │ │ ├── windowa_glass.stl
│ │ │ │ ├── windowa_h_frame.stl
│ │ │ │ ├── windowa_h_glass.stl
│ │ │ │ ├── windowb_frame.stl
│ │ │ │ ├── windowb_glass.stl
│ │ │ │ ├── windowb_h_frame.stl
│ │ │ │ └── windowb_h_glass.stl
│ │ ├── robotiq_2f85
│ │ │ ├── 2f85.png
│ │ │ ├── 2f85.xml
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── assets
│ │ │ │ ├── base.stl
│ │ │ │ ├── base_mount.stl
│ │ │ │ ├── coupler.stl
│ │ │ │ ├── driver.stl
│ │ │ │ ├── follower.stl
│ │ │ │ ├── pad.stl
│ │ │ │ ├── silicone_pad.stl
│ │ │ │ └── spring_link.stl
│ │ │ └── scene.xml
│ │ ├── universal_robots_ur5e
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── assets
│ │ │ │ ├── base_0.obj
│ │ │ │ ├── base_1.obj
│ │ │ │ ├── forearm_0.obj
│ │ │ │ ├── forearm_1.obj
│ │ │ │ ├── forearm_2.obj
│ │ │ │ ├── forearm_3.obj
│ │ │ │ ├── shoulder_0.obj
│ │ │ │ ├── shoulder_1.obj
│ │ │ │ ├── shoulder_2.obj
│ │ │ │ ├── upperarm_0.obj
│ │ │ │ ├── upperarm_1.obj
│ │ │ │ ├── upperarm_2.obj
│ │ │ │ ├── upperarm_3.obj
│ │ │ │ ├── wrist1_0.obj
│ │ │ │ ├── wrist1_1.obj
│ │ │ │ ├── wrist1_2.obj
│ │ │ │ ├── wrist2_0.obj
│ │ │ │ ├── wrist2_1.obj
│ │ │ │ ├── wrist2_2.obj
│ │ │ │ └── wrist3.obj
│ │ │ ├── scene.xml
│ │ │ ├── ur5e.png
│ │ │ └── ur5e.xml
│ │ └── window.xml
│ ├── envs
│ │ ├── __init__.py
│ │ ├── cube_env.py
│ │ ├── env.py
│ │ ├── manipspace_env.py
│ │ ├── puzzle_env.py
│ │ └── scene_env.py
│ ├── lie
│ │ ├── __init__.py
│ │ ├── se3.py
│ │ ├── so3.py
│ │ └── utils.py
│ ├── mjcf_utils.py
│ ├── oracles
│ │ ├── __init__.py
│ │ ├── markov
│ │ │ ├── __init__.py
│ │ │ ├── button_markov.py
│ │ │ ├── cube_markov.py
│ │ │ ├── drawer_markov.py
│ │ │ ├── markov_oracle.py
│ │ │ └── window_markov.py
│ │ └── plan
│ │ │ ├── __init__.py
│ │ │ ├── button_plan.py
│ │ │ ├── cube_plan.py
│ │ │ ├── drawer_plan.py
│ │ │ ├── plan_oracle.py
│ │ │ └── window_plan.py
│ └── viewer_utils.py
├── online_locomotion
│ ├── __init__.py
│ ├── ant.py
│ ├── ant_ball.py
│ ├── assets
│ │ ├── ant.xml
│ │ └── humanoid.xml
│ ├── humanoid.py
│ └── wrappers.py
├── powderworld
│ ├── __init__.py
│ ├── behaviors.py
│ ├── powderworld_env.py
│ └── sim.py
├── relabel_utils.py
└── utils.py
└── pyproject.toml
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | dist/
3 | *.py[cod]
4 | *$py.class
5 | *.egg-info/
6 | .DS_Store
7 | .idea/
8 | .ruff_cache/
9 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Change log
2 |
3 | ## ogbench 1.1.3 (2025-06-03)
4 | - Add the `cube-octuple` task.
5 |
6 | ## ogbench 1.1.2 (2025-03-30)
7 | - Improve compatibility with `gymnasium`.
8 |
9 | ## ogbench 1.1.1 (2025-03-02)
10 | - Make it compatible with the latest version of `gymnasium` (1.1.0).
11 |
12 | ## ogbench 1.1.0 (2025-02-13)
13 | - Added `-singletask` environments for standard (i.e., non-goal-conditioned) offline RL.
14 | - Added `-oraclerep` environments for offline goal-conditioned RL with oracle goal representations.
15 |
16 | ## ogbench 1.0.1 (2024-10-28)
17 | - Fixed a bug in the reward function of manipulation tasks.
18 |
19 | ## ogbench 1.0.0 (2024-10-25)
20 | - Initial release.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2024 OGBench Authors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/assets/env_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/assets/env_teaser.png
--------------------------------------------------------------------------------
/assets/ogbench.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data_gen_scripts/generate_powderworld.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from collections import defaultdict
3 |
4 | import gymnasium
5 | import numpy as np
6 | from absl import app, flags
7 | from tqdm import trange
8 |
9 | import ogbench.powderworld # noqa
10 | from ogbench.powderworld.behaviors import FillBehavior, LineBehavior, SquareBehavior
11 |
12 | FLAGS = flags.FLAGS
13 |
14 | flags.DEFINE_integer('seed', 0, 'Random seed.')
15 | flags.DEFINE_string('env_name', 'powderworld-v0', 'Environment name.')
16 | flags.DEFINE_string('dataset_type', 'play', 'Dataset type.')
17 | flags.DEFINE_string('save_path', None, 'Save path.')
18 | flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes.')
19 | flags.DEFINE_integer('max_episode_steps', 1001, 'Maximum number of steps in an episode.')
20 | flags.DEFINE_float('p_random_action', 0.5, 'Probability of selecting a random action.')
21 |
22 |
23 | def main(_):
24 | assert FLAGS.dataset_type in ['play']
25 |
26 | # Initialize environment.
27 | env = gymnasium.make(
28 | FLAGS.env_name,
29 | mode='data_collection',
30 | max_episode_steps=FLAGS.max_episode_steps,
31 | )
32 | env.reset()
33 |
34 | # Initialize agents.
35 | agents = [
36 | FillBehavior(env=env),
37 | LineBehavior(env=env),
38 | SquareBehavior(env=env),
39 | ]
40 | probs = np.array([1, 3, 3]) # Agent selection probabilities.
41 | probs = probs / probs.sum()
42 |
43 | # Collect data.
44 | dataset = defaultdict(list)
45 | total_steps = 0
46 | total_train_steps = 0
47 | num_train_episodes = FLAGS.num_episodes
48 | num_val_episodes = FLAGS.num_episodes // 10
49 | for ep_idx in trange(num_train_episodes + num_val_episodes):
50 | ob, info = env.reset()
51 | agent = np.random.choice(agents, p=probs)
52 | agent.reset(ob, info)
53 |
54 | done = False
55 | step = 0
56 |
57 | action_step = 0 # Action cycle counter (0, 1, 2).
58 | while not done:
59 | if action_step == 0:
60 | # Select an action every 3 steps.
61 | if np.random.rand() < FLAGS.p_random_action:
62 | # Sample a random action.
63 | semantic_action = env.unwrapped.sample_semantic_action()
64 | else:
65 | # Get an action from the agent.
66 | semantic_action = agent.select_action(ob, info)
67 | action = env.unwrapped.semantic_action_to_action(*semantic_action)
68 | next_ob, reward, terminated, truncated, info = env.step(action)
69 | done = terminated or truncated
70 |
71 | if agent.done and FLAGS.dataset_type == 'play':
72 | agent = np.random.choice(agents, p=probs)
73 | agent.reset(ob, info)
74 |
75 | dataset['observations'].append(ob)
76 | dataset['actions'].append(action)
77 | dataset['terminals'].append(done)
78 |
79 | ob = next_ob
80 | step += 1
81 | action_step = (action_step + 1) % 3
82 |
83 | total_steps += step
84 | if ep_idx < num_train_episodes:
85 | total_train_steps += step
86 |
87 | print('Total steps:', total_steps)
88 |
89 | train_path = FLAGS.save_path
90 | val_path = FLAGS.save_path.replace('.npz', '-val.npz')
91 | pathlib.Path(train_path).parent.mkdir(parents=True, exist_ok=True)
92 |
93 | # Split the dataset into training and validation sets.
94 | train_dataset = {}
95 | val_dataset = {}
96 | for k, v in dataset.items():
97 | if 'observations' in k and v[0].dtype == np.uint8:
98 | dtype = np.uint8
99 | elif 'actions':
100 | dtype = np.int32
101 | elif k == 'terminals':
102 | dtype = bool
103 | else:
104 | dtype = np.float32
105 | train_dataset[k] = np.array(v[:total_train_steps], dtype=dtype)
106 | val_dataset[k] = np.array(v[total_train_steps:], dtype=dtype)
107 |
108 | for path, dataset in [(train_path, train_dataset), (val_path, val_dataset)]:
109 | np.savez_compressed(path, **dataset)
110 |
111 |
112 | if __name__ == '__main__':
113 | app.run(main)
114 |
--------------------------------------------------------------------------------
/data_gen_scripts/online_env_utils.py:
--------------------------------------------------------------------------------
1 | import gymnasium
2 | from utils.env_utils import EpisodeMonitor
3 |
4 |
5 | def make_online_env(env_name):
6 | """Make online environment.
7 |
8 | If the environment name contains the '-xy' suffix, the environment will be wrapped with a directional locomotion
9 | wrapper. For example, 'online-ant-xy-v0' will return an 'online-ant-v0' environment wrapped with GymXYWrapper.
10 |
11 | Args:
12 | env_name: Name of the environment.
13 | """
14 | import ogbench.online_locomotion # noqa
15 |
16 | # Manually recognize the '-xy' suffix, which indicates that the environment should be wrapped with a directional
17 | # locomotion wrapper.
18 | if '-xy' in env_name:
19 | env_name = env_name.replace('-xy', '')
20 | apply_xy_wrapper = True
21 | else:
22 | apply_xy_wrapper = False
23 |
24 | # Set camera.
25 | if 'humanoid' in env_name:
26 | extra_kwargs = dict(camera_id=0)
27 | else:
28 | extra_kwargs = dict()
29 |
30 | # Make environment.
31 | env = gymnasium.make(env_name, render_mode='rgb_array', height=200, width=200, **extra_kwargs)
32 |
33 | if apply_xy_wrapper:
34 | # Apply the directional locomotion wrapper.
35 | from ogbench.online_locomotion.wrappers import DMCHumanoidXYWrapper, GymXYWrapper
36 |
37 | if 'humanoid' in env_name:
38 | env = DMCHumanoidXYWrapper(env, resample_interval=200)
39 | else:
40 | env = GymXYWrapper(env, resample_interval=100)
41 |
42 | env = EpisodeMonitor(env)
43 |
44 | return env
45 |
--------------------------------------------------------------------------------
/data_gen_scripts/viz_utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import numpy as np
3 | from matplotlib import figure
4 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
5 |
6 |
7 | def get_2d_colors(points, min_point, max_point):
8 | """Get colors corresponding to 2-D points."""
9 | points = np.array(points)
10 | min_point = np.array(min_point)
11 | max_point = np.array(max_point)
12 |
13 | colors = (points - min_point) / (max_point - min_point)
14 | colors = np.hstack((colors, (2 - np.sum(colors, axis=1, keepdims=True)) / 2))
15 | colors = np.clip(colors, 0, 1)
16 | colors = np.c_[colors, np.full(len(colors), 0.8)]
17 |
18 | return colors
19 |
20 |
21 | def visualize_trajs(env_name, trajs):
22 | """Visualize x-y trajectories in locomotion environments.
23 |
24 | It reads 'xy' and 'direction' from the 'info' field of the trajectories.
25 | """
26 | matplotlib.use('Agg')
27 |
28 | fig = figure.Figure(tight_layout=True)
29 | canvas = FigureCanvas(fig)
30 | if 'xy' in trajs[0]['info'][0]:
31 | ax = fig.add_subplot()
32 |
33 | max_xy = 0.0
34 | for traj in trajs:
35 | xy = np.array([info['xy'] for info in traj['info']])
36 | direction = np.array([info['direction'] for info in traj['info']])
37 | color = get_2d_colors(direction, [-1, -1], [1, 1])
38 | for i in range(len(xy) - 1):
39 | ax.plot(xy[i : i + 2, 0], xy[i : i + 2, 1], color=color[i], linewidth=0.7)
40 | max_xy = max(max_xy, np.abs(xy).max() * 1.2)
41 |
42 | plot_axis = [-max_xy, max_xy, -max_xy, max_xy]
43 | ax.axis(plot_axis)
44 | ax.set_aspect('equal')
45 | else:
46 | return None
47 |
48 | fig.tight_layout()
49 | canvas.draw()
50 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
51 | out_image = out_image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
52 | return out_image
53 |
--------------------------------------------------------------------------------
/impls/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from agents.crl import CRLAgent
2 | from agents.gcbc import GCBCAgent
3 | from agents.gciql import GCIQLAgent
4 | from agents.gcivl import GCIVLAgent
5 | from agents.hiql import HIQLAgent
6 | from agents.qrl import QRLAgent
7 | from agents.sac import SACAgent
8 |
9 | agents = dict(
10 | crl=CRLAgent,
11 | gcbc=GCBCAgent,
12 | gciql=GCIQLAgent,
13 | gcivl=GCIVLAgent,
14 | hiql=HIQLAgent,
15 | qrl=QRLAgent,
16 | sac=SACAgent,
17 | )
18 |
--------------------------------------------------------------------------------
/impls/agents/gcbc.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import flax
4 | import jax
5 | import jax.numpy as jnp
6 | import ml_collections
7 | import optax
8 | from utils.encoders import GCEncoder, encoder_modules
9 | from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
10 | from utils.networks import GCActor, GCDiscreteActor
11 |
12 |
13 | class GCBCAgent(flax.struct.PyTreeNode):
14 | """Goal-conditioned behavioral cloning (GCBC) agent."""
15 |
16 | rng: Any
17 | network: Any
18 | config: Any = nonpytree_field()
19 |
20 | def actor_loss(self, batch, grad_params, rng=None):
21 | """Compute the BC actor loss."""
22 | dist = self.network.select('actor')(batch['observations'], batch['actor_goals'], params=grad_params)
23 | log_prob = dist.log_prob(batch['actions'])
24 |
25 | actor_loss = -log_prob.mean()
26 |
27 | actor_info = {
28 | 'actor_loss': actor_loss,
29 | 'bc_log_prob': log_prob.mean(),
30 | }
31 | if not self.config['discrete']:
32 | actor_info.update(
33 | {
34 | 'mse': jnp.mean((dist.mode() - batch['actions']) ** 2),
35 | 'std': jnp.mean(dist.scale_diag),
36 | }
37 | )
38 |
39 | return actor_loss, actor_info
40 |
41 | @jax.jit
42 | def total_loss(self, batch, grad_params, rng=None):
43 | """Compute the total loss."""
44 | info = {}
45 | rng = rng if rng is not None else self.rng
46 |
47 | rng, actor_rng = jax.random.split(rng)
48 | actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
49 | for k, v in actor_info.items():
50 | info[f'actor/{k}'] = v
51 |
52 | loss = actor_loss
53 | return loss, info
54 |
55 | @jax.jit
56 | def update(self, batch):
57 | """Update the agent and return a new agent with information dictionary."""
58 | new_rng, rng = jax.random.split(self.rng)
59 |
60 | def loss_fn(grad_params):
61 | return self.total_loss(batch, grad_params, rng=rng)
62 |
63 | new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
64 |
65 | return self.replace(network=new_network, rng=new_rng), info
66 |
67 | @jax.jit
68 | def sample_actions(
69 | self,
70 | observations,
71 | goals=None,
72 | seed=None,
73 | temperature=1.0,
74 | ):
75 | """Sample actions from the actor."""
76 | dist = self.network.select('actor')(observations, goals, temperature=temperature)
77 | actions = dist.sample(seed=seed)
78 | if not self.config['discrete']:
79 | actions = jnp.clip(actions, -1, 1)
80 | return actions
81 |
82 | @classmethod
83 | def create(
84 | cls,
85 | seed,
86 | ex_observations,
87 | ex_actions,
88 | config,
89 | ):
90 | """Create a new agent.
91 |
92 | Args:
93 | seed: Random seed.
94 | ex_observations: Example batch of observations.
95 | ex_actions: Example batch of actions. In discrete-action MDPs, this should contain the maximum action value.
96 | config: Configuration dictionary.
97 | """
98 | rng = jax.random.PRNGKey(seed)
99 | rng, init_rng = jax.random.split(rng, 2)
100 |
101 | ex_goals = ex_observations
102 | if config['discrete']:
103 | action_dim = ex_actions.max() + 1
104 | else:
105 | action_dim = ex_actions.shape[-1]
106 |
107 | # Define encoder.
108 | encoders = dict()
109 | if config['encoder'] is not None:
110 | encoder_module = encoder_modules[config['encoder']]
111 | encoders['actor'] = GCEncoder(concat_encoder=encoder_module())
112 |
113 | # Define actor network.
114 | if config['discrete']:
115 | actor_def = GCDiscreteActor(
116 | hidden_dims=config['actor_hidden_dims'],
117 | action_dim=action_dim,
118 | gc_encoder=encoders.get('actor'),
119 | )
120 | else:
121 | actor_def = GCActor(
122 | hidden_dims=config['actor_hidden_dims'],
123 | action_dim=action_dim,
124 | state_dependent_std=False,
125 | const_std=config['const_std'],
126 | gc_encoder=encoders.get('actor'),
127 | )
128 |
129 | network_info = dict(
130 | actor=(actor_def, (ex_observations, ex_goals)),
131 | )
132 | networks = {k: v[0] for k, v in network_info.items()}
133 | network_args = {k: v[1] for k, v in network_info.items()}
134 |
135 | network_def = ModuleDict(networks)
136 | network_tx = optax.adam(learning_rate=config['lr'])
137 | network_params = network_def.init(init_rng, **network_args)['params']
138 | network = TrainState.create(network_def, network_params, tx=network_tx)
139 |
140 | return cls(rng, network=network, config=flax.core.FrozenDict(**config))
141 |
142 |
143 | def get_config():
144 | config = ml_collections.ConfigDict(
145 | dict(
146 | # Agent hyperparameters.
147 | agent_name='gcbc', # Agent name.
148 | lr=3e-4, # Learning rate.
149 | batch_size=1024, # Batch size.
150 | actor_hidden_dims=(512, 512, 512), # Actor network hidden dimensions.
151 | discount=0.99, # Discount factor (unused by default; can be used for geometric goal sampling in GCDataset).
152 | const_std=True, # Whether to use constant standard deviation for the actor.
153 | discrete=False, # Whether the action space is discrete.
154 | encoder=ml_collections.config_dict.placeholder(str), # Visual encoder name (None, 'impala_small', etc.).
155 | # Dataset hyperparameters.
156 | dataset_class='GCDataset', # Dataset class name.
157 | value_p_curgoal=0.0, # Unused (defined for compatibility with GCDataset).
158 | value_p_trajgoal=1.0, # Unused (defined for compatibility with GCDataset).
159 | value_p_randomgoal=0.0, # Unused (defined for compatibility with GCDataset).
160 | value_geom_sample=False, # Unused (defined for compatibility with GCDataset).
161 | actor_p_curgoal=0.0, # Probability of using the current state as the actor goal.
162 | actor_p_trajgoal=1.0, # Probability of using a future state in the same trajectory as the actor goal.
163 | actor_p_randomgoal=0.0, # Probability of using a random state as the actor goal.
164 | actor_geom_sample=False, # Whether to use geometric sampling for future actor goals.
165 | gc_negative=True, # Unused (defined for compatibility with GCDataset).
166 | p_aug=0.0, # Probability of applying image augmentation.
167 | frame_stack=ml_collections.config_dict.placeholder(int), # Number of frames to stack.
168 | )
169 | )
170 | return config
171 |
--------------------------------------------------------------------------------
/impls/main.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import time
5 | from collections import defaultdict
6 |
7 | import jax
8 | import numpy as np
9 | import tqdm
10 | import wandb
11 | from absl import app, flags
12 | from agents import agents
13 | from ml_collections import config_flags
14 | from utils.datasets import Dataset, GCDataset, HGCDataset
15 | from utils.env_utils import make_env_and_datasets
16 | from utils.evaluation import evaluate
17 | from utils.flax_utils import restore_agent, save_agent
18 | from utils.log_utils import CsvLogger, get_exp_name, get_flag_dict, get_wandb_video, setup_wandb
19 |
20 | FLAGS = flags.FLAGS
21 |
22 | flags.DEFINE_string('run_group', 'Debug', 'Run group.')
23 | flags.DEFINE_integer('seed', 0, 'Random seed.')
24 | flags.DEFINE_string('env_name', 'antmaze-large-navigate-v0', 'Environment (dataset) name.')
25 | flags.DEFINE_string('save_dir', 'exp/', 'Save directory.')
26 | flags.DEFINE_string('restore_path', None, 'Restore path.')
27 | flags.DEFINE_integer('restore_epoch', None, 'Restore epoch.')
28 |
29 | flags.DEFINE_integer('train_steps', 1000000, 'Number of training steps.')
30 | flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
31 | flags.DEFINE_integer('eval_interval', 100000, 'Evaluation interval.')
32 | flags.DEFINE_integer('save_interval', 1000000, 'Saving interval.')
33 |
34 | flags.DEFINE_integer('eval_tasks', None, 'Number of tasks to evaluate (None for all).')
35 | flags.DEFINE_integer('eval_episodes', 20, 'Number of episodes for each task.')
36 | flags.DEFINE_float('eval_temperature', 0, 'Actor temperature for evaluation.')
37 | flags.DEFINE_float('eval_gaussian', None, 'Action Gaussian noise for evaluation.')
38 | flags.DEFINE_integer('video_episodes', 1, 'Number of video episodes for each task.')
39 | flags.DEFINE_integer('video_frame_skip', 3, 'Frame skip for videos.')
40 | flags.DEFINE_integer('eval_on_cpu', 1, 'Whether to evaluate on CPU.')
41 |
42 | config_flags.DEFINE_config_file('agent', 'agents/gciql.py', lock_config=False)
43 |
44 |
45 | def main(_):
46 | # Set up logger.
47 | exp_name = get_exp_name(FLAGS.seed)
48 | setup_wandb(project='OGBench', group=FLAGS.run_group, name=exp_name)
49 |
50 | FLAGS.save_dir = os.path.join(FLAGS.save_dir, wandb.run.project, FLAGS.run_group, exp_name)
51 | os.makedirs(FLAGS.save_dir, exist_ok=True)
52 | flag_dict = get_flag_dict()
53 | with open(os.path.join(FLAGS.save_dir, 'flags.json'), 'w') as f:
54 | json.dump(flag_dict, f)
55 |
56 | # Set up environment and dataset.
57 | config = FLAGS.agent
58 | env, train_dataset, val_dataset = make_env_and_datasets(FLAGS.env_name, frame_stack=config['frame_stack'])
59 |
60 | dataset_class = {
61 | 'GCDataset': GCDataset,
62 | 'HGCDataset': HGCDataset,
63 | }[config['dataset_class']]
64 | train_dataset = dataset_class(Dataset.create(**train_dataset), config)
65 | if val_dataset is not None:
66 | val_dataset = dataset_class(Dataset.create(**val_dataset), config)
67 |
68 | # Initialize agent.
69 | random.seed(FLAGS.seed)
70 | np.random.seed(FLAGS.seed)
71 |
72 | example_batch = train_dataset.sample(1)
73 | if config['discrete']:
74 | # Fill with the maximum action to let the agent know the action space size.
75 | example_batch['actions'] = np.full_like(example_batch['actions'], env.action_space.n - 1)
76 |
77 | agent_class = agents[config['agent_name']]
78 | agent = agent_class.create(
79 | FLAGS.seed,
80 | example_batch['observations'],
81 | example_batch['actions'],
82 | config,
83 | )
84 |
85 | # Restore agent.
86 | if FLAGS.restore_path is not None:
87 | agent = restore_agent(agent, FLAGS.restore_path, FLAGS.restore_epoch)
88 |
89 | # Train agent.
90 | train_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'train.csv'))
91 | eval_logger = CsvLogger(os.path.join(FLAGS.save_dir, 'eval.csv'))
92 | first_time = time.time()
93 | last_time = time.time()
94 | for i in tqdm.tqdm(range(1, FLAGS.train_steps + 1), smoothing=0.1, dynamic_ncols=True):
95 | # Update agent.
96 | batch = train_dataset.sample(config['batch_size'])
97 | agent, update_info = agent.update(batch)
98 |
99 | # Log metrics.
100 | if i % FLAGS.log_interval == 0:
101 | train_metrics = {f'training/{k}': v for k, v in update_info.items()}
102 | if val_dataset is not None:
103 | val_batch = val_dataset.sample(config['batch_size'])
104 | _, val_info = agent.total_loss(val_batch, grad_params=None)
105 | train_metrics.update({f'validation/{k}': v for k, v in val_info.items()})
106 | train_metrics['time/epoch_time'] = (time.time() - last_time) / FLAGS.log_interval
107 | train_metrics['time/total_time'] = time.time() - first_time
108 | last_time = time.time()
109 | wandb.log(train_metrics, step=i)
110 | train_logger.log(train_metrics, step=i)
111 |
112 | # Evaluate agent.
113 | if i == 1 or i % FLAGS.eval_interval == 0:
114 | if FLAGS.eval_on_cpu:
115 | eval_agent = jax.device_put(agent, device=jax.devices('cpu')[0])
116 | else:
117 | eval_agent = agent
118 | renders = []
119 | eval_metrics = {}
120 | overall_metrics = defaultdict(list)
121 | task_infos = env.unwrapped.task_infos if hasattr(env.unwrapped, 'task_infos') else env.task_infos
122 | num_tasks = FLAGS.eval_tasks if FLAGS.eval_tasks is not None else len(task_infos)
123 | for task_id in tqdm.trange(1, num_tasks + 1):
124 | task_name = task_infos[task_id - 1]['task_name']
125 | eval_info, trajs, cur_renders = evaluate(
126 | agent=eval_agent,
127 | env=env,
128 | task_id=task_id,
129 | config=config,
130 | num_eval_episodes=FLAGS.eval_episodes,
131 | num_video_episodes=FLAGS.video_episodes,
132 | video_frame_skip=FLAGS.video_frame_skip,
133 | eval_temperature=FLAGS.eval_temperature,
134 | eval_gaussian=FLAGS.eval_gaussian,
135 | )
136 | renders.extend(cur_renders)
137 | metric_names = ['success']
138 | eval_metrics.update(
139 | {f'evaluation/{task_name}_{k}': v for k, v in eval_info.items() if k in metric_names}
140 | )
141 | for k, v in eval_info.items():
142 | if k in metric_names:
143 | overall_metrics[k].append(v)
144 | for k, v in overall_metrics.items():
145 | eval_metrics[f'evaluation/overall_{k}'] = np.mean(v)
146 |
147 | if FLAGS.video_episodes > 0:
148 | video = get_wandb_video(renders=renders, n_cols=num_tasks)
149 | eval_metrics['video'] = video
150 |
151 | wandb.log(eval_metrics, step=i)
152 | eval_logger.log(eval_metrics, step=i)
153 |
154 | # Save agent.
155 | if i % FLAGS.save_interval == 0:
156 | save_agent(agent, FLAGS.save_dir, i)
157 |
158 | train_logger.close()
159 | eval_logger.close()
160 |
161 |
162 | if __name__ == '__main__':
163 | app.run(main)
164 |
--------------------------------------------------------------------------------
/impls/requirements.txt:
--------------------------------------------------------------------------------
1 | ogbench # Use the PyPI version of OGBench. Replace this with `pip install -e .` if you want to use the local version.
2 | jax[cuda12] >= 0.4.26
3 | flax >= 0.8.4
4 | distrax >= 0.1.5
5 | ml_collections
6 | matplotlib
7 | moviepy
8 | wandb
--------------------------------------------------------------------------------
/impls/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/impls/utils/__init__.py
--------------------------------------------------------------------------------
/impls/utils/encoders.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import Sequence
3 |
4 | import flax.linen as nn
5 | import jax.numpy as jnp
6 |
7 | from utils.networks import MLP
8 |
9 |
10 | class ResnetStack(nn.Module):
11 | """ResNet stack module."""
12 |
13 | num_features: int
14 | num_blocks: int
15 | max_pooling: bool = True
16 |
17 | @nn.compact
18 | def __call__(self, x):
19 | initializer = nn.initializers.xavier_uniform()
20 | conv_out = nn.Conv(
21 | features=self.num_features,
22 | kernel_size=(3, 3),
23 | strides=1,
24 | kernel_init=initializer,
25 | padding='SAME',
26 | )(x)
27 |
28 | if self.max_pooling:
29 | conv_out = nn.max_pool(
30 | conv_out,
31 | window_shape=(3, 3),
32 | padding='SAME',
33 | strides=(2, 2),
34 | )
35 |
36 | for _ in range(self.num_blocks):
37 | block_input = conv_out
38 | conv_out = nn.relu(conv_out)
39 | conv_out = nn.Conv(
40 | features=self.num_features,
41 | kernel_size=(3, 3),
42 | strides=1,
43 | padding='SAME',
44 | kernel_init=initializer,
45 | )(conv_out)
46 |
47 | conv_out = nn.relu(conv_out)
48 | conv_out = nn.Conv(
49 | features=self.num_features,
50 | kernel_size=(3, 3),
51 | strides=1,
52 | padding='SAME',
53 | kernel_init=initializer,
54 | )(conv_out)
55 | conv_out += block_input
56 |
57 | return conv_out
58 |
59 |
60 | class ImpalaEncoder(nn.Module):
61 | """IMPALA encoder."""
62 |
63 | width: int = 1
64 | stack_sizes: tuple = (16, 32, 32)
65 | num_blocks: int = 2
66 | dropout_rate: float = None
67 | mlp_hidden_dims: Sequence[int] = (512,)
68 | layer_norm: bool = False
69 |
70 | def setup(self):
71 | stack_sizes = self.stack_sizes
72 | self.stack_blocks = [
73 | ResnetStack(
74 | num_features=stack_sizes[i] * self.width,
75 | num_blocks=self.num_blocks,
76 | )
77 | for i in range(len(stack_sizes))
78 | ]
79 | if self.dropout_rate is not None:
80 | self.dropout = nn.Dropout(rate=self.dropout_rate)
81 |
82 | @nn.compact
83 | def __call__(self, x, train=True, cond_var=None):
84 | x = x.astype(jnp.float32) / 255.0
85 |
86 | conv_out = x
87 |
88 | for idx in range(len(self.stack_blocks)):
89 | conv_out = self.stack_blocks[idx](conv_out)
90 | if self.dropout_rate is not None:
91 | conv_out = self.dropout(conv_out, deterministic=not train)
92 |
93 | conv_out = nn.relu(conv_out)
94 | if self.layer_norm:
95 | conv_out = nn.LayerNorm()(conv_out)
96 | out = conv_out.reshape((*x.shape[:-3], -1))
97 |
98 | out = MLP(self.mlp_hidden_dims, activate_final=True, layer_norm=self.layer_norm)(out)
99 |
100 | return out
101 |
102 |
103 | class GCEncoder(nn.Module):
104 | """Helper module to handle inputs to goal-conditioned networks.
105 |
106 | It takes in observations (s) and goals (g) and returns the concatenation of `state_encoder(s)`, `goal_encoder(g)`,
107 | and `concat_encoder([s, g])`. It ignores the encoders that are not provided. This way, the module can handle both
108 | early and late fusion (or their variants) of state and goal information.
109 | """
110 |
111 | state_encoder: nn.Module = None
112 | goal_encoder: nn.Module = None
113 | concat_encoder: nn.Module = None
114 |
115 | @nn.compact
116 | def __call__(self, observations, goals=None, goal_encoded=False):
117 | """Returns the representations of observations and goals.
118 |
119 | If `goal_encoded` is True, `goals` is assumed to be already encoded representations. In this case, either
120 | `goal_encoder` or `concat_encoder` must be None.
121 | """
122 | reps = []
123 | if self.state_encoder is not None:
124 | reps.append(self.state_encoder(observations))
125 | if goals is not None:
126 | if goal_encoded:
127 | # Can't have both goal_encoder and concat_encoder in this case.
128 | assert self.goal_encoder is None or self.concat_encoder is None
129 | reps.append(goals)
130 | else:
131 | if self.goal_encoder is not None:
132 | reps.append(self.goal_encoder(goals))
133 | if self.concat_encoder is not None:
134 | reps.append(self.concat_encoder(jnp.concatenate([observations, goals], axis=-1)))
135 | reps = jnp.concatenate(reps, axis=-1)
136 | return reps
137 |
138 |
139 | encoder_modules = {
140 | 'impala': ImpalaEncoder,
141 | 'impala_debug': functools.partial(ImpalaEncoder, num_blocks=1, stack_sizes=(4, 4)),
142 | 'impala_small': functools.partial(ImpalaEncoder, num_blocks=1),
143 | 'impala_large': functools.partial(ImpalaEncoder, stack_sizes=(64, 128, 128), mlp_hidden_dims=(1024,)),
144 | }
145 |
--------------------------------------------------------------------------------
/impls/utils/env_utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import platform
4 | import time
5 |
6 | import gymnasium
7 | import numpy as np
8 | from gymnasium.spaces import Box
9 |
10 | import ogbench
11 | from utils.datasets import Dataset
12 |
13 |
14 | class EpisodeMonitor(gymnasium.Wrapper):
15 | """Environment wrapper to monitor episode statistics."""
16 |
17 | def __init__(self, env):
18 | super().__init__(env)
19 | self._reset_stats()
20 | self.total_timesteps = 0
21 |
22 | def _reset_stats(self):
23 | self.reward_sum = 0.0
24 | self.episode_length = 0
25 | self.start_time = time.time()
26 |
27 | def step(self, action):
28 | observation, reward, terminated, truncated, info = self.env.step(action)
29 |
30 | self.reward_sum += reward
31 | self.episode_length += 1
32 | self.total_timesteps += 1
33 | info['total'] = {'timesteps': self.total_timesteps}
34 |
35 | if terminated or truncated:
36 | info['episode'] = {}
37 | info['episode']['return'] = self.reward_sum
38 | info['episode']['length'] = self.episode_length
39 | info['episode']['duration'] = time.time() - self.start_time
40 |
41 | return observation, reward, terminated, truncated, info
42 |
43 | def reset(self, *args, **kwargs):
44 | self._reset_stats()
45 | return self.env.reset(*args, **kwargs)
46 |
47 |
48 | class FrameStackWrapper(gymnasium.Wrapper):
49 | """Environment wrapper to stack observations."""
50 |
51 | def __init__(self, env, num_stack):
52 | super().__init__(env)
53 |
54 | self.num_stack = num_stack
55 | self.frames = collections.deque(maxlen=num_stack)
56 |
57 | low = np.concatenate([self.observation_space.low] * num_stack, axis=-1)
58 | high = np.concatenate([self.observation_space.high] * num_stack, axis=-1)
59 | self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
60 |
61 | def get_observation(self):
62 | assert len(self.frames) == self.num_stack
63 | return np.concatenate(list(self.frames), axis=-1)
64 |
65 | def reset(self, **kwargs):
66 | ob, info = self.env.reset(**kwargs)
67 | for _ in range(self.num_stack):
68 | self.frames.append(ob)
69 | if 'goal' in info:
70 | info['goal'] = np.concatenate([info['goal']] * self.num_stack, axis=-1)
71 | return self.get_observation(), info
72 |
73 | def step(self, action):
74 | ob, reward, terminated, truncated, info = self.env.step(action)
75 | self.frames.append(ob)
76 | return self.get_observation(), reward, terminated, truncated, info
77 |
78 |
79 | def make_env_and_datasets(dataset_name, frame_stack=None):
80 | """Make OGBench environment and datasets.
81 |
82 | Args:
83 | dataset_name: Name of the dataset.
84 | frame_stack: Number of frames to stack.
85 |
86 | Returns:
87 | A tuple of the environment, training dataset, and validation dataset.
88 | """
89 | # Use compact dataset to save memory.
90 | env, train_dataset, val_dataset = ogbench.make_env_and_datasets(dataset_name, compact_dataset=True)
91 | train_dataset = Dataset.create(**train_dataset)
92 | val_dataset = Dataset.create(**val_dataset)
93 |
94 | if frame_stack is not None:
95 | env = FrameStackWrapper(env, frame_stack)
96 |
97 | env.reset()
98 |
99 | return env, train_dataset, val_dataset
100 |
--------------------------------------------------------------------------------
/impls/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import jax
4 | import numpy as np
5 | from tqdm import trange
6 |
7 |
8 | def supply_rng(f, rng=jax.random.PRNGKey(0)):
9 | """Helper function to split the random number generator key before each call to the function."""
10 |
11 | def wrapped(*args, **kwargs):
12 | nonlocal rng
13 | rng, key = jax.random.split(rng)
14 | return f(*args, seed=key, **kwargs)
15 |
16 | return wrapped
17 |
18 |
19 | def flatten(d, parent_key='', sep='.'):
20 | """Flatten a dictionary."""
21 | items = []
22 | for k, v in d.items():
23 | new_key = parent_key + sep + k if parent_key else k
24 | if hasattr(v, 'items'):
25 | items.extend(flatten(v, new_key, sep=sep).items())
26 | else:
27 | items.append((new_key, v))
28 | return dict(items)
29 |
30 |
31 | def add_to(dict_of_lists, single_dict):
32 | """Append values to the corresponding lists in the dictionary."""
33 | for k, v in single_dict.items():
34 | dict_of_lists[k].append(v)
35 |
36 |
37 | def evaluate(
38 | agent,
39 | env,
40 | task_id=None,
41 | config=None,
42 | num_eval_episodes=50,
43 | num_video_episodes=0,
44 | video_frame_skip=3,
45 | eval_temperature=0,
46 | eval_gaussian=None,
47 | ):
48 | """Evaluate the agent in the environment.
49 |
50 | Args:
51 | agent: Agent.
52 | env: Environment.
53 | task_id: Task ID to be passed to the environment.
54 | config: Configuration dictionary.
55 | num_eval_episodes: Number of episodes to evaluate the agent.
56 | num_video_episodes: Number of episodes to render. These episodes are not included in the statistics.
57 | video_frame_skip: Number of frames to skip between renders.
58 | eval_temperature: Action sampling temperature.
59 | eval_gaussian: Standard deviation of the Gaussian noise to add to the actions.
60 |
61 | Returns:
62 | A tuple containing the statistics, trajectories, and rendered videos.
63 | """
64 | actor_fn = supply_rng(agent.sample_actions, rng=jax.random.PRNGKey(np.random.randint(0, 2**32)))
65 | trajs = []
66 | stats = defaultdict(list)
67 |
68 | renders = []
69 | for i in trange(num_eval_episodes + num_video_episodes):
70 | traj = defaultdict(list)
71 | should_render = i >= num_eval_episodes
72 |
73 | observation, info = env.reset(options=dict(task_id=task_id, render_goal=should_render))
74 | goal = info.get('goal')
75 | goal_frame = info.get('goal_rendered')
76 | done = False
77 | step = 0
78 | render = []
79 | while not done:
80 | action = actor_fn(observations=observation, goals=goal, temperature=eval_temperature)
81 | action = np.array(action)
82 | if not config.get('discrete'):
83 | if eval_gaussian is not None:
84 | action = np.random.normal(action, eval_gaussian)
85 | action = np.clip(action, -1, 1)
86 |
87 | next_observation, reward, terminated, truncated, info = env.step(action)
88 | done = terminated or truncated
89 | step += 1
90 |
91 | if should_render and (step % video_frame_skip == 0 or done):
92 | frame = env.render().copy()
93 | if goal_frame is not None:
94 | render.append(np.concatenate([goal_frame, frame], axis=0))
95 | else:
96 | render.append(frame)
97 |
98 | transition = dict(
99 | observation=observation,
100 | next_observation=next_observation,
101 | action=action,
102 | reward=reward,
103 | done=done,
104 | info=info,
105 | )
106 | add_to(traj, transition)
107 | observation = next_observation
108 | if i < num_eval_episodes:
109 | add_to(stats, flatten(info))
110 | trajs.append(traj)
111 | else:
112 | renders.append(np.array(render))
113 |
114 | for k, v in stats.items():
115 | stats[k] = np.mean(v)
116 |
117 | return stats, trajs, renders
118 |
--------------------------------------------------------------------------------
/impls/utils/log_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from datetime import datetime
4 |
5 | import absl.flags as flags
6 | import ml_collections
7 | import numpy as np
8 | import wandb
9 | from PIL import Image, ImageEnhance
10 |
11 |
12 | class CsvLogger:
13 | """CSV logger for logging metrics to a CSV file."""
14 |
15 | def __init__(self, path):
16 | self.path = path
17 | self.header = None
18 | self.file = None
19 | self.disallowed_types = (wandb.Image, wandb.Video, wandb.Histogram)
20 |
21 | def log(self, row, step):
22 | row['step'] = step
23 | if self.file is None:
24 | self.file = open(self.path, 'w')
25 | if self.header is None:
26 | self.header = [k for k, v in row.items() if not isinstance(v, self.disallowed_types)]
27 | self.file.write(','.join(self.header) + '\n')
28 | filtered_row = {k: v for k, v in row.items() if not isinstance(v, self.disallowed_types)}
29 | self.file.write(','.join([str(filtered_row.get(k, '')) for k in self.header]) + '\n')
30 | else:
31 | filtered_row = {k: v for k, v in row.items() if not isinstance(v, self.disallowed_types)}
32 | self.file.write(','.join([str(filtered_row.get(k, '')) for k in self.header]) + '\n')
33 | self.file.flush()
34 |
35 | def close(self):
36 | if self.file is not None:
37 | self.file.close()
38 |
39 |
40 | def get_exp_name(seed):
41 | """Return the experiment name."""
42 | exp_name = ''
43 | exp_name += f'sd{seed:03d}_'
44 | if 'SLURM_JOB_ID' in os.environ:
45 | exp_name += f's_{os.environ["SLURM_JOB_ID"]}.'
46 | if 'SLURM_PROCID' in os.environ:
47 | exp_name += f'{os.environ["SLURM_PROCID"]}.'
48 | exp_name += f'{datetime.now().strftime("%Y%m%d_%H%M%S")}'
49 |
50 | return exp_name
51 |
52 |
53 | def get_flag_dict():
54 | """Return the dictionary of flags."""
55 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS if '.' not in k}
56 | for k in flag_dict:
57 | if isinstance(flag_dict[k], ml_collections.ConfigDict):
58 | flag_dict[k] = flag_dict[k].to_dict()
59 | return flag_dict
60 |
61 |
62 | def setup_wandb(
63 | entity=None,
64 | project='project',
65 | group=None,
66 | name=None,
67 | mode='online',
68 | ):
69 | """Set up Weights & Biases for logging."""
70 | wandb_output_dir = tempfile.mkdtemp()
71 | tags = [group] if group is not None else None
72 |
73 | init_kwargs = dict(
74 | config=get_flag_dict(),
75 | project=project,
76 | entity=entity,
77 | tags=tags,
78 | group=group,
79 | dir=wandb_output_dir,
80 | name=name,
81 | settings=wandb.Settings(
82 | start_method='thread',
83 | _disable_stats=False,
84 | ),
85 | mode=mode,
86 | save_code=True,
87 | )
88 |
89 | run = wandb.init(**init_kwargs)
90 |
91 | return run
92 |
93 |
94 | def reshape_video(v, n_cols=None):
95 | """Helper function to reshape videos."""
96 | if v.ndim == 4:
97 | v = v[None,]
98 |
99 | _, t, h, w, c = v.shape
100 |
101 | if n_cols is None:
102 | # Set n_cols to the square root of the number of videos.
103 | n_cols = np.ceil(np.sqrt(v.shape[0])).astype(int)
104 | if v.shape[0] % n_cols != 0:
105 | len_addition = n_cols - v.shape[0] % n_cols
106 | v = np.concatenate((v, np.zeros(shape=(len_addition, t, h, w, c))), axis=0)
107 | n_rows = v.shape[0] // n_cols
108 |
109 | v = np.reshape(v, newshape=(n_rows, n_cols, t, h, w, c))
110 | v = np.transpose(v, axes=(2, 5, 0, 3, 1, 4))
111 | v = np.reshape(v, newshape=(t, c, n_rows * h, n_cols * w))
112 |
113 | return v
114 |
115 |
116 | def get_wandb_video(renders=None, n_cols=None, fps=15):
117 | """Return a Weights & Biases video.
118 |
119 | It takes a list of videos and reshapes them into a single video with the specified number of columns.
120 |
121 | Args:
122 | renders: List of videos. Each video should be a numpy array of shape (t, h, w, c).
123 | n_cols: Number of columns for the reshaped video. If None, it is set to the square root of the number of videos.
124 | """
125 | # Pad videos to the same length.
126 | max_length = max([len(render) for render in renders])
127 | for i, render in enumerate(renders):
128 | assert render.dtype == np.uint8
129 |
130 | # Decrease brightness of the padded frames.
131 | final_frame = render[-1]
132 | final_image = Image.fromarray(final_frame)
133 | enhancer = ImageEnhance.Brightness(final_image)
134 | final_image = enhancer.enhance(0.5)
135 | final_frame = np.array(final_image)
136 |
137 | pad = np.repeat(final_frame[np.newaxis, ...], max_length - len(render), axis=0)
138 | renders[i] = np.concatenate([render, pad], axis=0)
139 |
140 | # Add borders.
141 | renders[i] = np.pad(renders[i], ((0, 0), (1, 1), (1, 1), (0, 0)), mode='constant', constant_values=0)
142 | renders = np.array(renders) # (n, t, h, w, c)
143 |
144 | renders = reshape_video(renders, n_cols) # (t, c, nr * h, nc * w)
145 |
146 | return wandb.Video(renders, fps=fps, format='mp4')
147 |
--------------------------------------------------------------------------------
/ogbench/__init__.py:
--------------------------------------------------------------------------------
1 | """OGBench: Benchmarking Offline Goal-Conditioned RL"""
2 |
3 | import ogbench.locomaze
4 | import ogbench.manipspace
5 | import ogbench.powderworld
6 | from ogbench.utils import download_datasets, load_dataset, make_env_and_datasets
7 |
8 | __all__ = (
9 | 'locomaze',
10 | 'manipspace',
11 | 'powderworld',
12 | 'download_datasets',
13 | 'load_dataset',
14 | 'make_env_and_datasets',
15 | )
16 |
--------------------------------------------------------------------------------
/ogbench/locomaze/ant.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import gymnasium
4 | import numpy as np
5 | from gymnasium import utils
6 | from gymnasium.envs.mujoco import MujocoEnv
7 | from gymnasium.spaces import Box
8 |
9 |
10 | class AntEnv(MujocoEnv, utils.EzPickle):
11 | """Gymnasium Ant environment.
12 |
13 | Unlike the original Ant environment, this environment uses a restricted joint range for the actuators, as typically
14 | done in previous works in hierarchical reinforcement learning. It also uses a control frequency of 10Hz instead of
15 | 20Hz, which is the default in the original environment.
16 | """
17 |
18 | xml_file = os.path.join(os.path.dirname(__file__), 'assets', 'ant.xml')
19 | metadata = {
20 | 'render_modes': ['human', 'rgb_array', 'depth_array'],
21 | 'render_fps': 10,
22 | }
23 | if gymnasium.__version__ >= '1.1.0':
24 | metadata['render_modes'] += ['rgbd_tuple']
25 |
26 | def __init__(
27 | self,
28 | xml_file=None,
29 | reset_noise_scale=0.1,
30 | render_mode='rgb_array',
31 | width=200,
32 | height=200,
33 | **kwargs,
34 | ):
35 | """Initialize the Ant environment.
36 |
37 | Args:
38 | xml_file: Path to the XML description (optional).
39 | reset_noise_scale: Scale of the noise added to the initial state during reset.
40 | render_mode: Rendering mode.
41 | width: Width of the rendered image.
42 | height: Height of the rendered image.
43 | **kwargs: Additional keyword arguments.
44 | """
45 | if xml_file is None:
46 | xml_file = self.xml_file
47 | utils.EzPickle.__init__(
48 | self,
49 | xml_file,
50 | reset_noise_scale,
51 | **kwargs,
52 | )
53 |
54 | self._reset_noise_scale = reset_noise_scale
55 |
56 | observation_space = Box(low=-np.inf, high=np.inf, shape=(29,), dtype=np.float64)
57 |
58 | MujocoEnv.__init__(
59 | self,
60 | xml_file,
61 | frame_skip=5,
62 | observation_space=observation_space,
63 | render_mode=render_mode,
64 | width=width,
65 | height=height,
66 | **kwargs,
67 | )
68 |
69 | def step(self, action):
70 | prev_qpos = self.data.qpos.copy()
71 | prev_qvel = self.data.qvel.copy()
72 |
73 | self.do_simulation(action, self.frame_skip)
74 |
75 | qpos = self.data.qpos.copy()
76 | qvel = self.data.qvel.copy()
77 |
78 | observation = self.get_ob()
79 |
80 | if self.render_mode == 'human':
81 | self.render()
82 |
83 | return (
84 | observation,
85 | 0.0,
86 | False,
87 | False,
88 | {
89 | 'xy': self.get_xy(),
90 | 'prev_qpos': prev_qpos,
91 | 'prev_qvel': prev_qvel,
92 | 'qpos': qpos,
93 | 'qvel': qvel,
94 | },
95 | )
96 |
97 | def get_ob(self):
98 | position = self.data.qpos.flat.copy()
99 | velocity = self.data.qvel.flat.copy()
100 |
101 | return np.concatenate([position, velocity])
102 |
103 | def reset_model(self):
104 | noise_low = -self._reset_noise_scale
105 | noise_high = self._reset_noise_scale
106 |
107 | qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
108 | qvel = self.init_qvel + self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
109 | self.set_state(qpos, qvel)
110 |
111 | observation = self.get_ob()
112 |
113 | return observation
114 |
115 | def get_xy(self):
116 | return self.data.qpos[:2].copy()
117 |
118 | def set_xy(self, xy):
119 | qpos = self.data.qpos.copy()
120 | qvel = self.data.qvel.copy()
121 | qpos[:2] = xy
122 | self.set_state(qpos, qvel)
123 |
--------------------------------------------------------------------------------
/ogbench/locomaze/assets/ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/ogbench/locomaze/assets/point.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/ogbench/locomaze/humanoid.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import os
3 |
4 | import gymnasium
5 | import mujoco
6 | import numpy as np
7 | from gymnasium import utils
8 | from gymnasium.envs.mujoco import MujocoEnv
9 | from gymnasium.spaces import Box
10 |
11 |
12 | class HumanoidEnv(MujocoEnv, utils.EzPickle):
13 | """DMC Humanoid environment.
14 |
15 | Several methods are reimplemented to remove the dependency on the `dm_control` package. It is supposed to work
16 | identically to the original Humanoid environment.
17 | """
18 |
19 | xml_file = os.path.join(os.path.dirname(__file__), 'assets', 'humanoid.xml')
20 | metadata = {
21 | 'render_modes': ['human', 'rgb_array', 'depth_array'],
22 | 'render_fps': 40,
23 | }
24 | if gymnasium.__version__ >= '1.1.0':
25 | metadata['render_modes'] += ['rgbd_tuple']
26 |
27 | def __init__(
28 | self,
29 | xml_file=None,
30 | render_mode='rgb_array',
31 | width=200,
32 | height=200,
33 | **kwargs,
34 | ):
35 | """Initialize the Humanoid environment.
36 |
37 | Args:
38 | xml_file: Path to the XML description (optional).
39 | render_mode: Rendering mode.
40 | width: Width of the rendered image.
41 | height: Height of the rendered image.
42 | **kwargs: Additional keyword arguments.
43 | """
44 | if xml_file is None:
45 | xml_file = self.xml_file
46 | utils.EzPickle.__init__(
47 | self,
48 | xml_file,
49 | **kwargs,
50 | )
51 |
52 | observation_space = Box(low=-np.inf, high=np.inf, shape=(69,), dtype=np.float64)
53 |
54 | MujocoEnv.__init__(
55 | self,
56 | xml_file,
57 | frame_skip=5,
58 | observation_space=observation_space,
59 | render_mode=render_mode,
60 | width=width,
61 | height=height,
62 | **kwargs,
63 | )
64 |
65 | def step(self, action):
66 | prev_qpos = self.data.qpos.copy()
67 | prev_qvel = self.data.qvel.copy()
68 |
69 | self.do_simulation(action, self.frame_skip)
70 |
71 | qpos = self.data.qpos.copy()
72 | qvel = self.data.qvel.copy()
73 |
74 | observation = self.get_ob()
75 |
76 | if self.render_mode == 'human':
77 | self.render()
78 |
79 | return (
80 | observation,
81 | 0.0,
82 | False,
83 | False,
84 | {
85 | 'xy': self.get_xy(),
86 | 'prev_qpos': prev_qpos,
87 | 'prev_qvel': prev_qvel,
88 | 'qpos': qpos,
89 | 'qvel': qvel,
90 | },
91 | )
92 |
93 | def _step_mujoco_simulation(self, ctrl, n_frames):
94 | self.data.ctrl[:] = ctrl
95 |
96 | # DMC-style stepping (see Page 6 of https://arxiv.org/abs/2006.12983).
97 | if self.model.opt.integrator != mujoco.mjtIntegrator.mjINT_RK4.value:
98 | mujoco.mj_step2(self.model, self.data)
99 | if n_frames > 1:
100 | mujoco.mj_step(self.model, self.data, n_frames - 1)
101 | else:
102 | mujoco.mj_step(self.model, self.data, n_frames)
103 |
104 | mujoco.mj_step1(self.model, self.data)
105 |
106 | def get_ob(self):
107 | xy = self.data.qpos[:2]
108 | joint_angles = self.data.qpos[7:] # Skip the 7 DoFs of the free root joint.
109 | head_height = self.data.xpos[2, 2] # ['head', 'z']
110 | torso_frame = self.data.xmat[1].reshape(3, 3) # ['torso']
111 | torso_pos = self.data.xpos[1] # ['torso']
112 | positions = []
113 | for idx in [16, 10, 13, 7]: # ['left_hand', 'left_foot', 'right_hand', 'right_foot']
114 | torso_to_limb = self.data.xpos[idx] - torso_pos
115 | positions.append(torso_to_limb.dot(torso_frame))
116 | extremities = np.hstack(positions)
117 | torso_vertical_orientation = self.data.xmat[1, [6, 7, 8]] # ['torso', ['zx', 'zy', 'zz']]
118 | center_of_mass_velocity = self.data.sensordata[0:3] # ['torso_subtreelinvel']
119 | velocity = self.data.qvel
120 |
121 | return np.concatenate(
122 | [
123 | xy,
124 | joint_angles,
125 | [head_height],
126 | extremities,
127 | torso_vertical_orientation,
128 | center_of_mass_velocity,
129 | velocity,
130 | ]
131 | )
132 |
133 | @contextlib.contextmanager
134 | def disable(self, *flags):
135 | old_bitmask = self.model.opt.disableflags
136 | new_bitmask = old_bitmask
137 | for flag in flags:
138 | if isinstance(flag, str):
139 | field_name = 'mjDSBL_' + flag.upper()
140 | flag = getattr(mujoco.mjtDisableBit, field_name)
141 | elif isinstance(flag, int):
142 | flag = mujoco.mjtDisableBit(flag)
143 | new_bitmask |= flag.value
144 | self.model.opt.disableflags = new_bitmask
145 | try:
146 | yield
147 | finally:
148 | self.model.opt.disableflags = old_bitmask
149 |
150 | def reset_model(self):
151 | penetrating = True
152 | while penetrating:
153 | quat = self.np_random.uniform(size=4)
154 | quat /= np.linalg.norm(quat)
155 | self.data.qpos[3:7] = quat
156 | self.data.qvel = 0.1 * self.np_random.standard_normal(self.model.nv)
157 |
158 | for joint_id in range(1, self.model.njnt):
159 | range_min, range_max = self.model.jnt_range[joint_id]
160 | self.data.qpos[6 + joint_id] = self.np_random.uniform(range_min, range_max)
161 |
162 | with self.disable('actuation'):
163 | mujoco.mj_forward(self.model, self.data)
164 | penetrating = self.data.ncon > 0
165 |
166 | observation = self.get_ob()
167 |
168 | return observation
169 |
170 | def get_xy(self):
171 | return self.data.qpos[:2].copy()
172 |
173 | def set_xy(self, xy):
174 | qpos = self.data.qpos.copy()
175 | qvel = self.data.qvel.copy()
176 | qpos[:2] = xy
177 | self.set_state(qpos, qvel)
178 |
--------------------------------------------------------------------------------
/ogbench/locomaze/point.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import gymnasium
4 | import mujoco
5 | import numpy as np
6 | from gymnasium import utils
7 | from gymnasium.envs.mujoco import MujocoEnv
8 | from gymnasium.spaces import Box
9 |
10 |
11 | class PointEnv(MujocoEnv, utils.EzPickle):
12 | """PointMass environment.
13 |
14 | This is a simple 2-D point mass environment, where the agent is controlled by an x-y action vector that is added to
15 | the current position of the point mass.
16 | """
17 |
18 | xml_file = os.path.join(os.path.dirname(__file__), 'assets', 'point.xml')
19 | metadata = {
20 | 'render_modes': ['human', 'rgb_array', 'depth_array'],
21 | 'render_fps': 10,
22 | }
23 | if gymnasium.__version__ >= '1.1.0':
24 | metadata['render_modes'] += ['rgbd_tuple']
25 |
26 | def __init__(
27 | self,
28 | xml_file=None,
29 | render_mode='rgb_array',
30 | width=200,
31 | height=200,
32 | **kwargs,
33 | ):
34 | """Initialize the Humanoid environment.
35 |
36 | Args:
37 | xml_file: Path to the XML description (optional).
38 | render_mode: Rendering mode.
39 | width: Width of the rendered image.
40 | height: Height of the rendered image.
41 | **kwargs: Additional keyword arguments.
42 | """
43 | if xml_file is None:
44 | xml_file = self.xml_file
45 | utils.EzPickle.__init__(
46 | self,
47 | xml_file,
48 | **kwargs,
49 | )
50 |
51 | observation_space = Box(low=-np.inf, high=np.inf, shape=(6,), dtype=np.float64)
52 |
53 | MujocoEnv.__init__(
54 | self,
55 | xml_file,
56 | frame_skip=5,
57 | observation_space=observation_space,
58 | render_mode=render_mode,
59 | width=width,
60 | height=height,
61 | **kwargs,
62 | )
63 |
64 | def step(self, action):
65 | prev_qpos = self.data.qpos.copy()
66 | prev_qvel = self.data.qvel.copy()
67 |
68 | action = 0.2 * action
69 |
70 | self.data.qpos[:] = self.data.qpos + action
71 | self.data.qvel[:] = np.array([0.0, 0.0])
72 |
73 | mujoco.mj_step(self.model, self.data, nstep=self.frame_skip)
74 |
75 | qpos = self.data.qpos.flat.copy()
76 | qvel = self.data.qvel.flat.copy()
77 |
78 | observation = self.get_ob()
79 |
80 | if self.render_mode == 'human':
81 | self.render()
82 |
83 | return (
84 | observation,
85 | 0.0,
86 | False,
87 | False,
88 | {
89 | 'xy': self.get_xy(),
90 | 'prev_qpos': prev_qpos,
91 | 'prev_qvel': prev_qvel,
92 | 'qpos': qpos,
93 | 'qvel': qvel,
94 | },
95 | )
96 |
97 | def get_ob(self):
98 | return self.data.qpos.flat.copy()
99 |
100 | def reset_model(self):
101 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.1, high=0.1)
102 | qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
103 |
104 | self.set_state(qpos, qvel)
105 |
106 | return self.get_ob()
107 |
108 | def get_xy(self):
109 | return self.data.qpos.copy()
110 |
111 | def set_xy(self, xy):
112 | qpos = self.data.qpos.copy()
113 | qvel = self.data.qvel.copy()
114 | qpos[:] = xy
115 | self.set_state(qpos, qvel)
116 |
--------------------------------------------------------------------------------
/ogbench/manipspace/controllers/__init__.py:
--------------------------------------------------------------------------------
1 | from ogbench.manipspace.controllers.diff_ik import DiffIKController
2 |
3 | __all__ = ('DiffIKController',)
4 |
--------------------------------------------------------------------------------
/ogbench/manipspace/controllers/diff_ik.py:
--------------------------------------------------------------------------------
1 | import mujoco
2 | import numpy as np
3 |
4 | PI = np.pi
5 | PI_2 = 2 * np.pi
6 |
7 |
8 | def angle_diff(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
9 | return np.mod(q1 - q2 + PI, PI_2) - PI
10 |
11 |
12 | class DiffIKController:
13 | """Differential inverse kinematics controller."""
14 |
15 | def __init__(
16 | self,
17 | model: mujoco.MjModel,
18 | sites: list,
19 | qpos0: np.ndarray = None,
20 | damping_coeff: float = 1e-12,
21 | max_angle_change: float = np.radians(45),
22 | ):
23 | self._model = model
24 | self._data = mujoco.MjData(self._model)
25 | self._qp0 = qpos0
26 | self._max_angle_change = max_angle_change
27 |
28 | # Cache references.
29 | self._ns = len(sites) # Number of sites.
30 | self._site_ids = np.asarray([self._model.site(s).id for s in sites])
31 |
32 | # Preallocate arrays.
33 | self._err = np.empty((self._ns, 6))
34 | self._site_quat = np.empty((self._ns, 4))
35 | self._site_quat_inv = np.empty((self._ns, 4))
36 | self._err_quat = np.empty((self._ns, 4))
37 | self._jac = np.empty((6 * self._ns, self._model.nv))
38 | self._damping = damping_coeff * np.eye(6 * self._ns)
39 | self._eye = np.eye(self._model.nv)
40 |
41 | def _forward_kinematics(self) -> None:
42 | """Minimal computation required for forward kinematics."""
43 | mujoco.mj_kinematics(self._model, self._data)
44 | mujoco.mj_comPos(self._model, self._data) # Required for mj_jacSite.
45 |
46 | def _integrate(self, update: np.ndarray) -> None:
47 | """Integrate the joint velocities in-place."""
48 | mujoco.mj_integratePos(self._model, self._data.qpos, update, 1.0)
49 |
50 | def _compute_translational_error(self, pos: np.ndarray) -> None:
51 | """Compute the error between the desired and current site positions."""
52 | self._err[:, :3] = pos - self._data.site_xpos[self._site_ids]
53 |
54 | def _compute_rotational_error(self, quat: np.ndarray) -> None:
55 | """Compute the error between the desired and current site orientations."""
56 | for i, site_id in enumerate(self._site_ids):
57 | mujoco.mju_mat2Quat(self._site_quat[i], self._data.site_xmat[site_id])
58 | mujoco.mju_negQuat(self._site_quat_inv[i], self._site_quat[i])
59 | mujoco.mju_mulQuat(self._err_quat[i], quat[i], self._site_quat_inv[i])
60 | mujoco.mju_quat2Vel(self._err[i, 3:], self._err_quat[i], 1.0)
61 |
62 | def _compute_jacobian(self) -> None:
63 | """Update site end-effector Jacobians."""
64 | for i, site_id in enumerate(self._site_ids):
65 | jacp = self._jac[6 * i : 6 * i + 3]
66 | jacr = self._jac[6 * i + 3 : 6 * i + 6]
67 | mujoco.mj_jacSite(self._model, self._data, jacp, jacr, site_id)
68 |
69 | def _error_threshold_reached(self, pos_thresh: float, ori_thresh: float) -> bool:
70 | """Return True if position and rotation errors are below the thresholds."""
71 | pos_achieved = np.linalg.norm(self._err[:, :3]) <= pos_thresh
72 | ori_achieved = np.linalg.norm(self._err[:, 3:]) <= ori_thresh
73 | return pos_achieved and ori_achieved
74 |
75 | def _solve(self) -> np.ndarray:
76 | """Solve for joint velocities using damped least squares."""
77 | H = self._jac @ self._jac.T + self._damping
78 | x = self._jac.T @ np.linalg.solve(H, self._err.ravel())
79 | if self._qp0 is not None:
80 | jac_pinv = np.linalg.pinv(H)
81 | q_err = angle_diff(self._qp0, self._data.qpos)
82 | x += (self._eye - (self._jac.T @ jac_pinv) @ self._jac) @ q_err
83 | return x
84 |
85 | def _scale_update(self, update: np.ndarray) -> np.ndarray:
86 | """Scale down update so that the max allowable angle change is not exceeded."""
87 | update_max = np.max(np.abs(update))
88 | if update_max > self._max_angle_change:
89 | update *= self._max_angle_change / update_max
90 | return update
91 |
92 | def solve(
93 | self,
94 | pos: np.ndarray,
95 | quat: np.ndarray,
96 | curr_qpos: np.ndarray,
97 | max_iters: int = 20,
98 | pos_thresh: float = 1e-4,
99 | ori_thresh: float = 1e-4,
100 | ) -> np.ndarray:
101 | self._data.qpos = curr_qpos
102 |
103 | for _ in range(max_iters):
104 | self._forward_kinematics()
105 |
106 | self._compute_translational_error(np.atleast_2d(pos))
107 | self._compute_rotational_error(np.atleast_2d(quat))
108 | if self._error_threshold_reached(pos_thresh, ori_thresh):
109 | break
110 |
111 | self._compute_jacobian()
112 | update = self._scale_update(self._solve())
113 | self._integrate(update)
114 |
115 | return self._data.qpos.copy()
116 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/button_inner.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/button_outer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/buttons.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/cube.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/cube_inner.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/cube_outer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/drawer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/floor_wall.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/button/button.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/button.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/button/buttonring.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/buttonring.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/button/metal1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/metal1.png
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/button/stopbot.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stopbot.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/button/stopbutton.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stopbutton.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/button/stopbuttonrim.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stopbuttonrim.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/button/stopbuttonrod.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stopbuttonrod.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/button/stoptop.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/button/stoptop.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/drawer/drawer.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/drawer/drawer.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/drawer/drawercase.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/drawer/drawercase.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/drawer/drawerhandle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/drawer/drawerhandle.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/window_base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/window_base.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/window_frame.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/window_frame.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/window_h_base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/window_h_base.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/window_h_frame.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/window_h_frame.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/windowa_frame.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowa_frame.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/windowa_glass.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowa_glass.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/windowa_h_frame.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowa_h_frame.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/windowa_h_glass.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowa_h_glass.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/windowb_frame.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowb_frame.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/windowb_glass.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowb_glass.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/windowb_h_frame.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowb_h_frame.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/metaworld/window/windowb_h_glass.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/metaworld/window/windowb_h_glass.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/2f85.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/2f85.png
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2013, ROS-Industrial
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification,
5 | are permitted provided that the following conditions are met:
6 |
7 | Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | Redistributions in binary form must reproduce the above copyright notice, this
11 | list of conditions and the following disclaimer in the documentation and/or
12 | other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
18 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
21 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/README.md:
--------------------------------------------------------------------------------
1 | # Robotiq 2F-85 Description (MJCF)
2 |
3 | Requires MuJoCo 2.2.2 or later.
4 |
5 | ## Overview
6 |
7 | This package contains a simplified robot description (MJCF) of the [Robotiq 85mm
8 | 2-Finger Adaptive
9 | Gripper](https://robotiq.com/products/2f85-140-adaptive-robot-gripper) developed
10 | by [Robotiq](https://robotiq.com/). It is derived from the [publicly available
11 | URDF
12 | description](https://github.com/ros-industrial/robotiq/tree/kinetic-devel/robotiq_2f_85_gripper_visualization).
13 |
14 |
15 |
16 |
17 |
18 | ## URDF → MJCF derivation steps
19 |
20 | 1. Added `` to the URDF's
21 | `` clause in order to preserve visual geometries.
22 | 2. Loaded the URDF into MuJoCo and saved a corresponding MJCF.
23 | 3. Manually edited the MJCF to extract common properties into the `` section.
24 | 4. Added `` clauses to prevent collisions between the linkage bodies.
25 | 5. Broke up collision pads into two pads for more contacts.
26 | 6. Increased pad friction and priority.
27 | 7. Added `impratio=10` for better noslip.
28 | 8. Added `scene.xml` which includes the robot, with a textured groundplane, skybox, and haze.
29 | 9. Added hanging box to `scene.xml`.
30 |
31 | ## License
32 |
33 | This model is released under a [BSD-2-Clause License](LICENSE).
34 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/assets/base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/base.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/assets/base_mount.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/base_mount.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/assets/coupler.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/coupler.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/assets/driver.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/driver.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/assets/follower.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/follower.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/assets/pad.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/pad.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/assets/silicone_pad.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/silicone_pad.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/assets/spring_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seohongpark/ogbench/9c0200a05e728c0b81f76ca9889c2cc83e819a94/ogbench/manipspace/descriptions/robotiq_2f85/assets/spring_link.stl
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/robotiq_2f85/scene.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/universal_robots_ur5e/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2018 ROS Industrial Consortium
2 |
3 | Redistribution and use in source and binary forms, with or without modification,
4 | are permitted provided that the following conditions are met:
5 |
6 | 1. Redistributions of source code must retain the above copyright notice, this
7 | list of conditions and the following disclaimer.
8 |
9 | 2. Redistributions in binary form must reproduce the above copyright notice,
10 | this list of conditions and the following disclaimer in the documentation and/or
11 | other materials provided with the distribution.
12 |
13 | 3. Neither the name of the copyright holder nor the names of its contributors
14 | may be used to endorse or promote products derived from this software without
15 | specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 |
--------------------------------------------------------------------------------
/ogbench/manipspace/descriptions/universal_robots_ur5e/README.md:
--------------------------------------------------------------------------------
1 | # Universal Robots UR5e Description (MJCF)
2 |
3 | Requires MuJoCo 2.3.3 or later.
4 |
5 | ## Overview
6 |
7 | This package contains a simplified robot description (MJCF) of the
8 | [UR5e](https://www.universal-robots.com/products/ur5-robot/) developed by
9 | [Universal Robots](https://www.universal-robots.com/). It is derived from the
10 | [publicly available URDF
11 | description](https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_e_description).
12 |
13 |