├── .gitignore ├── README.md ├── algs ├── bc.py └── iql.py ├── common ├── dataset.py ├── envs │ ├── bandit │ │ └── bandit.py │ ├── d4rl │ │ ├── antmaze_actions.npy │ │ ├── d4rl_ant.py │ │ └── d4rl_utils.py │ ├── data_transforms.py │ ├── dmc │ │ ├── __init__.py │ │ ├── jaco.py │ │ └── wrappers.py │ ├── env_helper.py │ ├── exorl │ │ ├── custom_dmc_tasks │ │ │ ├── __init__.py │ │ │ ├── cheetah.py │ │ │ ├── cheetah.xml │ │ │ ├── hopper.py │ │ │ ├── hopper.xml │ │ │ ├── jaco.py │ │ │ ├── quadruped.py │ │ │ ├── quadruped.xml │ │ │ ├── walker.py │ │ │ └── walker.xml │ │ ├── dmc.py │ │ └── exorl_utils.py │ ├── gc_utils.py │ └── wrappers.py ├── evaluation.py ├── networks │ ├── basic.py │ └── transformer.py ├── train_state.py ├── typing.py ├── utils.py └── wandb.py ├── deps ├── base_container.def ├── environment.yml └── requirements.txt └── experiment ├── ant_helper.py ├── rewards_eval.py ├── rewards_unsupervised.py └── run_fre.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | *.sh 162 | make_sbatch/ 163 | experiment_output/ 164 | ipynbs/ 165 | .DS_Store 166 | data/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Zero-Shot RL via Functional Reward Representations 2 | Code for "Unsupervised Zero-Shot RL via Functional Reward Representations" 3 | 4 | Kevin Frans, Seohong Park, Pieter Abbeel, Sergey Levine 5 | 6 | [Link to Paper](https://arxiv.org/abs/2402.17135) 7 | 8 | ### Abstract 9 | Can we pre-train a generalist agent from a large amount of unlabeled offline trajectories such that it can be immediately adapted to any new downstream tasks in a zero-shot manner? 10 | In this work, we present a \emph{functional} reward encoding (FRE) as a general, scalable solution to this *zero-shot RL* problem. 11 | Our main idea is to learn functional representations of any arbitrary tasks by encoding their state-reward samples using a transformer-based variational auto-encoder. 12 | This functional encoding not only enables the pre-training of an agent from a wide diversity of general unsupervised reward functions, but also provides a way to solve any new downstream tasks in a zero-shot manner, given a small number of reward-annotated samples. 13 | We empirically show that FRE agents trained on diverse random unsupervised reward functions can generalize to solve novel tasks in a range of simulated robotic benchmarks, often outperforming previous zero-shot RL and offline RL methods. 14 | 15 | ### Code Instructions 16 | First install the dependencies in the `deps` folder. 17 | ``` 18 | cd deps 19 | conda env create -f environment.yml 20 | ``` 21 | 22 | For the ExORL experiments, you will need to first download the data using [these instructions](https://github.com/denisyarats/exorl). 23 | Then, download the [auxilliary offline data](https://drive.google.com/drive/folders/1HDkCP6-eLKuyQRPcyO3ei-vhubMRlGct?usp=sharing) and place it in the `data/` folder. 24 | 25 | To run the code for the experiments, use the following commands. 26 | 27 | ``` 28 | # AntMaze 29 | python experiment/run_pre.py --env_name antmaze-large-diverse-v2 30 | # ExORL 31 | python experiment/run_pre.py --env_name dmc_walker_walk --agent.warmup_steps 1000000 --max_steps 2000000 32 | python experiment/run_pre.py --env_name dmc_cheetah_run --agent.warmup_steps 1000000 --max_steps 2000000 33 | # Kitchen 34 | python experiment/run_pre.py --env_name kitchen-mixed-v0 --agent.warmup_steps 1000000 --max_steps 2000000 35 | ``` 36 | -------------------------------------------------------------------------------- /algs/bc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from absl import app, flags 3 | from functools import partial 4 | import numpy as np 5 | import tqdm 6 | import jax 7 | import jax.numpy as jnp 8 | import optax 9 | import flax 10 | import wandb 11 | from ml_collections import config_flags 12 | import pickle 13 | from flax.training import checkpoints 14 | import ml_collections 15 | 16 | import fre.common.envs.d4rl.d4rl_utils as d4rl_utils 17 | from fre.common.envs.gc_utils import GCDataset 18 | from fre.common.envs.env_helper import make_env 19 | from fre.common.wandb import setup_wandb, default_wandb_config, get_flag_dict 20 | from fre.common.evaluation import evaluate 21 | from fre.common.utils import supply_rng 22 | from fre.common.typing import * 23 | from fre.common.train_state import TrainState, target_update 24 | from fre.common.networks.basic import Policy, ValueCritic, Critic, ensemblize 25 | 26 | 27 | ############################### 28 | # Configs 29 | ############################### 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | flags.DEFINE_string('env_name', 'gc-antmaze-large-diverse-v2', 'Environment name.') 34 | flags.DEFINE_string('name', 'default', '') 35 | 36 | flags.DEFINE_string('save_dir', None, 'Logging dir (if not None, save params).') 37 | 38 | flags.DEFINE_integer('seed', np.random.choice(1000000), 'Random seed.') 39 | flags.DEFINE_integer('eval_episodes', 20, 40 | 'Number of episodes used for evaluation.') 41 | flags.DEFINE_integer('log_interval', 1000, 'Logging interval.') 42 | flags.DEFINE_integer('eval_interval', 50000, 'Eval interval.') 43 | flags.DEFINE_integer('save_interval', 250000, 'Eval interval.') 44 | flags.DEFINE_integer('video_interval', 50000, 'Eval interval.') 45 | flags.DEFINE_integer('batch_size', 1024, 'Mini batch size.') 46 | flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.') 47 | flags.DEFINE_integer('goal_conditioned', 0, 'Whether to use goal conditioned relabelling or not.') 48 | 49 | # These variables are passed to the BCAgent class. 50 | agent_config = ml_collections.ConfigDict({ 51 | 'actor_lr': 3e-4, 52 | 'hidden_dims': (512, 512, 512), 53 | 'opt_decay_schedule': 'none', 54 | 'use_tanh': 0, 55 | 'state_dependent_std': 0, 56 | 'use_layer_norm': 1, 57 | }) 58 | 59 | wandb_config = default_wandb_config() 60 | wandb_config.update({ 61 | 'project': 'mujoco_rlalgs', 62 | 'name': 'bc_{env_name}', 63 | }) 64 | 65 | 66 | config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False) 67 | config_flags.DEFINE_config_dict('agent', agent_config, lock_config=False) 68 | config_flags.DEFINE_config_dict('gcdataset', GCDataset.get_default_config(), lock_config=False) 69 | 70 | ############################### 71 | # Agent. Contains the neural networks, training logic, and sampling. 72 | ############################### 73 | 74 | class BCAgent(flax.struct.PyTreeNode): 75 | rng: PRNGKey 76 | actor: TrainState 77 | config: dict = flax.struct.field(pytree_node=False) 78 | 79 | @jax.jit 80 | def update(agent, batch: Batch) -> InfoDict: 81 | observations = batch['observations'] 82 | actions = batch['actions'] 83 | 84 | def actor_loss_fn(actor_params): 85 | dist = agent.actor(observations, params=actor_params) 86 | log_probs = dist.log_prob(actions) 87 | actor_loss = -(log_probs).mean() 88 | 89 | mse_error = jnp.square(dist.loc - actions).mean() 90 | 91 | return actor_loss, { 92 | 'actor_loss': actor_loss, 93 | 'action_std': dist.stddev().mean(), 94 | 'mse_error': mse_error, 95 | } 96 | 97 | new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=actor_loss_fn, has_aux=True) 98 | 99 | return agent.replace(actor=new_actor), { 100 | **actor_info 101 | } 102 | 103 | @jax.jit 104 | def sample_actions(agent, 105 | observations: np.ndarray, 106 | *, 107 | seed: PRNGKey, 108 | temperature: float = 1.0) -> jnp.ndarray: 109 | if type(observations) is dict: 110 | observations = jnp.concatenate([observations['observation'], observations['goal']], axis=-1) 111 | actions = agent.actor(observations, temperature=temperature).sample(seed=seed) 112 | actions = jnp.clip(actions, -1, 1) 113 | return actions 114 | 115 | def create_agent( 116 | seed: int, 117 | observations: jnp.ndarray, 118 | actions: jnp.ndarray, 119 | actor_lr: float, 120 | use_tanh: bool, 121 | state_dependent_std: bool, 122 | use_layer_norm: bool, 123 | hidden_dims: Sequence[int], 124 | opt_decay_schedule: str, 125 | max_steps: Optional[int] = None, 126 | **kwargs): 127 | 128 | print('Extra kwargs:', kwargs) 129 | 130 | rng = jax.random.PRNGKey(seed) 131 | rng, actor_key, critic_key, value_key = jax.random.split(rng, 4) 132 | 133 | action_dim = actions.shape[-1] 134 | actor_def = Policy(hidden_dims, action_dim=action_dim, 135 | log_std_min=-5.0, state_dependent_std=state_dependent_std, tanh_squash_distribution=use_tanh, mlp_kwargs=dict(use_layer_norm=use_layer_norm)) 136 | 137 | if opt_decay_schedule == "cosine": 138 | schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps) 139 | actor_tx = optax.chain(optax.scale_by_adam(), 140 | optax.scale_by_schedule(schedule_fn)) 141 | else: 142 | actor_tx = optax.adam(learning_rate=actor_lr) 143 | 144 | actor_params = actor_def.init(actor_key, observations)['params'] 145 | actor = TrainState.create(actor_def, actor_params, tx=actor_tx) 146 | 147 | config = flax.core.FrozenDict(dict( 148 | actor_lr=actor_lr, 149 | )) 150 | 151 | return BCAgent(rng, actor=actor, config=config) 152 | 153 | ############################### 154 | # Run Script. Loads data, logs to wandb, and runs the training loop. 155 | ############################### 156 | 157 | def main(_): 158 | if FLAGS.goal_conditioned: 159 | assert 'gc' in FLAGS.env_name 160 | else: 161 | assert 'gc' not in FLAGS.env_name 162 | 163 | # Create wandb logger 164 | setup_wandb(FLAGS.agent.to_dict(), **FLAGS.wandb) 165 | 166 | env = make_env(FLAGS.env_name) 167 | eval_env = make_env(FLAGS.env_name) 168 | 169 | dataset = d4rl_utils.get_dataset(env, FLAGS.env_name) 170 | dataset = d4rl_utils.normalize_dataset(FLAGS.env_name, dataset) 171 | if FLAGS.goal_conditioned: 172 | dataset = GCDataset(dataset, **FLAGS.gcdataset.to_dict()) 173 | example_batch = dataset.sample(1) 174 | example_obs = np.concatenate([example_batch['observations'], example_batch['goals']], axis=-1) 175 | debug_batch = dataset.sample(100) 176 | print("Masks Look Like", debug_batch['masks']) 177 | print("Rewards Look Like", debug_batch['rewards']) 178 | else: 179 | example_obs = dataset.sample(1)['observations'] 180 | 181 | agent = create_agent(FLAGS.seed, 182 | example_obs, 183 | example_batch['actions'], 184 | max_steps=FLAGS.max_steps, 185 | **FLAGS.agent) 186 | 187 | for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1), 188 | smoothing=0.1, 189 | dynamic_ncols=True): 190 | 191 | batch = dataset.sample(FLAGS.batch_size) 192 | if FLAGS.goal_conditioned: 193 | batch['observations'] = np.concatenate([batch['observations'], batch['goals']], axis=-1) 194 | batch['next_observations'] = np.concatenate([batch['next_observations'], batch['goals']], axis=-1) 195 | 196 | agent, update_info = agent.update(batch) 197 | 198 | if i % FLAGS.log_interval == 0: 199 | train_metrics = {f'training/{k}': v for k, v in update_info.items()} 200 | wandb.log(train_metrics, step=i) 201 | 202 | if i % FLAGS.eval_interval == 0: 203 | record_video = i % FLAGS.video_interval == 0 204 | policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0) 205 | eval_info, trajs = evaluate(policy_fn, eval_env, num_episodes=FLAGS.eval_episodes, record_video=record_video, return_trajectories=True) 206 | eval_metrics = {} 207 | for k in ['episode.return', 'episode.length']: 208 | eval_metrics[f'evaluation/{k}'] = eval_info[k] 209 | print(f'evaluation/{k}: {eval_info[k]}') 210 | eval_metrics['evaluation/episode.return.normalized'] = eval_env.get_normalized_score(eval_info['episode.return']) 211 | print(f'evaluation/episode.return.normalized: {eval_metrics["evaluation/episode.return.normalized"]}') 212 | if record_video: 213 | wandb.log({'video': eval_info['video']}, step=i) 214 | 215 | # Antmaze Specific Logging 216 | if 'antmaze-large' in FLAGS.env_name or 'maze2d-large' in FLAGS.env_name: 217 | import fre.common.envs.d4rl.d4rl_ant as d4rl_ant 218 | # Make an image of the trajectories. 219 | traj_image = d4rl_ant.trajectory_image(eval_env, trajs) 220 | eval_metrics['trajectories'] = wandb.Image(traj_image) 221 | 222 | wandb.log(eval_metrics, step=i) 223 | 224 | if __name__ == '__main__': 225 | app.run(main) -------------------------------------------------------------------------------- /algs/iql.py: -------------------------------------------------------------------------------- 1 | import os 2 | from absl import app, flags 3 | from functools import partial 4 | import numpy as np 5 | import tqdm 6 | import jax 7 | import jax.numpy as jnp 8 | import flax 9 | import optax 10 | import wandb 11 | from ml_collections import config_flags 12 | import pickle 13 | from flax.training import checkpoints 14 | import ml_collections 15 | 16 | from fre.common.envs.gc_utils import GCDataset 17 | from fre.common.envs.data_transforms import ActionDiscretizeCluster, ActionDiscretizeBins, ActionTransform 18 | from fre.common.envs.env_helper import make_env, get_dataset 19 | from fre.common.wandb import setup_wandb, default_wandb_config, get_flag_dict 20 | from fre.common.evaluation import evaluate 21 | from fre.common.utils import supply_rng 22 | from fre.common.typing import * 23 | from fre.common.train_state import TrainState, target_update 24 | from fre.common.networks.basic import Policy, ValueCritic, Critic, ensemblize 25 | 26 | 27 | ############################### 28 | # Configs 29 | ############################### 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | flags.DEFINE_string('env_name', 'gc-antmaze-large-diverse-v2', 'Environment name.') 34 | flags.DEFINE_string('name', 'default', '') 35 | 36 | flags.DEFINE_string('save_dir', None, 'Logging dir (if not None, save params).') 37 | 38 | flags.DEFINE_integer('seed', np.random.choice(1000000), 'Random seed.') 39 | flags.DEFINE_integer('eval_episodes', 20, 40 | 'Number of episodes used for evaluation.') 41 | flags.DEFINE_integer('log_interval', 1000, 'Logging interval.') 42 | flags.DEFINE_integer('eval_interval', 50000, 'Eval interval.') 43 | flags.DEFINE_integer('save_interval', 250000, 'Eval interval.') 44 | flags.DEFINE_integer('video_interval', 250000, 'Eval interval.') 45 | flags.DEFINE_integer('batch_size', 1024, 'Mini batch size.') 46 | flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.') 47 | flags.DEFINE_integer('goal_conditioned', 0, 'Whether to use goal conditioned relabelling or not.') 48 | 49 | # These variables are passed to the IQLAgent class. 50 | agent_config = ml_collections.ConfigDict({ 51 | 'actor_lr': 3e-4, 52 | 'value_lr': 3e-4, 53 | 'critic_lr': 3e-4, 54 | 'num_qs': 2, 55 | 'actor_hidden_dims': (512, 512, 512), 56 | 'hidden_dims': (512, 512, 512), 57 | 'discount': 0.99, 58 | 'expectile': 0.9, 59 | 'temperature': 3.0, # 0 for behavior cloning. 60 | 'dropout_rate': 0, 61 | 'use_tanh': 0, 62 | 'state_dependent_std': 0, 63 | 'use_layer_norm': 1, 64 | 'tau': 0.005, 65 | 'opt_decay_schedule': 'none', 66 | 'action_transform_type': 'none', 67 | 'actor_loss_type': 'awr', # or ddpg 68 | 'bc_weight': 0.0, # for ddpg 69 | }) 70 | 71 | wandb_config = default_wandb_config() 72 | wandb_config.update({ 73 | 'name': 'iql_{env_name}', 74 | }) 75 | 76 | config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False) 77 | config_flags.DEFINE_config_dict('agent', agent_config, lock_config=False) 78 | config_flags.DEFINE_config_dict('gcdataset', GCDataset.get_default_config(), lock_config=False) 79 | 80 | ############################### 81 | # Agent. Contains the neural networks, training logic, and sampling. 82 | ############################### 83 | 84 | def expectile_loss(diff, expectile=0.8): 85 | weight = jnp.where(diff > 0, expectile, (1 - expectile)) 86 | return weight * (diff**2) 87 | 88 | class IQLAgent(flax.struct.PyTreeNode): 89 | rng: PRNGKey 90 | critic: TrainState 91 | target_critic: TrainState 92 | value: TrainState 93 | actor: TrainState 94 | config: dict = flax.struct.field(pytree_node=False) 95 | 96 | @jax.jit 97 | def update(agent, batch: Batch) -> InfoDict: 98 | def critic_loss_fn(critic_params): 99 | next_v = agent.value(batch['next_observations']) 100 | target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_v 101 | qs = agent.critic(batch['observations'], batch['actions'], params=critic_params) # [num_q, batch] 102 | critic_loss = ((qs - target_q[None])**2).mean() 103 | return critic_loss, { 104 | 'critic_loss': critic_loss, 105 | 'q1': qs[0].mean(), 106 | } 107 | 108 | def value_loss_fn(value_params): 109 | qs = agent.target_critic(batch['observations'], batch['actions']) 110 | q = jnp.min(qs, axis=0) # Min over ensemble. 111 | v = agent.value(batch['observations'], params=value_params) 112 | value_loss = expectile_loss(q-v, agent.config['expectile']).mean() 113 | return value_loss, { 114 | 'value_loss': value_loss, 115 | 'v': v.mean(), 116 | } 117 | 118 | def actor_loss_fn(actor_params): 119 | if agent.config['actor_loss_type'] == 'awr': 120 | v = agent.value(batch['observations']) 121 | qs = agent.critic(batch['observations'], batch['actions']) 122 | q = jnp.min(qs, axis=0) # Min over ensemble. 123 | exp_a = jnp.exp((q - v) * agent.config['temperature']) 124 | exp_a = jnp.minimum(exp_a, 100.0) 125 | 126 | actions = batch['actions'] 127 | if agent.config['action_transform_type'] == "cluster": 128 | actions = agent.config['action_transform'].action_to_ids(actions) 129 | dist = agent.actor(batch['observations'], params=actor_params) 130 | log_probs = dist.log_prob(actions) 131 | actor_loss = -(exp_a * log_probs).mean() 132 | 133 | action_std = dist.stddev().mean() if agent.config['action_transform_type'] == 'none' else 0 134 | return actor_loss, { 135 | 'actor_loss': actor_loss, 136 | 'action_std': action_std, 137 | 'adv': (q - v).mean(), 138 | 'adv_min': (q - v).min(), 139 | 'adv_max': (q - v).max(), 140 | } 141 | elif agent.config['actor_loss_type'] == 'ddpg': 142 | dist = agent.actor(batch['observations'], params=actor_params) 143 | normalized_actions = jnp.tanh(dist.loc) 144 | qs = agent.critic(batch['observations'], normalized_actions) 145 | q = jnp.min(qs, axis=0) # Min over ensemble. 146 | 147 | q_loss = -q.mean() 148 | 149 | log_probs = dist.log_prob(batch['actions']) 150 | bc_loss = -((agent.config['bc_weight'] * log_probs)).mean() # Abuse the name 'temperature' for the BC coefficient 151 | 152 | actor_loss = (q_loss + bc_loss).mean() 153 | return actor_loss, { 154 | 'bc_loss': bc_loss, 155 | 'agent_q': q.mean(), 156 | } 157 | else: 158 | raise NotImplementedError 159 | 160 | new_critic, critic_info = agent.critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True) 161 | new_target_critic = target_update(agent.critic, agent.target_critic, agent.config['target_update_rate']) 162 | new_value, value_info = agent.value.apply_loss_fn(loss_fn=value_loss_fn, has_aux=True) 163 | new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=actor_loss_fn, has_aux=True) 164 | 165 | return agent.replace(critic=new_critic, target_critic=new_target_critic, value=new_value, actor=new_actor), { 166 | **critic_info, **value_info, **actor_info 167 | } 168 | 169 | @jax.jit 170 | def sample_actions(agent, 171 | observations: np.ndarray, 172 | *, 173 | seed: PRNGKey, 174 | temperature: float = 1.0) -> jnp.ndarray: 175 | if type(observations) is dict: 176 | observations = jnp.concatenate([observations['observation'], observations['goal']], axis=-1) 177 | actions = agent.actor(observations, temperature=temperature).sample(seed=seed) 178 | if agent.config['action_transform_type'] == "cluster": 179 | actions = agent.config['action_transform'].ids_to_action(actions) 180 | if agent.config['actor_loss_type'] == 'ddpg': 181 | actions = jnp.tanh(actions) 182 | actions = jnp.clip(actions, -1, 1) 183 | return actions 184 | 185 | # Initializes all the networks, etc. for the agent. 186 | def create_agent( 187 | seed: int, 188 | observations: jnp.ndarray, 189 | actions: jnp.ndarray, 190 | actor_lr: float, 191 | value_lr: float, 192 | critic_lr: float, 193 | hidden_dims: Sequence[int], 194 | actor_hidden_dims: Sequence[int], 195 | discount: float, 196 | tau: float, 197 | expectile: float, 198 | temperature: float, 199 | use_tanh: bool, 200 | state_dependent_std: bool, 201 | use_layer_norm: bool, 202 | opt_decay_schedule: str, 203 | actor_loss_type: str, 204 | bc_weight: float, 205 | num_qs: int, 206 | action_transform_type: str, 207 | action_transform: ActionTransform = None, 208 | max_steps: Optional[int] = None, 209 | **kwargs): 210 | 211 | print('Extra kwargs:', kwargs) 212 | 213 | rng = jax.random.PRNGKey(seed) 214 | rng, actor_key, critic_key, value_key = jax.random.split(rng, 4) 215 | 216 | action_dim = actions.shape[-1] 217 | if action_transform_type == "cluster": 218 | num_clusters = action_transform.num_clusters 219 | actor_def = Policy(actor_hidden_dims, action_dim=num_clusters, is_discrete=True, mlp_kwargs=dict(use_layer_norm=use_layer_norm)) 220 | else: 221 | actor_def = Policy(actor_hidden_dims, action_dim=action_dim, 222 | log_std_min=-5.0, state_dependent_std=state_dependent_std, tanh_squash_distribution=use_tanh, mlp_kwargs=dict(use_layer_norm=use_layer_norm)) 223 | 224 | if opt_decay_schedule == "cosine": 225 | schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps) 226 | actor_tx = optax.chain(optax.scale_by_adam(), 227 | optax.scale_by_schedule(schedule_fn)) 228 | else: 229 | actor_tx = optax.adam(learning_rate=actor_lr) 230 | 231 | actor_params = actor_def.init(actor_key, observations)['params'] 232 | actor = TrainState.create(actor_def, actor_params, tx=actor_tx) 233 | 234 | critic_def = ensemblize(Critic, num_qs=num_qs)(hidden_dims, mlp_kwargs=dict(use_layer_norm=use_layer_norm)) 235 | critic_params = critic_def.init(critic_key, observations, actions)['params'] 236 | critic = TrainState.create(critic_def, critic_params, tx=optax.adam(learning_rate=critic_lr)) 237 | target_critic = TrainState.create(critic_def, critic_params) 238 | 239 | value_def = ValueCritic(hidden_dims, mlp_kwargs=dict(use_layer_norm=use_layer_norm)) 240 | value_params = value_def.init(value_key, observations)['params'] 241 | value = TrainState.create(value_def, value_params, tx=optax.adam(learning_rate=value_lr)) 242 | 243 | config = flax.core.FrozenDict(dict( 244 | discount=discount, temperature=temperature, expectile=expectile, target_update_rate=tau, action_transform_type=action_transform_type, 245 | action_transform=action_transform, actor_loss_type=actor_loss_type, bc_weight=bc_weight 246 | )) 247 | 248 | return IQLAgent(rng, critic=critic, target_critic=target_critic, value=value, actor=actor, config=config) 249 | 250 | 251 | ############################### 252 | # Run Script. Loads data, logs to wandb, and runs the training loop. 253 | ############################### 254 | 255 | 256 | def main(_): 257 | if FLAGS.goal_conditioned: 258 | assert 'gc' in FLAGS.env_name 259 | else: 260 | assert 'gc' not in FLAGS.env_name 261 | 262 | np.random.seed(FLAGS.seed) 263 | 264 | # Create wandb logger 265 | setup_wandb(FLAGS.agent.to_dict(), **FLAGS.wandb) 266 | 267 | env = make_env(FLAGS.env_name) 268 | eval_env = make_env(FLAGS.env_name) 269 | 270 | dataset = get_dataset(env, FLAGS.env_name) 271 | if FLAGS.goal_conditioned: 272 | dataset = GCDataset(dataset, **FLAGS.gcdataset.to_dict()) 273 | example_batch = dataset.sample(1) 274 | example_obs = np.concatenate([example_batch['observations'], example_batch['goals']], axis=-1) 275 | debug_batch = dataset.sample(100) 276 | print("Masks Look Like", debug_batch['masks']) 277 | print("Rewards Look Like", debug_batch['rewards']) 278 | else: 279 | example_obs = dataset.sample(1)['observations'] 280 | example_batch = dataset.sample(1) 281 | 282 | if FLAGS.agent['action_transform_type'] == 'cluster': 283 | if FLAGS.goal_conditioned: 284 | action_transform = ActionDiscretizeCluster(1024, dataset.dataset["actions"][::100]) 285 | else: 286 | action_transform = ActionDiscretizeCluster(1024, dataset["actions"][::100]) 287 | else: 288 | action_transform = None 289 | 290 | agent = create_agent(FLAGS.seed, 291 | example_obs, 292 | example_batch['actions'], 293 | max_steps=FLAGS.max_steps, 294 | action_transform=action_transform, 295 | **FLAGS.agent) 296 | 297 | for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1), 298 | smoothing=0.1, 299 | dynamic_ncols=True): 300 | 301 | batch = dataset.sample(FLAGS.batch_size) 302 | if FLAGS.goal_conditioned: 303 | batch['observations'] = np.concatenate([batch['observations'], batch['goals']], axis=-1) 304 | batch['next_observations'] = np.concatenate([batch['next_observations'], batch['goals']], axis=-1) 305 | 306 | agent, update_info = agent.update(batch) 307 | 308 | if i % FLAGS.log_interval == 0: 309 | train_metrics = {f'training/{k}': v for k, v in update_info.items()} 310 | wandb.log(train_metrics, step=i) 311 | 312 | if i % FLAGS.eval_interval == 0: 313 | record_video = i % FLAGS.video_interval == 0 314 | policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0) 315 | eval_info, trajs = evaluate(policy_fn, eval_env, num_episodes=FLAGS.eval_episodes, record_video=record_video, return_trajectories=True) 316 | eval_metrics = {} 317 | for k in ['episode.return', 'episode.length']: 318 | eval_metrics[f'evaluation/{k}'] = eval_info[k] 319 | print(f'evaluation/{k}: {eval_info[k]}') 320 | try: 321 | eval_metrics['evaluation/episode.return.normalized'] = eval_env.get_normalized_score(eval_info['episode.return']) 322 | print(f'evaluation/episode.return.normalized: {eval_metrics["evaluation/episode.return.normalized"]}') 323 | except: 324 | pass 325 | if record_video: 326 | wandb.log({'video': eval_info['video']}, step=i) 327 | 328 | # Antmaze Specific Logging 329 | if 'antmaze-large' in FLAGS.env_name or 'maze2d-large' in FLAGS.env_name: 330 | import fre.common.envs.d4rl.d4rl_ant as d4rl_ant 331 | # Make an image of the trajectories. 332 | traj_image = d4rl_ant.trajectory_image(eval_env, trajs) 333 | eval_metrics['trajectories'] = wandb.Image(traj_image) 334 | 335 | # Make an image of the value function. 336 | if 'antmaze-large' in FLAGS.env_name or 'maze2d-large' in FLAGS.env_name: 337 | def get_gcvalue(state, goal): 338 | obgoal = jnp.concatenate([state, goal], axis=-1) 339 | return agent.value(obgoal) 340 | pred_value_img = d4rl_ant.value_image(eval_env, dataset, get_gcvalue) 341 | eval_metrics['v'] = wandb.Image(pred_value_img) 342 | 343 | # Maze2d Action Distribution 344 | if 'maze2d-large' in FLAGS.env_name: 345 | # Make a plot of the actions. 346 | traj_actions = np.concatenate([t['action'] for t in trajs], axis=0) # (T, A) 347 | import matplotlib.pyplot as plt 348 | plt.figure() 349 | plt.scatter(traj_actions[::100, 0], traj_actions[::100, 1], alpha=0.4) 350 | wandb.log({'actions_traj': wandb.Image(plt)}, step=i) 351 | 352 | data_actions = batch['actions'] 353 | import matplotlib.pyplot as plt 354 | plt.figure() 355 | plt.scatter(data_actions[:, 0], data_actions[:, 1], alpha=0.2) 356 | wandb.log({'actions_data': wandb.Image(plt)}, step=i) 357 | 358 | wandb.log(eval_metrics, step=i) 359 | 360 | if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None: 361 | checkpoints.save_checkpoint(FLAGS.save_dir, agent, i) 362 | 363 | if __name__ == '__main__': 364 | app.run(main) -------------------------------------------------------------------------------- /common/dataset.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # 3 | # Dataset Pytrees for offline data, replay buffers, etc. 4 | # 5 | ############################### 6 | 7 | import numpy as np 8 | from fre.common.typing import Data, Array 9 | from flax.core.frozen_dict import FrozenDict 10 | from jax import tree_util 11 | 12 | 13 | def get_size(data: Data) -> int: 14 | sizes = tree_util.tree_map(lambda arr: len(arr), data) 15 | return max(tree_util.tree_leaves(sizes)) 16 | 17 | 18 | class Dataset(FrozenDict): 19 | """ 20 | A class for storing (and retrieving batches of) data in nested dictionary format. 21 | 22 | Example: 23 | dataset = Dataset({ 24 | 'observations': { 25 | 'image': np.random.randn(100, 28, 28, 1), 26 | 'state': np.random.randn(100, 4), 27 | }, 28 | 'actions': np.random.randn(100, 2), 29 | }) 30 | 31 | batch = dataset.sample(32) 32 | # Batch should have nested shape: { 33 | # 'observations': {'image': (32, 28, 28, 1), 'state': (32, 4)}, 34 | # 'actions': (32, 2) 35 | # } 36 | """ 37 | 38 | @classmethod 39 | def create( 40 | cls, 41 | observations: Data, 42 | actions: Array, 43 | rewards: Array, 44 | masks: Array, 45 | next_observations: Data, 46 | freeze=True, 47 | **extra_fields 48 | ): 49 | data = { 50 | "observations": observations, 51 | "actions": actions, 52 | "rewards": rewards, 53 | "masks": masks, 54 | "next_observations": next_observations, 55 | **extra_fields, 56 | } 57 | # Force freeze 58 | if freeze: 59 | tree_util.tree_map(lambda arr: arr.setflags(write=False), data) 60 | return cls(data) 61 | 62 | def __init__(self, *args, **kwargs): 63 | super().__init__(*args, **kwargs) 64 | self.size = get_size(self._dict) 65 | 66 | def sample(self, batch_size: int, indx=None): 67 | """ 68 | Sample a batch of data from the dataset. Use `indx` to specify a specific 69 | set of indices to retrieve. Otherwise, a random sample will be drawn. 70 | 71 | Returns a dictionary with the same structure as the original dataset. 72 | """ 73 | if indx is None: 74 | indx = np.random.randint(self.size, size=batch_size) 75 | return self.get_subset(indx) 76 | 77 | def get_subset(self, indx): 78 | return tree_util.tree_map(lambda arr: arr[indx], self._dict) 79 | 80 | 81 | class ReplayBuffer(Dataset): 82 | """ 83 | Dataset where data is added to the buffer. 84 | 85 | Example: 86 | example_transition = { 87 | 'observations': { 88 | 'image': np.random.randn(28, 28, 1), 89 | 'state': np.random.randn(4), 90 | }, 91 | 'actions': np.random.randn(2), 92 | } 93 | buffer = ReplayBuffer.create(example_transition, size=1000) 94 | buffer.add_transition(example_transition) 95 | batch = buffer.sample(32) 96 | 97 | """ 98 | 99 | @classmethod 100 | def create(cls, transition: Data, size: int): 101 | def create_buffer(example): 102 | example = np.array(example) 103 | return np.zeros((size, *example.shape), dtype=example.dtype) 104 | 105 | buffer_dict = tree_util.tree_map(create_buffer, transition) 106 | return cls(buffer_dict) 107 | 108 | @classmethod 109 | def create_from_initial_dataset(cls, init_dataset: dict, size: int): 110 | def create_buffer(init_buffer): 111 | buffer = np.zeros((size, *init_buffer.shape[1:]), dtype=init_buffer.dtype) 112 | buffer[: len(init_buffer)] = init_buffer 113 | return buffer 114 | 115 | buffer_dict = tree_util.tree_map(create_buffer, init_dataset) 116 | dataset = cls(buffer_dict) 117 | dataset.size = dataset.pointer = get_size(init_dataset) 118 | return dataset 119 | 120 | def __init__(self, *args, **kwargs): 121 | super().__init__(*args, **kwargs) 122 | 123 | self.max_size = get_size(self._dict) 124 | self.size = 0 125 | self.pointer = 0 126 | 127 | def add_transition(self, transition): 128 | def set_idx(buffer, new_element): 129 | buffer[self.pointer] = new_element 130 | 131 | tree_util.tree_map(set_idx, self._dict, transition) 132 | self.pointer = (self.pointer + 1) % self.max_size 133 | self.size = max(self.pointer, self.size) 134 | 135 | def clear(self): 136 | self.size = self.pointer = 0 137 | -------------------------------------------------------------------------------- /common/envs/bandit/bandit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | 4 | # Here's a super simple bandit environment that follows the OpenAI Gym API. 5 | # There is one continuous action. The observation is always zero. 6 | # A reward of 1 is given if the action is either 0.5 or -0.5. 7 | 8 | class BanditEnv(gym.Env): 9 | def __init__(self): 10 | self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) 11 | self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) 12 | self._state = None 13 | self.width = 0.15 14 | 15 | def reset(self): 16 | self._state = np.zeros(1) 17 | return self._state 18 | 19 | def step(self, action): 20 | reward = 0 21 | if (np.abs(action[0] - 0.5) < self.width) or (np.abs(action[0] + 0.5) < self.width): 22 | reward = 1 23 | self.last_action = action 24 | return self._state, reward, True, {} 25 | 26 | def render(self, mode='human'): 27 | # Render the last action on a line. Also indicate where the reward is. Return this as a numpy array. 28 | img = np.ones((20, 100, 3), dtype=np.uint8) * 255 29 | # Render reward zones in green. 0-100 means actions between -1 and 1. 30 | center_low = 25 31 | center_high = 75 32 | width_int = int(self.width * 50) 33 | img[:, center_low-width_int:center_low+width_int, :] = [0, 255, 0] 34 | img[:, center_high-width_int:center_high+width_int, :] = [0, 255, 0] 35 | # Render the last action in red. 36 | action = self.last_action[0] 37 | action = int((action + 1) * 50) 38 | img[:, action:action+1, :] = [255, 0, 0] 39 | return img 40 | 41 | 42 | 43 | 44 | def close(self): 45 | pass -------------------------------------------------------------------------------- /common/envs/d4rl/antmaze_actions.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/fre/c9b5a1d8d88a69bbe25da0d88d39ef1a9c4cf39a/common/envs/d4rl/antmaze_actions.npy -------------------------------------------------------------------------------- /common/envs/d4rl/d4rl_ant.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | from matplotlib import patches 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 9 | from functools import partial 10 | from mpl_toolkits.axes_grid1 import make_axes_locatable 11 | 12 | import os 13 | import os.path as osp 14 | 15 | import gym 16 | import d4rl 17 | import numpy as np 18 | import functools as ft 19 | import math 20 | import matplotlib.gridspec as gridspec 21 | 22 | from fre.common.envs.gc_utils import GCDataset 23 | 24 | class MazeWrapper(gym.Wrapper): 25 | def __init__(self, env_name): 26 | self.env = gym.make(env_name) 27 | self.env.render(mode='rgb_array', width=200, height=200) 28 | self.env_name = env_name 29 | self.inner_env = get_inner_env(self.env) 30 | if 'antmaze' in env_name: 31 | if 'medium' in env_name: 32 | self.env.viewer.cam.lookat[0] = 10 33 | self.env.viewer.cam.lookat[1] = 10 34 | self.env.viewer.cam.distance = 40 35 | self.env.viewer.cam.elevation = -90 36 | elif 'umaze' in env_name: 37 | self.env.viewer.cam.lookat[0] = 4 38 | self.env.viewer.cam.lookat[1] = 4 39 | self.env.viewer.cam.distance = 30 40 | self.env.viewer.cam.elevation = -90 41 | elif 'large' in env_name: 42 | self.env.viewer.cam.lookat[0] = 18 43 | self.env.viewer.cam.lookat[1] = 13 44 | self.env.viewer.cam.distance = 55 45 | self.env.viewer.cam.elevation = -90 46 | self.inner_env.goal_sampler = ft.partial(valid_goal_sampler, self.inner_env) 47 | elif 'maze2d' in env_name: 48 | if 'open' in env_name: 49 | pass 50 | elif 'large' in env_name: 51 | self.env.viewer.cam.lookat[0] = 5 52 | self.env.viewer.cam.lookat[1] = 6.5 53 | self.env.viewer.cam.distance = 15 54 | self.env.viewer.cam.elevation = -90 55 | self.env.viewer.cam.azimuth = 180 56 | self.draw_ant_maze = get_inner_env(gym.make('antmaze-large-diverse-v2')) 57 | self.action_space = self.env.action_space 58 | 59 | def render(self, *args, **kwargs): 60 | img = self.env.render(*args, **kwargs) 61 | if 'maze2d' in self.env_name: 62 | img = img[::-1] 63 | return img 64 | 65 | # ======== BELOW is helper stuff for drawing and visualizing ======== # 66 | 67 | def get_starting_boundary(self): 68 | if 'antmaze' in self.env_name: 69 | self = self.inner_env 70 | else: 71 | self = self.draw_ant_maze 72 | torso_x, torso_y = self._init_torso_x, self._init_torso_y 73 | S = self._maze_size_scaling 74 | return (0 - S / 2 + S - torso_x, 0 - S/2 + S - torso_y), (len(self._maze_map[0]) * S - torso_x - S/2 - S, len(self._maze_map) * S - torso_y - S/2 - S) 75 | 76 | def XY(self, n=20, m=30): 77 | bl, tr = self.get_starting_boundary() 78 | X = np.linspace(bl[0] + 0.04 * (tr[0] - bl[0]) , tr[0] - 0.04 * (tr[0] - bl[0]), m) 79 | Y = np.linspace(bl[1] + 0.04 * (tr[1] - bl[1]) , tr[1] - 0.04 * (tr[1] - bl[1]), n) 80 | 81 | X,Y = np.meshgrid(X,Y) 82 | states = np.array([X.flatten(), Y.flatten()]).T 83 | return states 84 | 85 | def four_goals(self): 86 | self = self.inner_env 87 | 88 | valid_cells = [] 89 | goal_cells = [] 90 | 91 | for i in range(len(self._maze_map)): 92 | for j in range(len(self._maze_map[0])): 93 | if self._maze_map[i][j] in [0, 'r', 'g']: 94 | valid_cells.append(self._rowcol_to_xy((i, j), add_random_noise=False)) 95 | 96 | goals = [] 97 | goals.append(max(valid_cells, key=lambda x: -x[0]-x[1])) 98 | goals.append(max(valid_cells, key=lambda x: x[0]-x[1])) 99 | goals.append(max(valid_cells, key=lambda x: x[0]+x[1])) 100 | goals.append(max(valid_cells, key=lambda x: -x[0] + x[1])) 101 | return goals 102 | 103 | def draw(self, ax=None, scale=1.0): 104 | if not ax: ax = plt.gca() 105 | if 'antmaze' in self.env_name: 106 | self = self.inner_env 107 | else: 108 | self = self.draw_ant_maze 109 | torso_x, torso_y = self._init_torso_x, self._init_torso_y 110 | S = self._maze_size_scaling 111 | if scale < 1.0: 112 | S *= 0.965 113 | torso_x -= 0.7 114 | torso_y -= 0.95 115 | for i in range(len(self._maze_map)): 116 | for j in range(len(self._maze_map[0])): 117 | struct = self._maze_map[i][j] 118 | if struct == 1: 119 | rect = patches.Rectangle((j *S - torso_x - S/ 2, 120 | i * S- torso_y - S/ 2), 121 | S, 122 | S, linewidth=1, edgecolor='none', facecolor='grey', alpha=1.0) 123 | 124 | ax.add_patch(rect) 125 | ax.set_xlim(0 - S /2 + 0.6 * S - torso_x, len(self._maze_map[0]) * S - torso_x - S/2 - S * 0.6) 126 | ax.set_ylim(0 - S/2 + 0.6 * S - torso_y, len(self._maze_map) * S - torso_y - S/2 - S * 0.6) 127 | ax.axis('off') 128 | 129 | class CenteredMaze(MazeWrapper): 130 | start_loc: str = "center" 131 | 132 | def __init__(self, env_name, start_loc="center"): 133 | super().__init__(env_name) 134 | self.start_loc = start_loc 135 | self.t = 0 136 | 137 | def step(self, action): 138 | next_obs, r, done, info = self.env.step(action) 139 | if 'antmaze' in self.env_name: 140 | info['x'], info['y'] = self.get_xy() 141 | self.t += 1 142 | done = self.t >= 2000 143 | return next_obs, r, done, info 144 | 145 | def reset(self, **kwargs): 146 | self.t = 0 147 | obs = self.env.reset(**kwargs) 148 | if 'maze2d' in self.env_name: 149 | if self.start_loc == 'center' or self.start_loc == 'center2': 150 | obs = self.env.reset_to_location([4, 5.8]) 151 | elif self.start_loc == 'original': 152 | obs = self.env.reset_to_location([0.9, 0.9]) 153 | else: 154 | raise NotImplementedError 155 | elif 'antmaze' in self.env_name: 156 | if self.start_loc == 'center' or self.start_loc == 'center2': 157 | self.env.set_xy([20, 15]) 158 | obs[:2] = [20, 15] 159 | elif self.start_loc == 'original': 160 | pass 161 | else: 162 | raise NotImplementedError 163 | return obs 164 | 165 | class GoalReachingMaze(MazeWrapper): 166 | def __init__(self, env_name): 167 | super().__init__(env_name) 168 | self.observation_space = gym.spaces.Dict({ 169 | 'observation': self.env.observation_space, 170 | 'goal': self.env.observation_space, 171 | }) 172 | 173 | def step(self, action): 174 | next_obs, r, done, info = self.env.step(action) 175 | 176 | if 'antmaze' in self.env_name: 177 | achieved = self.get_xy() 178 | desired = self.target_goal 179 | elif 'maze2d' in self.env_name: 180 | achieved = next_obs[:2] 181 | desired = self.env.get_target() 182 | distance = np.linalg.norm(achieved - desired) 183 | info['x'], info['y'] = achieved 184 | info['achieved_goal'] = np.array(achieved) 185 | info['desired_goal'] = np.copy(desired) 186 | info['success'] = float(distance < 0.5) 187 | done = 'TimeLimit.truncated' in info or info['success'] 188 | 189 | return self.get_obs(next_obs), r, done, info 190 | 191 | def get_obs(self, obs): 192 | if 'antmaze' in self.env_name: 193 | desired = self.target_goal 194 | elif 'maze2d' in self.env_name: 195 | desired = self.env.get_target() 196 | target_goal = obs.copy() 197 | target_goal[:2] = desired 198 | if 'antmaze' in self.env_name: 199 | obs = discretize_obs(obs) 200 | target_goal = discretize_obs(target_goal) 201 | return dict(observation=obs, goal=target_goal) 202 | 203 | def reset(self, **kwargs): 204 | obs = self.env.reset(**kwargs) 205 | if 'maze2d' in self.env_name: 206 | obs = self.env.reset_to_location([0.9, 0.9]) 207 | return self.get_obs(obs) 208 | 209 | def get_normalized_score(self, score): 210 | return score 211 | 212 | # =================================== 213 | # HELPER FUNCTIONS FOR OB DISCRETIZATION 214 | # =================================== 215 | 216 | def discretize_obs(ob, num_bins=32, disc_type='tanh', disc_temperature=1.0): 217 | min_ob = np.array([0, 0]) 218 | max_ob = np.array([35, 35]) 219 | disc_dims = 2 220 | bins = np.linspace(min_ob, max_ob, num_bins).T # [num_bins,] values from min_ob to max_ob 221 | bin_size = (max_ob - min_ob) / num_bins 222 | if disc_type == 'twohot': 223 | raise NotImplementedError 224 | elif disc_type == 'tanh': 225 | orig_ob = ob 226 | ob = np.expand_dims(ob, -1) 227 | # Convert each discretized dimension into num_bins dimensions. Value of each dimension is tanh of the distance from the bin center. 228 | bin_diff = ob[..., :disc_dims, :] - bins[:disc_dims] 229 | bin_diff_normalized = bin_diff / np.expand_dims(bin_size[:disc_dims], -1) * disc_temperature 230 | bin_tanh = np.tanh(bin_diff_normalized).reshape(*orig_ob.shape[:-1], -1) 231 | disc_ob = np.concatenate([bin_tanh, orig_ob[..., disc_dims:]], axis=-1) 232 | return disc_ob 233 | else: 234 | raise NotImplementedError 235 | 236 | # =================================== 237 | # HELPER FUNCTIONS FOR VISUALIZATION 238 | # =================================== 239 | 240 | def get_canvas_image(canvas): 241 | canvas.draw() 242 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') 243 | out_image = out_image.reshape(canvas.get_width_height()[::-1] + (3,)) 244 | return out_image 245 | 246 | def valid_goal_sampler(self, np_random): 247 | valid_cells = [] 248 | goal_cells = [] 249 | # print('Hello') 250 | 251 | for i in range(len(self._maze_map)): 252 | for j in range(len(self._maze_map[0])): 253 | if self._maze_map[i][j] in [0, 'r', 'g']: 254 | valid_cells.append((i, j)) 255 | 256 | # If there is a 'goal' designated, use that. Otherwise, any valid cell can 257 | # be a goal. 258 | sample_choices = valid_cells 259 | cell = sample_choices[np_random.choice(len(sample_choices))] 260 | xy = self._rowcol_to_xy(cell, add_random_noise=True) 261 | 262 | random_x = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling 263 | random_y = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling 264 | 265 | xy = (max(xy[0] + random_x, 0), max(xy[1] + random_y, 0)) 266 | 267 | return xy 268 | 269 | 270 | def get_inner_env(env): 271 | if hasattr(env, '_maze_size_scaling'): 272 | return env 273 | elif hasattr(env, 'env'): 274 | return get_inner_env(env.env) 275 | elif hasattr(env, 'wrapped_env'): 276 | return get_inner_env(env.wrapped_env) 277 | return env 278 | 279 | 280 | # =================================== 281 | # PLOT VALUE FUNCTION 282 | # =================================== 283 | 284 | def value_image(env, dataset, value_fn): 285 | """ 286 | Visualize the value function. 287 | Args: 288 | env: The environment. 289 | value_fn: a function with signature value_fn([# states, state_dim]) -> [#states, 1] 290 | Returns: 291 | A numpy array of the image. 292 | """ 293 | fig, axs = plt.subplots(2, 2, tight_layout=True) 294 | axs_flat = axs.flatten() 295 | canvas = FigureCanvas(fig) 296 | if type(dataset) is GCDataset: 297 | dataset = dataset.dataset 298 | if 'antmaze' in env.env_name: 299 | goals = env.four_goals() 300 | goal_states = dataset['observations'][0] 301 | goal_states = goal_states[-29:] # Remove discretized observations. 302 | goal_states = np.tile(goal_states, (len(goals), 1)) 303 | goal_states[:, :2] = goals 304 | goal_states = discretize_obs(goal_states) 305 | elif 'maze2d' in env.env_name: 306 | goals = np.array([[0.8, 0.8], [1, 9.7], [6.8, 9], [6.8, 1]]) 307 | goal_states = dataset['observations'][0] 308 | goal_states = np.tile(goal_states, (len(goals), 1)) 309 | goal_states[:, :2] = goals 310 | for i in range(4): 311 | plot_value(goal_states[i], env, dataset, value_fn, axs_flat[i]) 312 | image = get_canvas_image(canvas) 313 | plt.close(fig) 314 | return image 315 | 316 | def plot_value(goal_observation, env, dataset, value_fn, ax): 317 | N = 14 318 | M = 20 319 | ob_xy = env.XY(n=N, m=M) 320 | 321 | goal_observation = np.tile(goal_observation, (ob_xy.shape[0], 1)) # (N*M, 29) 322 | 323 | base_observation = np.copy(dataset['observations'][0]) 324 | xy_observations = np.tile(base_observation, (ob_xy.shape[0], 1)) # (N*M, 29) 325 | if 'antmaze' in env.env_name: 326 | xy_observations = xy_observations[:, -29:] # Remove discretized observations. 327 | xy_observations[:, :2] = ob_xy # Set to XY. 328 | xy_observations = discretize_obs(xy_observations) # Discretize again. 329 | assert xy_observations.shape[1] == 91 330 | elif 'maze2d' in env.env_name: 331 | ob_xy_scaled = ob_xy / 3.5 332 | ob_xy_scaled = ob_xy_scaled[:, [1, 0]] 333 | xy_observations[:, :2] = ob_xy_scaled 334 | assert xy_observations.shape[1] == 4 # (x, y, vx, vy) 335 | values = value_fn(xy_observations, goal_observation) # (N*M, 1) 336 | 337 | x, y = ob_xy[:, 0], ob_xy[:, 1] 338 | x = x.reshape(N, M) 339 | y = y.reshape(N, M) * 0.975 + 0.7 340 | values = values.reshape(N, M) 341 | mesh = ax.pcolormesh(x, y, values, cmap='viridis') 342 | 343 | env.draw(ax, scale=0.95) 344 | 345 | 346 | # =================================== 347 | # PLOT TRAJECTORIES 348 | # =================================== 349 | 350 | # Makes an image of the trajectory the Ant follows. 351 | def trajectory_image(env, trajectories, **kwargs): 352 | fig = plt.figure(tight_layout=True) 353 | canvas = FigureCanvas(fig) 354 | 355 | plot_trajectories(env, trajectories, fig, plt.gca(), **kwargs) 356 | 357 | plt.tight_layout() 358 | image = get_canvas_image(canvas) 359 | plt.close(fig) 360 | return image 361 | 362 | # Helper that plots the XY coordinates as scatter plots. 363 | def plot_trajectories(env, trajectories, fig, ax, color_list=None): 364 | if color_list is None: 365 | from itertools import cycle 366 | color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color'] 367 | color_list = cycle(color_cycle) 368 | 369 | for color, trajectory in zip(color_list, trajectories): 370 | obs = np.array(trajectory['observation']) 371 | 372 | # convert back to xy? 373 | if 'ant' in env.env_name: 374 | all_x = [] 375 | all_y = [] 376 | for info in trajectory['info']: 377 | all_x.append(info['x']) 378 | all_y.append(info['y']) 379 | all_x = np.array(all_x) 380 | all_y = np.array(all_y) 381 | elif 'maze2d' in env.env_name: 382 | all_x = obs[:, 1] * 4 - 3.2 383 | all_y = obs[:, 0] * 4 - 3.2 384 | ax.scatter(all_x, all_y, s=5, c=color, alpha=0.2) 385 | ax.scatter(all_x[-1], all_y[-1], s=50, c=color, marker='*', alpha=1, edgecolors='black') 386 | 387 | env.draw(ax) -------------------------------------------------------------------------------- /common/envs/d4rl/d4rl_utils.py: -------------------------------------------------------------------------------- 1 | import d4rl 2 | import d4rl.gym_mujoco 3 | import gym 4 | import numpy as np 5 | from jax import tree_util 6 | 7 | 8 | import fre.common.envs.d4rl.d4rl_ant as d4rl_ant 9 | from fre.common.dataset import Dataset 10 | 11 | 12 | # Note on AntMaze. Reward = 1 at the goal, and Terminal = 1 at the goal. 13 | # Masks = Does the episode end due to final state? 14 | # Dones_float = Does the episode end due to time limit? OR does the episode end due to final state? 15 | def get_dataset(env: gym.Env, env_name: str, clip_to_eps: bool = True, 16 | eps: float = 1e-5, dataset=None, filter_terminals=False, obs_dtype=np.float32): 17 | if dataset is None: 18 | dataset = d4rl.qlearning_dataset(env) 19 | 20 | if clip_to_eps: 21 | lim = 1 - eps 22 | dataset['actions'] = np.clip(dataset['actions'], -lim, lim) 23 | 24 | # Mask everything that is marked as a terminal state. 25 | # For AntMaze, this should mask the end of each trajectory. 26 | masks = 1.0 - dataset['terminals'] 27 | 28 | # In the AntMaze data, terminal is 1 when at the goal. But the episode doesn't end. 29 | # This just ensures that we treat AntMaze trajectories as non-ending. 30 | if "antmaze" in env_name or "maze2d" in env_name: 31 | dataset['terminals'] = np.zeros_like(dataset['terminals']) 32 | 33 | # if 'antmaze' in env_name: 34 | # print("Discretizing AntMaze observations.") 35 | # print("Raw observations looks like", dataset['observations'].shape[1:]) 36 | # dataset['observations'] = d4rl_ant.discretize_obs(dataset['observations']) 37 | # dataset['next_observations'] = d4rl_ant.discretize_obs(dataset['next_observations']) 38 | # print("Discretized observations looks like", dataset['observations'].shape[1:]) 39 | 40 | # Compute dones if terminal OR orbservation jumps. 41 | dones_float = np.zeros_like(dataset['rewards']) 42 | 43 | imputed_next_observations = np.roll(dataset['observations'], -1, axis=0) 44 | same_obs = np.all(np.isclose(imputed_next_observations, dataset['next_observations'], atol=1e-5), axis=-1) 45 | dones_float = 1.0 - same_obs.astype(np.float32) 46 | dones_float += dataset['terminals'] 47 | dones_float[-1] = 1.0 48 | dones_float = np.clip(dones_float, 0.0, 1.0) 49 | 50 | observations = dataset['observations'].astype(obs_dtype) 51 | next_observations = dataset['next_observations'].astype(obs_dtype) 52 | 53 | return Dataset.create( 54 | observations=observations, 55 | actions=dataset['actions'].astype(np.float32), 56 | rewards=dataset['rewards'].astype(np.float32), 57 | masks=masks.astype(np.float32), 58 | dones_float=dones_float.astype(np.float32), 59 | next_observations=next_observations, 60 | ) 61 | 62 | def get_normalization(dataset): 63 | returns = [] 64 | ret = 0 65 | for r, term in zip(dataset['rewards'], dataset['dones_float']): 66 | ret += r 67 | if term: 68 | returns.append(ret) 69 | ret = 0 70 | return (max(returns) - min(returns)) / 1000 71 | 72 | def normalize_dataset(env_name, dataset): 73 | print("Normalizing", env_name) 74 | if 'antmaze' in env_name or 'maze2d' in env_name: 75 | return dataset.copy({'rewards': dataset['rewards']- 1.0}) 76 | else: 77 | normalizing_factor = get_normalization(dataset) 78 | print(f'Normalizing factor: {normalizing_factor}') 79 | dataset = dataset.copy({'rewards': dataset['rewards'] / normalizing_factor}) 80 | return dataset 81 | 82 | # Flattens environment with a dictionary of observation,goal to a single concatenated observation. 83 | class GoalReachingFlat(gym.Wrapper): 84 | """A wrapper that maps actions from [-1,1] to [low, hgih].""" 85 | def __init__(self, env): 86 | super().__init__(env) 87 | self.observation_space = gym.spaces.Box( 88 | low=-np.inf, high=np.inf, shape=(self.observation_space['observation'].shape[0] + self.observation_space['goal'].shape[0],), dtype=np.float32) 89 | 90 | def step(self, action): 91 | ob, reward, done, info = self.env.step(action) 92 | ob_flat = np.concatenate([ob['observation'], ob['goal']]) 93 | return ob_flat, reward, done, info 94 | 95 | def reset(self, **kwargs): 96 | ob = self.env.reset(**kwargs) 97 | ob_flat = np.concatenate([ob['observation'], ob['goal']]) 98 | return ob_flat 99 | 100 | def parse_trajectories(dataset): 101 | trajectory_ids = np.where(dataset['dones_float'] == 1)[0] + 1 102 | trajectory_ids = np.concatenate([[0], trajectory_ids]) 103 | num_trajectories = trajectory_ids.shape[0] - 1 104 | print("There are {} trajectories. Some traj lens are {}".format(num_trajectories, [trajectory_ids[i + 1] - trajectory_ids[i] for i in range(min(5, num_trajectories))])) 105 | trajectories = [] 106 | for i in range(len(trajectory_ids) - 1): 107 | trajectories.append(tree_util.tree_map(lambda arr: arr[trajectory_ids[i]:trajectory_ids[i + 1]], dataset._dict)) 108 | return trajectories 109 | 110 | class KitchenRenderWrapper(gym.Wrapper): 111 | def render(self, *args, **kwargs): 112 | from dm_control.mujoco import engine 113 | camera = engine.MovableCamera(self.sim, 1920, 2560) 114 | camera.set_pose(distance=2.2, lookat=[-0.2, .5, 2.], azimuth=70, elevation=-35) 115 | img = camera.render() 116 | return img 117 | -------------------------------------------------------------------------------- /common/envs/data_transforms.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # 3 | # Helpful utilities for processing actions, observations. 4 | # 5 | ############################### 6 | 7 | import numpy as np 8 | import jax.numpy as jnp 9 | 10 | class ActionTransform(): 11 | pass 12 | 13 | class ActionDiscretizeBins(ActionTransform): 14 | def __init__(self, bins_per_dim, action_dim): 15 | self.bins_per_dim = bins_per_dim 16 | self.action_dim = action_dim 17 | self.bins = np.linspace(-1, 1, bins_per_dim + 1) 18 | 19 | # Assumes action is in [-1, 1]. 20 | def action_to_ids(self, action): 21 | ids = np.digitize(action, self.bins) - 1 22 | ids = np.clip(ids, 0, self.bins_per_dim - 1) 23 | return ids 24 | 25 | def ids_to_action(self, ids): 26 | action = (self.bins[ids] + self.bins[ids + 1]) / 2 27 | return action 28 | 29 | class ActionDiscretizeCluster(ActionTransform): 30 | def __init__(self, num_clusters, data_actions): 31 | self.num_clusters = num_clusters 32 | assert len(data_actions.shape) == 2 # (data_size, action_dim) 33 | print("Clustering actions of shape", data_actions.shape) 34 | 35 | # Cluster the data. 36 | from sklearn.cluster import KMeans 37 | kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(data_actions) 38 | self.centers = kmeans.cluster_centers_ 39 | # self.labels = kmeans.labels_ 40 | self.centers = jnp.array(self.centers) 41 | print("Average cluster error is", kmeans.inertia_ / len(data_actions)) 42 | print("Average cluster error per dimension is", (kmeans.inertia_ / len(data_actions)) / data_actions.shape[1]) 43 | # print(self.centers.shape) 44 | 45 | def action_to_ids(self, action): 46 | if len(action.shape) == 1: 47 | action = action[None] 48 | assert len(action.shape) == 2 # (batch, action_dim,) 49 | # Find the closest cluster center. 50 | dists = jnp.linalg.norm(self.centers[None] - action[:, None], axis=-1) 51 | ids = jnp.argmin(dists, axis=-1) 52 | return ids 53 | 54 | def ids_to_action(self, ids): 55 | action = self.centers[ids] 56 | return action 57 | 58 | # Test 59 | # action_discretize_bins = ActionDiscretizeBins(32, 2) 60 | # action = np.array([-1, -0.999, -0.5, 0, 0.5, 0.999, 1]) 61 | # ids = action_discretize_bins.action_to_ids(action) 62 | # print(ids) 63 | # action_recreate = action_discretize_bins.ids_to_action(ids) 64 | # print(action_recreate) 65 | # assert np.abs(action - action_recreate).max() < 0.1 66 | 67 | # action_discretize_cluster = ActionDiscretizeCluster(32, np.random.uniform(low=-1, high=1, size=(10000, 1))) 68 | # action = np.array([-1, -0.999, -0.5, 0, 0.5, 0.999, 1])[:, None] # [7, 1] 69 | # ids = action_discretize_cluster.action_to_ids(action) 70 | # print(ids) 71 | # action_recreate = action_discretize_cluster.ids_to_action(ids) 72 | # print(action_recreate) 73 | # assert np.abs(action - action_recreate).max() < 0.1 -------------------------------------------------------------------------------- /common/envs/dmc/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.envs.registration import register 3 | 4 | 5 | def make( 6 | domain_name, 7 | task_name, 8 | seed=1, 9 | visualize_reward=True, 10 | from_pixels=False, 11 | height=84, 12 | width=84, 13 | camera_id=0, 14 | frame_skip=1, 15 | episode_length=1000, 16 | environment_kwargs=None, 17 | time_limit=None, 18 | channels_first=True 19 | ): 20 | env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed) 21 | 22 | if from_pixels: 23 | assert not visualize_reward, 'cannot use visualize reward when learning from pixels' 24 | 25 | # shorten episode length 26 | max_episode_steps = (episode_length + frame_skip - 1) // frame_skip 27 | 28 | if not env_id in gym.envs.registry.env_specs: 29 | task_kwargs = {} 30 | if seed is not None: 31 | task_kwargs['random'] = seed 32 | if time_limit is not None: 33 | task_kwargs['time_limit'] = time_limit 34 | register( 35 | id=env_id, 36 | # entry_point='dmc2gym.wrappers:DMCWrapper', 37 | entry_point='fre.common.envs.dmc.wrappers:DMCWrapper', 38 | kwargs=dict( 39 | domain_name=domain_name, 40 | task_name=task_name, 41 | task_kwargs=task_kwargs, 42 | environment_kwargs=environment_kwargs, 43 | visualize_reward=visualize_reward, 44 | from_pixels=from_pixels, 45 | height=height, 46 | width=width, 47 | camera_id=camera_id, 48 | frame_skip=frame_skip, 49 | channels_first=channels_first, 50 | ), 51 | max_episode_steps=max_episode_steps, 52 | ) 53 | return gym.make(env_id) 54 | -------------------------------------------------------------------------------- /common/envs/dmc/jaco.py: -------------------------------------------------------------------------------- 1 | """A task where the goal is to move the hand close to a target prop or site.""" 2 | 3 | import collections 4 | 5 | from dm_control import composer 6 | from dm_control.composer import initializers 7 | from dm_control.composer.variation import distributions 8 | from dm_control.entities import props 9 | from dm_control.manipulation.shared import arenas 10 | from dm_control.manipulation.shared import cameras 11 | from dm_control.manipulation.shared import constants 12 | from dm_control.manipulation.shared import observations 13 | from dm_control.manipulation.shared import robots 14 | from dm_control.manipulation.shared import workspaces 15 | from dm_control.utils import rewards 16 | from dm_env import specs 17 | import numpy as np 18 | 19 | _ReachWorkspace = collections.namedtuple( 20 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset']) 21 | 22 | # Ensures that the props are not touching the table before settling. 23 | _PROP_Z_OFFSET = 0.001 24 | 25 | _DUPLO_WORKSPACE = _ReachWorkspace( 26 | target_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, _PROP_Z_OFFSET), 27 | upper=(0.1, 0.1, _PROP_Z_OFFSET)), 28 | tcp_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, 0.2), 29 | upper=(0.1, 0.1, 0.4)), 30 | arm_offset=robots.ARM_OFFSET) 31 | 32 | _SITE_WORKSPACE = _ReachWorkspace( 33 | target_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02), 34 | upper=(0.2, 0.2, 0.4)), 35 | tcp_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02), 36 | upper=(0.2, 0.2, 0.4)), 37 | arm_offset=robots.ARM_OFFSET) 38 | 39 | _TARGET_RADIUS = 0.05 40 | _TIME_LIMIT = 10. 41 | 42 | TASKS = [('reach_top_left', np.array([-0.09, 0.09, _PROP_Z_OFFSET])), 43 | ('reach_top_right', np.array([0.09, 0.09, _PROP_Z_OFFSET])), 44 | ('reach_bottom_left', np.array([-0.09, -0.09, _PROP_Z_OFFSET])), 45 | ('reach_bottom_right', np.array([0.09, -0.09, _PROP_Z_OFFSET]))] 46 | 47 | 48 | def make(task_id, obs_type, seed): 49 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES 50 | task = _reach(task_id, obs_settings=obs_settings, use_site=True) 51 | return composer.Environment(task, 52 | time_limit=_TIME_LIMIT, 53 | random_state=seed) 54 | 55 | 56 | class MultiTaskReach(composer.Task): 57 | """Bring the hand close to a target prop or site.""" 58 | 59 | def __init__(self, task_id, arena, arm, hand, prop, obs_settings, 60 | workspace, control_timestep): 61 | """Initializes a new `Reach` task. 62 | Args: 63 | arena: `composer.Entity` instance. 64 | arm: `robot_base.RobotArm` instance. 65 | hand: `robot_base.RobotHand` instance. 66 | prop: `composer.Entity` instance specifying the prop to reach to, or None 67 | in which case the target is a fixed site whose position is specified by 68 | the workspace. 69 | obs_settings: `observations.ObservationSettings` instance. 70 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP. 71 | control_timestep: Float specifying the control timestep in seconds. 72 | """ 73 | self._arena = arena 74 | self._arm = arm 75 | self._hand = hand 76 | self._arm.attach(self._hand) 77 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset) 78 | self.control_timestep = control_timestep 79 | self._tcp_initializer = initializers.ToolCenterPointInitializer( 80 | self._hand, 81 | self._arm, 82 | position=distributions.Uniform(*workspace.tcp_bbox), 83 | quaternion=workspaces.DOWN_QUATERNION) 84 | 85 | # Add custom camera observable. 86 | self._task_observables = cameras.add_camera_observables( 87 | arena, obs_settings, cameras.FRONT_CLOSE) 88 | 89 | if task_id == 'reach_multitask': 90 | self._targets = [target for (_, target) in TASKS] 91 | else: 92 | self._targets = [ 93 | target for (task, target) in TASKS if task == task_id 94 | ] 95 | assert len(self._targets) > 0 96 | 97 | #target_pos_distribution = distributions.Uniform(*TASKS[task_id]) 98 | self._prop = prop 99 | if prop: 100 | # The prop itself is used to visualize the target location. 101 | self._make_target_site(parent_entity=prop, visible=False) 102 | self._target = self._arena.add_free_entity(prop) 103 | self._prop_placer = initializers.PropPlacer( 104 | props=[prop], 105 | position=target_pos_distribution, 106 | quaternion=workspaces.uniform_z_rotation, 107 | settle_physics=True) 108 | else: 109 | if len(self._targets) == 1: 110 | self._target = self._make_target_site(parent_entity=arena, 111 | visible=True) 112 | 113 | #obs = observable.MJCFFeature('pos', self._target) 114 | # obs.configure(**obs_settings.prop_pose._asdict()) 115 | #self._task_observables['target_position'] = obs 116 | 117 | # Add sites for visualizing the prop and target bounding boxes. 118 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody, 119 | lower=workspace.tcp_bbox.lower, 120 | upper=workspace.tcp_bbox.upper, 121 | rgba=constants.GREEN, 122 | name='tcp_spawn_area') 123 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody, 124 | lower=workspace.target_bbox.lower, 125 | upper=workspace.target_bbox.upper, 126 | rgba=constants.BLUE, 127 | name='target_spawn_area') 128 | 129 | def _make_target_site(self, parent_entity, visible): 130 | return workspaces.add_target_site( 131 | body=parent_entity.mjcf_model.worldbody, 132 | radius=_TARGET_RADIUS, 133 | visible=visible, 134 | rgba=constants.RED, 135 | name='target_site') 136 | 137 | @property 138 | def root_entity(self): 139 | return self._arena 140 | 141 | @property 142 | def arm(self): 143 | return self._arm 144 | 145 | @property 146 | def hand(self): 147 | return self._hand 148 | 149 | def get_reward_spec(self): 150 | n = len(self._targets) 151 | return specs.Array(shape=(n,), dtype=np.float32, name='reward') 152 | 153 | @property 154 | def task_observables(self): 155 | return self._task_observables 156 | 157 | def get_reward(self, physics): 158 | hand_pos = physics.bind(self._hand.tool_center_point).xpos 159 | rews = [] 160 | for target_pos in self._targets: 161 | distance = np.linalg.norm(hand_pos - target_pos) 162 | reward = rewards.tolerance(distance, 163 | bounds=(0, _TARGET_RADIUS), 164 | margin=_TARGET_RADIUS) 165 | rews.append(reward) 166 | rews = np.array(rews).astype(np.float32) 167 | if len(self._targets) == 1: 168 | return rews[0] 169 | return rews 170 | 171 | def initialize_episode(self, physics, random_state): 172 | self._hand.set_grasp(physics, close_factors=random_state.uniform()) 173 | self._tcp_initializer(physics, random_state) 174 | if self._prop: 175 | self._prop_placer(physics, random_state) 176 | else: 177 | if len(self._targets) == 1: 178 | physics.bind(self._target).pos = self._targets[0] 179 | 180 | 181 | def _reach(task_id, obs_settings, use_site): 182 | """Configure and instantiate a `Reach` task. 183 | Args: 184 | obs_settings: An `observations.ObservationSettings` instance. 185 | use_site: Boolean, if True then the target will be a fixed site, otherwise 186 | it will be a moveable Duplo brick. 187 | Returns: 188 | An instance of `reach.Reach`. 189 | """ 190 | arena = arenas.Standard() 191 | arm = robots.make_arm(obs_settings=obs_settings) 192 | hand = robots.make_hand(obs_settings=obs_settings) 193 | if use_site: 194 | workspace = _SITE_WORKSPACE 195 | prop = None 196 | else: 197 | workspace = _DUPLO_WORKSPACE 198 | prop = props.Duplo(observable_options=observations.make_options( 199 | obs_settings, observations.FREEPROP_OBSERVABLES)) 200 | task = MultiTaskReach(task_id, 201 | arena=arena, 202 | arm=arm, 203 | hand=hand, 204 | prop=prop, 205 | obs_settings=obs_settings, 206 | workspace=workspace, 207 | control_timestep=constants.CONTROL_TIMESTEP) 208 | return task -------------------------------------------------------------------------------- /common/envs/dmc/wrappers.py: -------------------------------------------------------------------------------- 1 | from gym import core, spaces 2 | from dm_control import suite 3 | from dm_env import specs 4 | import numpy as np 5 | 6 | 7 | def _spec_to_box(spec, dtype): 8 | def extract_min_max(s): 9 | assert s.dtype == np.float64 or s.dtype == np.float32 10 | dim = int(np.prod(s.shape)) 11 | if type(s) == specs.Array: 12 | bound = np.inf * np.ones(dim, dtype=np.float32) 13 | return -bound, bound 14 | elif type(s) == specs.BoundedArray: 15 | zeros = np.zeros(dim, dtype=np.float32) 16 | return s.minimum + zeros, s.maximum + zeros 17 | 18 | mins, maxs = [], [] 19 | for s in spec: 20 | mn, mx = extract_min_max(s) 21 | mins.append(mn) 22 | maxs.append(mx) 23 | low = np.concatenate(mins, axis=0).astype(dtype) 24 | high = np.concatenate(maxs, axis=0).astype(dtype) 25 | assert low.shape == high.shape 26 | return spaces.Box(low, high, dtype=dtype) 27 | 28 | 29 | def _flatten_obs(obs): 30 | obs_pieces = [] 31 | for v in obs.values(): 32 | flat = np.array([v]) if np.isscalar(v) else v.ravel() 33 | obs_pieces.append(flat) 34 | return np.concatenate(obs_pieces, axis=0) 35 | 36 | 37 | class DMCWrapper(core.Env): 38 | def __init__( 39 | self, 40 | domain_name, 41 | task_name, 42 | task_kwargs=None, 43 | visualize_reward={}, 44 | from_pixels=False, 45 | height=84, 46 | width=84, 47 | camera_id=0, 48 | frame_skip=1, 49 | environment_kwargs=None, 50 | channels_first=True 51 | ): 52 | assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' 53 | self._from_pixels = from_pixels 54 | self._height = height 55 | self._width = width 56 | self._camera_id = camera_id 57 | self._frame_skip = frame_skip 58 | self._channels_first = channels_first 59 | 60 | # create task 61 | if domain_name == 'jaco': 62 | import fre.common.envs.dmc.jaco as jaco 63 | self._env = jaco.make(task_id=task_name, obs_type=jaco.observations.PERFECT_FEATURES, seed=1) 64 | else: 65 | self._env = suite.load( 66 | domain_name=domain_name, 67 | task_name=task_name, 68 | task_kwargs=task_kwargs, 69 | visualize_reward=visualize_reward, 70 | environment_kwargs=environment_kwargs 71 | ) 72 | 73 | # true and normalized action spaces 74 | self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) 75 | self._norm_action_space = spaces.Box( 76 | low=-1.0, 77 | high=1.0, 78 | shape=self._true_action_space.shape, 79 | dtype=np.float32 80 | ) 81 | 82 | # create observation space 83 | if from_pixels: 84 | shape = [3, height, width] if channels_first else [height, width, 3] 85 | self._observation_space = spaces.Box( 86 | low=0, high=255, shape=shape, dtype=np.uint8 87 | ) 88 | else: 89 | self._observation_space = _spec_to_box( 90 | self._env.observation_spec().values(), 91 | np.float64 92 | ) 93 | 94 | self._state_space = _spec_to_box( 95 | self._env.observation_spec().values(), 96 | np.float64 97 | ) 98 | 99 | self.current_state = None 100 | 101 | # set seed 102 | self.seed(seed=task_kwargs.get('random', 1)) 103 | 104 | def __getattr__(self, name): 105 | return getattr(self._env, name) 106 | 107 | def _get_obs(self, time_step): 108 | if self._from_pixels: 109 | obs = self.render( 110 | height=self._height, 111 | width=self._width, 112 | camera_id=self._camera_id 113 | ) 114 | if self._channels_first: 115 | obs = obs.transpose(2, 0, 1).copy() 116 | else: 117 | obs = _flatten_obs(time_step.observation) 118 | 119 | return obs 120 | 121 | def _convert_action(self, action): 122 | action = action.astype(np.float32) 123 | true_delta = self._true_action_space.high - self._true_action_space.low 124 | norm_delta = self._norm_action_space.high - self._norm_action_space.low 125 | action = (action - self._norm_action_space.low) / norm_delta 126 | action = action * true_delta + self._true_action_space.low 127 | action = action.astype(np.float32) 128 | return action 129 | 130 | @property 131 | def observation_space(self): 132 | return self._observation_space 133 | 134 | @property 135 | def state_space(self): 136 | return self._state_space 137 | 138 | @property 139 | def action_space(self): 140 | return self._norm_action_space 141 | 142 | @property 143 | def reward_range(self): 144 | return 0, self._frame_skip 145 | 146 | def seed(self, seed): 147 | self._true_action_space.seed(seed) 148 | self._norm_action_space.seed(seed) 149 | self._observation_space.seed(seed) 150 | 151 | def step(self, action): 152 | assert self._norm_action_space.contains(action) 153 | action = self._convert_action(action) 154 | assert self._true_action_space.contains(action) 155 | reward = 0 156 | extra = {'internal_state': self._env.physics.get_state().copy()} 157 | 158 | for _ in range(self._frame_skip): 159 | time_step = self._env.step(action) 160 | reward += time_step.reward or 0 161 | done = time_step.last() 162 | if done: 163 | break 164 | obs = self._get_obs(time_step) 165 | self.current_state = _flatten_obs(time_step.observation) 166 | extra['discount'] = time_step.discount 167 | return obs, reward, done, extra 168 | 169 | def reset(self): 170 | time_step = self._env.reset() 171 | self.current_state = _flatten_obs(time_step.observation) 172 | obs = self._get_obs(time_step) 173 | return obs 174 | 175 | def render(self, mode='rgb_array', height=None, width=None, camera_id=0): 176 | assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode 177 | height = height or self._height 178 | width = width or self._width 179 | camera_id = camera_id or self._camera_id 180 | return self._env.physics.render( 181 | height=height, width=width, camera_id=camera_id 182 | ) 183 | -------------------------------------------------------------------------------- /common/envs/env_helper.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # 3 | # Helper that initializes environments with the proper imports. 4 | # Returns an environment that is: 5 | # - Action normalized. 6 | # - Video rendering works. 7 | # - Episode monitor attached. 8 | # 9 | ############################### 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import os 14 | import os.path as osp 15 | import gym 16 | import numpy as np 17 | import functools as ft 18 | from fre.common.train_state import NormalizeActionWrapper 19 | from fre.common.envs.wrappers import EpisodeMonitor 20 | 21 | 22 | # Supported envs: 23 | env_list = [ 24 | # From Gym 25 | 'HalfCheetah-v2', 26 | 'Hopper-v2', 27 | 'Walker2d-v2', 28 | 'Pendulum-v1', 29 | 'CartPole-v1', 30 | 'Acrobot_v1', 31 | 'MountainCar-v0', 32 | 'MountainCarContinuous-v0', 33 | # From DMC 34 | 'pendulum_swingup', 35 | 'acrobot_swingup', 36 | 'acrobot_swingup_sparse', 37 | 'cartpole_swingup', # has exorl dataset. 38 | 'cartpole_swingup_sparse', 39 | 'pointmass_easy', 40 | 'reacher_easy', 41 | 'reacher_hard', 42 | 'cheetah_run', # has exorl dataset. 43 | 'hopper_hop', 44 | 'walker_stand', # has exorl dataset. 45 | 'walker_walk', # has exorl dataset. 46 | 'walker_run', # has exorl dataset. 47 | 'quadruped_walk', # has exorl dataset. 48 | 'quadruped_run', # has exorl dataset. 49 | 'humanoid_stand', 50 | 'humanoid_run', 51 | 'jaco_reach_top_left', # has exorl dataset. 52 | 'jaco_reach_bottom_right', # has exorl dataset. 53 | # TODO: Atari games 54 | # Offline D4RL envs 55 | 'antmaze-large-diverse-v2', # Start in the corner, goal is in the top corner. 56 | 'gc-antmaze-large-diverse-v2', # Start in the corner, goal is in the top corner. 57 | 'center-antmaze-large-diverse-v2', # Start in the center, goal is UNDEFINED (this is for RARE rewards). 58 | 'maze2d-large-v1', 59 | 'gc-maze2d-large-v1', 60 | 'center-maze2d-large-v1', 61 | # D4RL mujoco 62 | 'halfcheetah-expert-v2', 63 | 'walker2d-expert-v2', 64 | 'hopper-expert-v2', 65 | 'kitchen-complete-v0' # broken 66 | 'kitchen-mixed-v0' # broken 67 | ] 68 | 69 | # Making an environment. 70 | def make_env(env_name, **kwargs): 71 | if 'exorl' in env_name: 72 | import os 73 | os.environ['DISPLAY'] = ':0' 74 | import fre.common.envs.exorl.dmc as dmc 75 | _, env_name, task_name = env_name.split('_', 2) 76 | def make_env(env_name, task_name): 77 | env = dmc.make(f'{env_name}_{task_name}', obs_type='states', frame_stack=1, action_repeat=1, seed=0) 78 | env = dmc.DMCWrapper(env, 0) 79 | return env 80 | env = make_env(env_name, task_name) 81 | env.reset() 82 | elif '_' in env_name: # DMC Control 83 | import fre.common.envs.dmc as dmc2gym 84 | import os 85 | os.environ['DISPLAY'] = ':0' 86 | suite, task = env_name.split('_', 1) 87 | print(suite, task) 88 | if suite == 'pointmass': 89 | suite = 'point_mass' 90 | frame_skip = kwargs['frame_skip'] if 'frame_skip' in kwargs else 2 91 | visualize_reward = kwargs['visualize_reward'] if 'visualize_reward' in kwargs else False 92 | env = dmc2gym.make( 93 | domain_name=suite, 94 | task_name=task, seed=1, 95 | frame_skip=frame_skip, 96 | visualize_reward=visualize_reward) 97 | env = NormalizeActionWrapper(env) 98 | elif 'antmaze' in env_name: 99 | from fre.common.envs.d4rl.d4rl_ant import CenteredMaze, GoalReachingMaze, MazeWrapper 100 | if 'gc-antmaze' in env_name: 101 | env = GoalReachingMaze('antmaze-large-diverse-v2') 102 | elif 'center-antmaze' in env_name: 103 | env = CenteredMaze('antmaze-large-diverse-v2') 104 | else: 105 | env = MazeWrapper('antmaze-large-diverse-v2') 106 | elif 'maze2d' in env_name: 107 | from fre.common.envs.d4rl.d4rl_ant import CenteredMaze, GoalReachingMaze, MazeWrapper 108 | if 'gc-maze2d' in env_name: 109 | env = GoalReachingMaze('maze2d-large-v1') 110 | elif 'center-maze2d' in env_name: 111 | env = CenteredMaze('maze2d-large-v1') 112 | else: 113 | env = CenteredMaze('maze2d-large-v1', start_loc='original') 114 | elif 'halfcheetah-' in env_name or 'walker2d-' in env_name or 'hopper-' in env_name: # D4RL Mujoco 115 | import d4rl 116 | import d4rl.gym_mujoco 117 | env = gym.make(env_name) 118 | elif 'kitchen' in env_name: # This doesn't work yet. 119 | import os 120 | os.environ['DISPLAY'] = ':0' 121 | from fre.common.envs.d4rl.d4rl_utils import KitchenRenderWrapper 122 | env = KitchenRenderWrapper(gym.make(env_name)) 123 | elif 'bandit' in env_name: 124 | from fre.common.envs.bandit.bandit import BanditEnv 125 | env = BanditEnv() 126 | else: 127 | env = gym.make(env_name) 128 | env = EpisodeMonitor(env) 129 | return env 130 | 131 | # For getting offline data. 132 | def get_dataset(env, env_name, **kwargs): 133 | if 'exorl' in env_name: 134 | from fre.common.envs.exorl.exorl_utils import get_dataset 135 | env_name_short = env_name.split('_', 1)[1] 136 | return get_dataset(env, env_name_short, **kwargs) 137 | elif 'ant' in env_name or 'maze2d' in env_name or 'kitchen' in env_name or 'halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name: 138 | from fre.common.envs.d4rl.d4rl_utils import get_dataset, normalize_dataset 139 | dataset = get_dataset(env, env_name, **kwargs) 140 | dataset = normalize_dataset(env_name, dataset) 141 | return dataset 142 | elif 'cartpole' in env_name or 'cheetah' in env_name or 'jaco' in env_name or 'quadruped' in env_name or 'walker' in env_name: 143 | from fre.common.envs.exorl.exorl_utils import get_dataset 144 | return get_dataset(env, env_name, **kwargs) 145 | 146 | def make_vec_env(env_name, num_envs, **kwargs): 147 | from gym.vector import SyncVectorEnv 148 | envs = [lambda : make_env(env_name, **kwargs) for _ in range(num_envs)] 149 | env = SyncVectorEnv(envs) 150 | return env -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from . import cheetah 3 | from . import walker 4 | from . import hopper 5 | from . import quadruped 6 | from . import jaco 7 | 8 | 9 | def make(domain, task, 10 | task_kwargs=None, 11 | environment_kwargs=None, 12 | visualize_reward: bool = False): 13 | 14 | if domain == 'cheetah': 15 | return cheetah.make(task, 16 | task_kwargs=task_kwargs, 17 | environment_kwargs=environment_kwargs, 18 | visualize_reward=visualize_reward) 19 | elif domain == 'walker': 20 | return walker.make(task, 21 | task_kwargs=task_kwargs, 22 | environment_kwargs=environment_kwargs, 23 | visualize_reward=visualize_reward) 24 | elif domain == 'hopper': 25 | return hopper.make(task, 26 | task_kwargs=task_kwargs, 27 | environment_kwargs=environment_kwargs, 28 | visualize_reward=visualize_reward) 29 | elif domain == 'quadruped': 30 | return quadruped.make(task, 31 | task_kwargs=task_kwargs, 32 | environment_kwargs=environment_kwargs, 33 | visualize_reward=visualize_reward) 34 | elif domain == 'point_mass_maze': 35 | return point_mass_maze.make(task, 36 | task_kwargs=task_kwargs, 37 | environment_kwargs=environment_kwargs, 38 | visualize_reward=visualize_reward) 39 | 40 | else: 41 | raise ValueError(f'{task} not found') 42 | 43 | assert None 44 | 45 | 46 | def make_jaco(task, obs_type, seed) -> tp.Any: 47 | return jaco.make(task, obs_type, seed) 48 | -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/cheetah.py: -------------------------------------------------------------------------------- 1 | """Cheetah Domain.""" 2 | 3 | import collections 4 | import os 5 | import typing as tp 6 | from typing import Any, Tuple 7 | 8 | from dm_control import mujoco 9 | from dm_control.rl import control 10 | from dm_control.suite import base 11 | from dm_control.suite import common 12 | from dm_control.utils import containers 13 | from dm_control.utils import rewards 14 | from dm_control.utils import io as resources 15 | 16 | _DEFAULT_TIME_LIMIT: int 17 | _RUN_SPEED: int 18 | _SPIN_SPEED: int 19 | 20 | # How long the simulation will run, in seconds. 21 | _DEFAULT_TIME_LIMIT = 10 22 | 23 | # Running speed above which reward is 1. 24 | _RUN_SPEED = 10 25 | _WALK_SPEED = 2 26 | _SPIN_SPEED = 5 27 | 28 | SUITE = containers.TaggedTasks() 29 | 30 | 31 | def make(task, 32 | task_kwargs=None, 33 | environment_kwargs=None, 34 | visualize_reward: bool = False): 35 | task_kwargs = task_kwargs or {} 36 | if environment_kwargs is not None: 37 | task_kwargs = task_kwargs.copy() 38 | task_kwargs['environment_kwargs'] = environment_kwargs 39 | env = SUITE[task](**task_kwargs) 40 | env.task.visualize_reward = visualize_reward 41 | return env 42 | 43 | 44 | def get_model_and_assets() -> Tuple[Any, Any]: 45 | """Returns a tuple containing the model XML string and a dict of assets.""" 46 | root_dir = os.path.dirname(os.path.dirname(__file__)) 47 | xml = resources.GetResource( 48 | os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml')) 49 | return xml, common.ASSETS 50 | 51 | 52 | @SUITE.add('benchmarking') 53 | def walk(time_limit: int = _DEFAULT_TIME_LIMIT, 54 | random=None, 55 | environment_kwargs=None): 56 | """Returns the run task.""" 57 | physics = Physics.from_xml_string(*get_model_and_assets()) 58 | task = Cheetah(move_speed=_WALK_SPEED, forward=True, flip=False, random=random) 59 | environment_kwargs = environment_kwargs or {} 60 | return control.Environment(physics, 61 | task, 62 | time_limit=time_limit, 63 | **environment_kwargs) 64 | 65 | 66 | @SUITE.add('benchmarking') 67 | def walk_backward(time_limit: int = _DEFAULT_TIME_LIMIT, 68 | random=None, 69 | environment_kwargs=None): 70 | """Returns the run task.""" 71 | physics = Physics.from_xml_string(*get_model_and_assets()) 72 | task = Cheetah(move_speed=_WALK_SPEED, forward=False, flip=False, random=random) 73 | environment_kwargs = environment_kwargs or {} 74 | return control.Environment(physics, 75 | task, 76 | time_limit=time_limit, 77 | **environment_kwargs) 78 | 79 | 80 | @SUITE.add('benchmarking') 81 | def run_backward(time_limit: int = _DEFAULT_TIME_LIMIT, 82 | random=None, 83 | environment_kwargs=None): 84 | """Returns the run task.""" 85 | physics = Physics.from_xml_string(*get_model_and_assets()) 86 | task = Cheetah(forward=False, flip=False, random=random) 87 | environment_kwargs = environment_kwargs or {} 88 | return control.Environment(physics, 89 | task, 90 | time_limit=time_limit, 91 | **environment_kwargs) 92 | 93 | 94 | @SUITE.add('benchmarking') 95 | def flip(time_limit: int = _DEFAULT_TIME_LIMIT, 96 | random=None, 97 | environment_kwargs=None): 98 | """Returns the run task.""" 99 | physics = Physics.from_xml_string(*get_model_and_assets()) 100 | task = Cheetah(move_speed=_WALK_SPEED, forward=True, flip=True, random=random) 101 | environment_kwargs = environment_kwargs or {} 102 | return control.Environment(physics, 103 | task, 104 | time_limit=time_limit, 105 | **environment_kwargs) 106 | 107 | 108 | @SUITE.add('benchmarking') 109 | def flip_backward(time_limit: int = _DEFAULT_TIME_LIMIT, 110 | random=None, 111 | environment_kwargs=None): 112 | """Returns the run task.""" 113 | physics = Physics.from_xml_string(*get_model_and_assets()) 114 | task = Cheetah(move_speed=_WALK_SPEED, forward=False, flip=True, random=random) 115 | environment_kwargs = environment_kwargs or {} 116 | return control.Environment(physics, 117 | task, 118 | time_limit=time_limit, 119 | **environment_kwargs) 120 | 121 | 122 | class Physics(mujoco.Physics): 123 | """Physics simulation with additional features for the Cheetah domain.""" 124 | 125 | def speed(self) -> Any: 126 | """Returns the horizontal speed of the Cheetah.""" 127 | return self.named.data.sensordata['torso_subtreelinvel'][0] 128 | 129 | def angmomentum(self) -> Any: 130 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 131 | return self.named.data.subtree_angmom['torso'][1] 132 | 133 | 134 | class Cheetah(base.Task): 135 | """A `Task` to train a running Cheetah.""" 136 | 137 | def __init__(self, move_speed=_RUN_SPEED, forward=True, flip=False, random=None) -> None: 138 | self._move_speed = move_speed 139 | self._forward = 1 if forward else -1 140 | self._flip = flip 141 | super(Cheetah, self).__init__(random=random) 142 | self._timeout_progress = 0 143 | 144 | def initialize_episode(self, physics) -> None: 145 | """Sets the state of the environment at the start of each episode.""" 146 | # The indexing below assumes that all joints have a single DOF. 147 | assert physics.model.nq == physics.model.njnt 148 | is_limited = physics.model.jnt_limited == 1 149 | lower, upper = physics.model.jnt_range[is_limited].T 150 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper) 151 | 152 | # Stabilize the model before the actual simulation. 153 | for _ in range(200): 154 | physics.step() 155 | 156 | physics.data.time = 0 157 | self._timeout_progress = 0 158 | super().initialize_episode(physics) 159 | 160 | def get_observation(self, physics) -> tp.Dict[str, Any]: 161 | """Returns an observation of the state, ignoring horizontal position.""" 162 | obs = collections.OrderedDict() 163 | # Ignores horizontal position to maintain translational invariance. 164 | obs['position'] = physics.data.qpos[1:].copy() 165 | obs['velocity'] = physics.velocity() 166 | return obs 167 | 168 | def get_reward(self, physics) -> Any: 169 | """Returns a reward to the agent.""" 170 | if self._flip: 171 | reward = rewards.tolerance(self._forward * physics.angmomentum(), 172 | bounds=(_SPIN_SPEED, float('inf')), 173 | margin=_SPIN_SPEED, 174 | value_at_margin=0, 175 | sigmoid='linear') 176 | 177 | else: 178 | reward = rewards.tolerance(self._forward * physics.speed(), 179 | bounds=(self._move_speed, float('inf')), 180 | margin=self._move_speed, 181 | value_at_margin=0, 182 | sigmoid='linear') 183 | return reward 184 | -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/cheetah.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/hopper.py: -------------------------------------------------------------------------------- 1 | """Hopper domain.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import collections 8 | import os 9 | import typing as tp 10 | from typing import Any, Tuple 11 | 12 | import numpy as np 13 | from dm_control import mujoco 14 | from dm_control.rl import control 15 | from dm_control.suite import base 16 | from dm_control.suite import common 17 | from dm_control.suite.utils import randomizers 18 | from dm_control.utils import containers 19 | from dm_control.utils import rewards 20 | from dm_control.utils import io as resources 21 | 22 | _CONTROL_TIMESTEP: float 23 | _DEFAULT_TIME_LIMIT: int 24 | _HOP_SPEED: int 25 | _SPIN_SPEED: int 26 | _STAND_HEIGHT: float 27 | 28 | SUITE = containers.TaggedTasks() 29 | 30 | _CONTROL_TIMESTEP = .02 # (Seconds) 31 | 32 | # Default duration of an episode, in seconds. 33 | _DEFAULT_TIME_LIMIT = 20 34 | 35 | # Minimal height of torso over foot above which stand reward is 1. 36 | _STAND_HEIGHT = 0.6 37 | 38 | # Hopping speed above which hop reward is 1. 39 | _HOP_SPEED = 2 40 | _SPIN_SPEED = 5 41 | 42 | 43 | def make(task, 44 | task_kwargs=None, 45 | environment_kwargs=None, 46 | visualize_reward: bool = False): 47 | task_kwargs = task_kwargs or {} 48 | if environment_kwargs is not None: 49 | task_kwargs = task_kwargs.copy() 50 | task_kwargs['environment_kwargs'] = environment_kwargs 51 | env = SUITE[task](**task_kwargs) 52 | env.task.visualize_reward = visualize_reward 53 | return env 54 | 55 | 56 | def get_model_and_assets() -> Tuple[Any, Any]: 57 | """Returns a tuple containing the model XML string and a dict of assets.""" 58 | root_dir = os.path.dirname(os.path.dirname(__file__)) 59 | xml = resources.GetResource( 60 | os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml')) 61 | return xml, common.ASSETS 62 | 63 | 64 | @SUITE.add('benchmarking') 65 | def hop_backward(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 66 | """Returns a Hopper that strives to hop forward.""" 67 | physics = Physics.from_xml_string(*get_model_and_assets()) 68 | task = Hopper(hopping=True, forward=False, flip=False, random=random) 69 | environment_kwargs = environment_kwargs or {} 70 | return control.Environment(physics, 71 | task, 72 | time_limit=time_limit, 73 | control_timestep=_CONTROL_TIMESTEP, 74 | **environment_kwargs) 75 | 76 | 77 | @SUITE.add('benchmarking') 78 | def flip(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 79 | """Returns a Hopper that strives to hop forward.""" 80 | physics = Physics.from_xml_string(*get_model_and_assets()) 81 | task = Hopper(hopping=True, forward=True, flip=True, random=random) 82 | environment_kwargs = environment_kwargs or {} 83 | return control.Environment(physics, 84 | task, 85 | time_limit=time_limit, 86 | control_timestep=_CONTROL_TIMESTEP, 87 | **environment_kwargs) 88 | 89 | 90 | @SUITE.add('benchmarking') 91 | def flip_backward(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 92 | """Returns a Hopper that strives to hop forward.""" 93 | physics = Physics.from_xml_string(*get_model_and_assets()) 94 | task = Hopper(hopping=True, forward=False, flip=True, random=random) 95 | environment_kwargs = environment_kwargs or {} 96 | return control.Environment(physics, 97 | task, 98 | time_limit=time_limit, 99 | control_timestep=_CONTROL_TIMESTEP, 100 | **environment_kwargs) 101 | 102 | 103 | class Physics(mujoco.Physics): 104 | """Physics simulation with additional features for the Hopper domain.""" 105 | 106 | def height(self) -> Any: 107 | """Returns height of torso with respect to foot.""" 108 | return (self.named.data.xipos['torso', 'z'] - 109 | self.named.data.xipos['foot', 'z']) 110 | 111 | def speed(self) -> Any: 112 | """Returns horizontal speed of the Hopper.""" 113 | return self.named.data.sensordata['torso_subtreelinvel'][0] 114 | 115 | def touch(self) -> Any: 116 | """Returns the signals from two foot touch sensors.""" 117 | return np.log1p(self.named.data.sensordata[['touch_toe', 118 | 'touch_heel']]) 119 | 120 | def angmomentum(self) -> Any: 121 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 122 | return self.named.data.subtree_angmom['torso'][1] 123 | 124 | 125 | class Hopper(base.Task): 126 | """A Hopper's `Task` to train a standing and a jumping Hopper.""" 127 | 128 | def __init__(self, hopping, forward=True, flip=False, random=None) -> None: 129 | """Initialize an instance of `Hopper`. 130 | 131 | Args: 132 | hopping: Boolean, if True the task is to hop forwards, otherwise it is to 133 | balance upright. 134 | random: Optional, either a `numpy.random.RandomState` instance, an 135 | integer seed for creating a new `RandomState`, or None to select a seed 136 | automatically (default). 137 | """ 138 | self._hopping = hopping 139 | self._forward = 1 if forward else -1 140 | self._flip = flip 141 | self._timeout_progress = 0 142 | super(Hopper, self).__init__(random=random) 143 | 144 | def initialize_episode(self, physics) -> None: 145 | """Sets the state of the environment at the start of each episode.""" 146 | randomizers.randomize_limited_and_rotational_joints( 147 | physics, self.random) 148 | self._timeout_progress = 0 149 | super(Hopper, self).initialize_episode(physics) 150 | 151 | def get_observation(self, physics) -> tp.Dict[str, Any]: 152 | """Returns an observation of positions, velocities and touch sensors.""" 153 | obs = collections.OrderedDict() 154 | # Ignores horizontal position to maintain translational invariance: 155 | obs['position'] = physics.data.qpos[1:].copy() 156 | obs['velocity'] = physics.velocity() 157 | obs['touch'] = physics.touch() 158 | return obs 159 | 160 | def get_reward(self, physics) -> Any: 161 | """Returns a reward applicable to the performed task.""" 162 | standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2)) 163 | assert self._hopping 164 | if self._flip: 165 | hopping = rewards.tolerance(self._forward * physics.angmomentum(), 166 | bounds=(_SPIN_SPEED, float('inf')), 167 | margin=_SPIN_SPEED, 168 | value_at_margin=0, 169 | sigmoid='linear') 170 | else: 171 | hopping = rewards.tolerance(self._forward * physics.speed(), 172 | bounds=(_HOP_SPEED, float('inf')), 173 | margin=_HOP_SPEED / 2, 174 | value_at_margin=0.5, 175 | sigmoid='linear') 176 | return standing * hopping 177 | -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/jaco.py: -------------------------------------------------------------------------------- 1 | """A task where the goal is to move the hand close to a target prop or site.""" 2 | 3 | import collections 4 | 5 | from dm_control import composer 6 | from dm_control.composer import initializers 7 | from dm_control.composer.variation import distributions 8 | from dm_control.entities import props 9 | from dm_control.manipulation.shared import arenas 10 | from dm_control.manipulation.shared import cameras 11 | from dm_control.manipulation.shared import constants 12 | from dm_control.manipulation.shared import observations 13 | from dm_control.manipulation.shared import robots 14 | from dm_control.manipulation.shared import workspaces 15 | from dm_control.utils import rewards 16 | from dm_env import specs 17 | import numpy as np 18 | 19 | _ReachWorkspace = collections.namedtuple( 20 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset']) 21 | 22 | # Ensures that the props are not touching the table before settling. 23 | _PROP_Z_OFFSET = 0.001 24 | 25 | _DUPLO_WORKSPACE = _ReachWorkspace( 26 | target_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, _PROP_Z_OFFSET), 27 | upper=(0.1, 0.1, _PROP_Z_OFFSET)), 28 | tcp_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, 0.2), 29 | upper=(0.1, 0.1, 0.4)), 30 | arm_offset=robots.ARM_OFFSET) 31 | 32 | _SITE_WORKSPACE = _ReachWorkspace( 33 | target_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02), 34 | upper=(0.2, 0.2, 0.4)), 35 | tcp_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02), 36 | upper=(0.2, 0.2, 0.4)), 37 | arm_offset=robots.ARM_OFFSET) 38 | 39 | _TARGET_RADIUS = 0.05 40 | _TIME_LIMIT = 10. 41 | 42 | TASKS = [('reach_top_left', np.array([-0.09, 0.09, _PROP_Z_OFFSET])), 43 | ('reach_top_right', np.array([0.09, 0.09, _PROP_Z_OFFSET])), 44 | ('reach_bottom_left', np.array([-0.09, -0.09, _PROP_Z_OFFSET])), 45 | ('reach_bottom_right', np.array([0.09, -0.09, _PROP_Z_OFFSET]))] 46 | 47 | 48 | def make(task_id, obs_type, seed): 49 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES 50 | task = _reach(task_id, obs_settings=obs_settings, use_site=True) 51 | return composer.Environment(task, 52 | time_limit=_TIME_LIMIT, 53 | random_state=seed) 54 | 55 | 56 | class MultiTaskReach(composer.Task): 57 | """Bring the hand close to a target prop or site.""" 58 | 59 | def __init__(self, task_id, arena, arm, hand, prop, obs_settings, 60 | workspace, control_timestep): 61 | """Initializes a new `Reach` task. 62 | 63 | Args: 64 | arena: `composer.Entity` instance. 65 | arm: `robot_base.RobotArm` instance. 66 | hand: `robot_base.RobotHand` instance. 67 | prop: `composer.Entity` instance specifying the prop to reach to, or None 68 | in which case the target is a fixed site whose position is specified by 69 | the workspace. 70 | obs_settings: `observations.ObservationSettings` instance. 71 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP. 72 | control_timestep: Float specifying the control timestep in seconds. 73 | """ 74 | self._arena = arena 75 | self._arm = arm 76 | self._hand = hand 77 | self._arm.attach(self._hand) 78 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset) 79 | self.control_timestep = control_timestep 80 | self._tcp_initializer = initializers.ToolCenterPointInitializer( 81 | self._hand, 82 | self._arm, 83 | position=distributions.Uniform(*workspace.tcp_bbox), 84 | quaternion=workspaces.DOWN_QUATERNION) 85 | 86 | # Add custom camera observable. 87 | self._task_observables = cameras.add_camera_observables( 88 | arena, obs_settings, cameras.FRONT_CLOSE) 89 | 90 | if task_id == 'reach_multitask': 91 | self._targets = [target for (_, target) in TASKS] 92 | else: 93 | self._targets = [ 94 | target for (task, target) in TASKS if task == task_id 95 | ] 96 | assert len(self._targets) > 0 97 | 98 | #target_pos_distribution = distributions.Uniform(*TASKS[task_id]) 99 | self._prop = prop 100 | if prop: 101 | # The prop itself is used to visualize the target location. 102 | self._make_target_site(parent_entity=prop, visible=False) 103 | self._target = self._arena.add_free_entity(prop) 104 | self._prop_placer = initializers.PropPlacer( 105 | props=[prop], 106 | position=target_pos_distribution, 107 | quaternion=workspaces.uniform_z_rotation, 108 | settle_physics=True) 109 | else: 110 | if len(self._targets) == 1: 111 | self._target = self._make_target_site(parent_entity=arena, 112 | visible=True) 113 | 114 | #obs = observable.MJCFFeature('pos', self._target) 115 | # obs.configure(**obs_settings.prop_pose._asdict()) 116 | #self._task_observables['target_position'] = obs 117 | 118 | # Add sites for visualizing the prop and target bounding boxes. 119 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody, 120 | lower=workspace.tcp_bbox.lower, 121 | upper=workspace.tcp_bbox.upper, 122 | rgba=constants.GREEN, 123 | name='tcp_spawn_area') 124 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody, 125 | lower=workspace.target_bbox.lower, 126 | upper=workspace.target_bbox.upper, 127 | rgba=constants.BLUE, 128 | name='target_spawn_area') 129 | 130 | def _make_target_site(self, parent_entity, visible): 131 | return workspaces.add_target_site( 132 | body=parent_entity.mjcf_model.worldbody, 133 | radius=_TARGET_RADIUS, 134 | visible=visible, 135 | rgba=constants.RED, 136 | name='target_site') 137 | 138 | @property 139 | def root_entity(self): 140 | return self._arena 141 | 142 | @property 143 | def arm(self): 144 | return self._arm 145 | 146 | @property 147 | def hand(self): 148 | return self._hand 149 | 150 | def get_reward_spec(self): 151 | n = len(self._targets) 152 | return specs.Array(shape=(n,), dtype=np.float32, name='reward') 153 | 154 | @property 155 | def task_observables(self): 156 | return self._task_observables 157 | 158 | def get_reward(self, physics): 159 | hand_pos = physics.bind(self._hand.tool_center_point).xpos 160 | rews = [] 161 | for target_pos in self._targets: 162 | distance = np.linalg.norm(hand_pos - target_pos) 163 | reward = rewards.tolerance(distance, 164 | bounds=(0, _TARGET_RADIUS), 165 | margin=_TARGET_RADIUS) 166 | rews.append(reward) 167 | rews = np.array(rews).astype(np.float32) 168 | if len(self._targets) == 1: 169 | return rews[0] 170 | return rews 171 | 172 | def initialize_episode(self, physics, random_state): 173 | self._hand.set_grasp(physics, close_factors=random_state.uniform()) 174 | self._tcp_initializer(physics, random_state) 175 | if self._prop: 176 | self._prop_placer(physics, random_state) 177 | else: 178 | if len(self._targets) == 1: 179 | physics.bind(self._target).pos = self._targets[0] 180 | 181 | 182 | def _reach(task_id, obs_settings, use_site): 183 | """Configure and instantiate a `Reach` task. 184 | 185 | Args: 186 | obs_settings: An `observations.ObservationSettings` instance. 187 | use_site: Boolean, if True then the target will be a fixed site, otherwise 188 | it will be a moveable Duplo brick. 189 | 190 | Returns: 191 | An instance of `reach.Reach`. 192 | """ 193 | arena = arenas.Standard() 194 | arm = robots.make_arm(obs_settings=obs_settings) 195 | hand = robots.make_hand(obs_settings=obs_settings) 196 | if use_site: 197 | workspace = _SITE_WORKSPACE 198 | prop = None 199 | else: 200 | workspace = _DUPLO_WORKSPACE 201 | prop = props.Duplo(observable_options=observations.make_options( 202 | obs_settings, observations.FREEPROP_OBSERVABLES)) 203 | task = MultiTaskReach(task_id, 204 | arena=arena, 205 | arm=arm, 206 | hand=hand, 207 | prop=prop, 208 | obs_settings=obs_settings, 209 | workspace=workspace, 210 | control_timestep=constants.CONTROL_TIMESTEP) 211 | return task 212 | -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/quadruped.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/walker.py: -------------------------------------------------------------------------------- 1 | """Planar Walker Domain.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import collections 8 | from typing import Any, Tuple 9 | import typing as tp 10 | import os 11 | 12 | from dm_control import mujoco 13 | from dm_control.rl import control 14 | from dm_control.suite import base 15 | from dm_control.suite import common 16 | from dm_control.suite.utils import randomizers 17 | from dm_control.utils import containers 18 | from dm_control.utils import rewards 19 | from dm_control.utils import io as resources 20 | 21 | _CONTROL_TIMESTEP: float 22 | _DEFAULT_TIME_LIMIT: int 23 | _RUN_SPEED: int 24 | _SPIN_SPEED: int 25 | _STAND_HEIGHT: float 26 | _WALK_SPEED: int 27 | # from dm_control import suite # TODO useless? 28 | 29 | _DEFAULT_TIME_LIMIT = 25 30 | _CONTROL_TIMESTEP = .025 31 | 32 | # Minimal height of torso over foot above which stand reward is 1. 33 | _STAND_HEIGHT = 1.2 34 | 35 | # Horizontal speeds (meters/second) above which move reward is 1. 36 | _WALK_SPEED = 1 37 | _RUN_SPEED = 8 38 | _SPIN_SPEED = 5 39 | 40 | SUITE = containers.TaggedTasks() 41 | 42 | 43 | def make(task, 44 | task_kwargs=None, 45 | environment_kwargs=None, 46 | visualize_reward: bool = False): 47 | task_kwargs = task_kwargs or {} 48 | if environment_kwargs is not None: 49 | task_kwargs = task_kwargs.copy() 50 | task_kwargs['environment_kwargs'] = environment_kwargs 51 | env = SUITE[task](**task_kwargs) 52 | env.task.visualize_reward = visualize_reward 53 | return env 54 | 55 | 56 | def get_model_and_assets() -> Tuple[Any, Any]: 57 | """Returns a tuple containing the model XML string and a dict of assets.""" 58 | root_dir = os.path.dirname(os.path.dirname(__file__)) 59 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks', 60 | 'walker.xml')) 61 | return xml, common.ASSETS 62 | 63 | 64 | @SUITE.add('benchmarking') 65 | def flip(time_limit: int = _DEFAULT_TIME_LIMIT, 66 | random=None, 67 | environment_kwargs=None): 68 | """Returns the Run task.""" 69 | physics = Physics.from_xml_string(*get_model_and_assets()) 70 | task = PlanarWalker(move_speed=_RUN_SPEED, 71 | forward=True, 72 | flip=True, 73 | random=random) 74 | environment_kwargs = environment_kwargs or {} 75 | return control.Environment(physics, 76 | task, 77 | time_limit=time_limit, 78 | control_timestep=_CONTROL_TIMESTEP, 79 | **environment_kwargs) 80 | 81 | 82 | class Physics(mujoco.Physics): 83 | """Physics simulation with additional features for the Walker domain.""" 84 | 85 | def torso_upright(self) -> Any: 86 | """Returns projection from z-axes of torso to the z-axes of world.""" 87 | return self.named.data.xmat['torso', 'zz'] 88 | 89 | def torso_height(self) -> Any: 90 | """Returns the height of the torso.""" 91 | return self.named.data.xpos['torso', 'z'] 92 | 93 | def horizontal_velocity(self) -> Any: 94 | """Returns the horizontal velocity of the center-of-mass.""" 95 | return self.named.data.sensordata['torso_subtreelinvel'][0] 96 | 97 | def orientations(self) -> Any: 98 | """Returns planar orientations of all bodies.""" 99 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel() 100 | 101 | def angmomentum(self) -> Any: 102 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 103 | return self.named.data.subtree_angmom['torso'][1] 104 | 105 | 106 | class PlanarWalker(base.Task): 107 | """A planar walker task.""" 108 | 109 | def __init__(self, move_speed, forward=True, flip=False, random=None) -> None: 110 | """Initializes an instance of `PlanarWalker`. 111 | 112 | Args: 113 | move_speed: A float. If this value is zero, reward is given simply for 114 | standing up. Otherwise this specifies a target horizontal velocity for 115 | the walking task. 116 | random: Optional, either a `numpy.random.RandomState` instance, an 117 | integer seed for creating a new `RandomState`, or None to select a seed 118 | automatically (default). 119 | """ 120 | self._move_speed = move_speed 121 | self._forward = 1 if forward else -1 122 | self._flip = flip 123 | super(PlanarWalker, self).__init__(random=random) 124 | 125 | def initialize_episode(self, physics) -> None: 126 | """Sets the state of the environment at the start of each episode. 127 | 128 | In 'standing' mode, use initial orientation and small velocities. 129 | In 'random' mode, randomize joint angles and let fall to the floor. 130 | 131 | Args: 132 | physics: An instance of `Physics`. 133 | 134 | """ 135 | randomizers.randomize_limited_and_rotational_joints( 136 | physics, self.random) 137 | super(PlanarWalker, self).initialize_episode(physics) 138 | 139 | def get_observation(self, physics) -> tp.Dict[str, Any]: 140 | """Returns an observation of body orientations, height and velocites.""" 141 | obs = collections.OrderedDict() 142 | obs['orientations'] = physics.orientations() 143 | obs['height'] = physics.torso_height() 144 | obs['velocity'] = physics.velocity() 145 | return obs 146 | 147 | def get_reward(self, physics) -> Any: 148 | """Returns a reward to the agent.""" 149 | standing = rewards.tolerance(physics.torso_height(), 150 | bounds=(_STAND_HEIGHT, float('inf')), 151 | margin=_STAND_HEIGHT / 2) 152 | upright = (1 + physics.torso_upright()) / 2 153 | stand_reward = (3 * standing + upright) / 4 154 | 155 | if self._flip: 156 | move_reward = rewards.tolerance(self._forward * 157 | physics.angmomentum(), 158 | bounds=(_SPIN_SPEED, float('inf')), 159 | margin=_SPIN_SPEED, 160 | value_at_margin=0, 161 | sigmoid='linear') 162 | else: 163 | move_reward = rewards.tolerance( 164 | self._forward * physics.horizontal_velocity(), 165 | bounds=(self._move_speed, float('inf')), 166 | margin=self._move_speed / 2, 167 | value_at_margin=0.5, 168 | sigmoid='linear') 169 | 170 | return stand_reward * (5 * move_reward + 1) / 6 171 | -------------------------------------------------------------------------------- /common/envs/exorl/custom_dmc_tasks/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /common/envs/exorl/exorl_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import glob 4 | import tqdm 5 | import numpy as np 6 | from collections import defaultdict 7 | 8 | # from fre.common.envs.d4rl import d4rl_utils 9 | from fre.common.dataset import Dataset 10 | 11 | # get path relative to 'fre' package 12 | data_path = os.path.dirname(os.path.abspath(__file__)) 13 | data_path = Path(data_path).parents[2] 14 | data_path = os.path.join(data_path, 'data/exorl/') 15 | print("Path to exorl data is", data_path) 16 | 17 | def get_dataset(env, env_name, method='rnd', dmc_dataset_size=10000000, use_task_reward=True): 18 | 19 | # dmc_dataset_size /= 10 20 | # print("WARNING: Only using 10 percent of exorl data.") 21 | 22 | domain_name, task_name = env_name.split('_', 1) 23 | 24 | path = os.path.join(data_path, domain_name, method) 25 | if not os.path.exists(path): 26 | print("Downloading exorl data.") 27 | os.makedirs(path) 28 | url = "https://dl.fbaipublicfiles.com/exorl/" + domain_name + "/" + method + ".zip" 29 | print("Downloading from", url) 30 | os.system("wget " + url + " -P " + path) 31 | os.system("unzip " + path + "/" + method + ".zip -d " + path) 32 | 33 | # process data into Dataset object. 34 | path = os.path.join(data_path, domain_name, method, 'buffer') 35 | npzs = sorted(glob.glob(f'{path}/*.npz')) 36 | dataset_npy = os.path.join(data_path, domain_name, method, task_name + '.npy') 37 | if os.path.exists(dataset_npy): 38 | dataset = np.load(dataset_npy, allow_pickle=True).item() 39 | else: 40 | print("Calculating exorl rewards.") 41 | dataset = defaultdict(list) 42 | num_steps = 0 43 | for i, npz in tqdm.tqdm(enumerate(npzs)): 44 | traj_data = dict(np.load(npz)) 45 | dataset['observations'].append(traj_data['observation'][:-1, :]) 46 | dataset['next_observations'].append(traj_data['observation'][1:, :]) 47 | dataset['actions'].append(traj_data['action'][1:, :]) 48 | dataset['physics'].append(traj_data['physics'][1:, :]) # Note that this corresponds to next_observations (i.e., r(s, a, s') = r(s') -- following the original DMC rewards) 49 | 50 | if use_task_reward: 51 | # TODO: make this faster and sanity check it 52 | rewards = [] 53 | reward_spec = env.reward_spec() 54 | states = traj_data['physics'] 55 | for j in range(states.shape[0]): 56 | with env.physics.reset_context(): 57 | env.physics.set_state(states[j]) 58 | reward = env.task.get_reward(env.physics) 59 | reward = np.full(reward_spec.shape, reward, reward_spec.dtype) 60 | rewards.append(reward) 61 | traj_data['reward'] = np.array(rewards, dtype=reward_spec.dtype) 62 | dataset['rewards'].append(traj_data['reward'][1:]) 63 | else: 64 | dataset['rewards'].append(traj_data['reward'][1:, 0]) 65 | 66 | terminals = np.full((len(traj_data['observation']) - 1,), False) 67 | dataset['terminals'].append(terminals) 68 | num_steps += len(traj_data['observation']) - 1 69 | if num_steps >= dmc_dataset_size: 70 | break 71 | print("Loaded {} steps".format(num_steps)) 72 | for k, v in dataset.items(): 73 | dataset[k] = np.concatenate(v, axis=0) 74 | np.save(dataset_npy, dataset) 75 | 76 | 77 | 78 | # Processing 79 | masks = 1.0 - dataset['terminals'] 80 | dones_float = dataset['terminals'] 81 | 82 | return Dataset.create( 83 | observations=dataset['observations'], 84 | actions=dataset['actions'], 85 | rewards=dataset['rewards'], 86 | masks=masks, 87 | dones_float=dones_float, 88 | next_observations=dataset['next_observations'], 89 | ) -------------------------------------------------------------------------------- /common/envs/gc_utils.py: -------------------------------------------------------------------------------- 1 | from fre.common.dataset import Dataset 2 | from flax.core.frozen_dict import FrozenDict 3 | from flax.core import freeze 4 | import dataclasses 5 | import numpy as np 6 | import jax 7 | import ml_collections 8 | 9 | @dataclasses.dataclass 10 | class GCDataset: 11 | dataset: Dataset 12 | p_randomgoal: float 13 | p_trajgoal: float 14 | p_currgoal: float 15 | geom_sample: int 16 | discount: float 17 | terminal_key: str = 'dones_float' 18 | reward_scale: float = 1.0 19 | reward_shift: float = -1.0 20 | mask_terminal: int = 1 21 | 22 | @staticmethod 23 | def get_default_config(): 24 | return ml_collections.ConfigDict({ 25 | 'p_randomgoal': 0.3, 26 | 'p_trajgoal': 0.5, 27 | 'p_currgoal': 0.2, 28 | 'geom_sample': 1, 29 | 'discount': 0.99, 30 | 'reward_scale': 1.0, 31 | 'reward_shift': -1.0, 32 | 'mask_terminal': 1, 33 | }) 34 | 35 | def __post_init__(self): 36 | self.terminal_locs, = np.nonzero(self.dataset[self.terminal_key] > 0) 37 | assert np.isclose(self.p_randomgoal + self.p_trajgoal + self.p_currgoal, 1.0) 38 | 39 | def sample_goals(self, indx, p_randomgoal=None, p_trajgoal=None, p_currgoal=None): 40 | if p_randomgoal is None: 41 | p_randomgoal = self.p_randomgoal 42 | if p_trajgoal is None: 43 | p_trajgoal = self.p_trajgoal 44 | if p_currgoal is None: 45 | p_currgoal = self.p_currgoal 46 | 47 | batch_size = len(indx) 48 | # Random goals 49 | goal_indx = np.random.randint(self.dataset.size, size=batch_size) 50 | 51 | # Goals from the same trajectory 52 | final_state_indx = self.terminal_locs[np.searchsorted(self.terminal_locs, indx)] 53 | 54 | distance = np.random.rand(batch_size) 55 | if self.geom_sample: 56 | us = np.random.rand(batch_size) 57 | middle_goal_indx = np.minimum(indx + np.ceil(np.log(1 - us) / np.log(self.discount)).astype(int), final_state_indx) 58 | else: 59 | middle_goal_indx = np.round((np.minimum(indx + 1, final_state_indx) * distance + final_state_indx * (1 - distance))).astype(int) 60 | 61 | goal_indx = np.where(np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal), middle_goal_indx, goal_indx) 62 | 63 | # Goals at the current state 64 | goal_indx = np.where(np.random.rand(batch_size) < p_currgoal, indx, goal_indx) 65 | 66 | return goal_indx 67 | 68 | def sample(self, batch_size: int, indx=None): 69 | if indx is None: 70 | indx = np.random.randint(self.dataset.size-1, size=batch_size) 71 | 72 | batch = self.dataset.sample(batch_size, indx) 73 | goal_indx = self.sample_goals(indx) 74 | 75 | success = (indx == goal_indx) 76 | batch['rewards'] = success.astype(float) * self.reward_scale + self.reward_shift 77 | batch['goals'] = jax.tree_map(lambda arr: arr[goal_indx], self.dataset['observations']) 78 | 79 | if self.mask_terminal: 80 | batch['masks'] = 1.0 - success.astype(float) 81 | else: 82 | batch['masks'] = np.ones(batch_size) 83 | 84 | return batch 85 | 86 | def sample_traj_random(self, batch_size, num_traj_states, num_random_states, num_random_states_decode): 87 | indx = np.random.randint(self.dataset.size-1, size=batch_size) 88 | batch = self.dataset.sample(batch_size, indx) 89 | indx_expand = np.repeat(indx, num_traj_states-1) # (batch_size * num_traj_states) 90 | traj_indx = self.sample_goals(indx_expand, p_randomgoal=0.0, p_trajgoal=1.0, p_currgoal=0.0) 91 | traj_indx = traj_indx.reshape(batch_size, num_traj_states-1) # (batch_size, num_traj_states) 92 | batch['traj_states'] = jax.tree_map(lambda arr: arr[traj_indx], self.dataset['observations']) 93 | batch['traj_states'] = np.concatenate([batch['observations'][:,None,:], batch['traj_states']], axis=1) 94 | 95 | rand_indx = np.random.randint(self.dataset.size-1, size=batch_size * num_random_states) 96 | rand_indx = rand_indx.reshape(batch_size, num_random_states) 97 | batch['random_states'] = jax.tree_map(lambda arr: arr[rand_indx], self.dataset['observations']) 98 | 99 | rand_indx_decode = np.random.randint(self.dataset.size-1, size=batch_size * num_random_states_decode) 100 | rand_indx_decode = rand_indx_decode.reshape(batch_size, num_random_states_decode) 101 | batch['random_states_decode'] = jax.tree_map(lambda arr: arr[rand_indx_decode], self.dataset['observations']) 102 | return batch 103 | 104 | def flatten_obgoal(obgoal): 105 | return np.concatenate([obgoal['observation'], obgoal['goal']], axis=-1) -------------------------------------------------------------------------------- /common/envs/wrappers.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # 3 | # Wrappers on top of gym environments 4 | # 5 | ############################### 6 | 7 | from typing import Dict 8 | import gym 9 | import numpy as np 10 | import time 11 | 12 | class EpisodeMonitor(gym.ActionWrapper): 13 | """A class that computes episode returns and lengths.""" 14 | 15 | def __init__(self, env: gym.Env): 16 | super().__init__(env) 17 | self._reset_stats() 18 | self.total_timesteps = 0 19 | 20 | def _reset_stats(self): 21 | self.reward_sum = 0.0 22 | self.episode_length = 0 23 | self.start_time = time.time() 24 | 25 | def step(self, action: np.ndarray): 26 | observation, reward, done, info = self.env.step(action) 27 | 28 | self.reward_sum += reward 29 | self.episode_length += 1 30 | self.total_timesteps += 1 31 | info["total"] = {"timesteps": self.total_timesteps} 32 | 33 | if done: 34 | info["episode"] = {} 35 | info["episode"]["return"] = self.reward_sum 36 | info["episode"]["length"] = self.episode_length 37 | info["episode"]["duration"] = time.time() - self.start_time 38 | 39 | if hasattr(self, "get_normalized_score"): 40 | info["episode"]["normalized_return"] = ( 41 | self.get_normalized_score(info["episode"]["return"]) * 100.0 42 | ) 43 | 44 | return observation, reward, done, info 45 | 46 | def reset(self, **kwargs) -> np.ndarray: 47 | self._reset_stats() 48 | return self.env.reset(**kwargs) 49 | 50 | class RewardOverride(gym.ActionWrapper): 51 | def __init__(self, env: gym.Env): 52 | super().__init__(env) 53 | self.reward_fn = None 54 | 55 | def step(self, action: np.ndarray): 56 | observation, reward, done, info = self.env.step(action) 57 | 58 | if self.env.observation_space.shape[0] == 24: 59 | horizontal_velocity = self.env.physics.horizontal_velocity() 60 | torso_upright = self.env.physics.torso_upright() 61 | torso_height = self.env.physics.torso_height() 62 | aux = np.array([horizontal_velocity, torso_upright, torso_height]) 63 | observation_aux = np.concatenate([observation, aux]) 64 | reward = self.reward_fn(observation_aux) 65 | elif self.env.observation_space.shape[0] == 17: 66 | horizontal_velocity = self.env.physics.speed() 67 | aux = np.array([horizontal_velocity]) 68 | observation_aux = np.concatenate([observation, aux]) 69 | reward = self.reward_fn(observation_aux) 70 | else: 71 | reward = self.reward_fn(observation) 72 | return observation, reward, done, info 73 | 74 | def reset(self, **kwargs) -> np.ndarray: 75 | return self.env.reset(**kwargs) 76 | 77 | class TruncateObservation(gym.ObservationWrapper): 78 | def __init__(self, env: gym.Env, truncate_size: int): 79 | super().__init__(env) 80 | self.truncate_size = truncate_size 81 | 82 | def observation(self, observation: np.ndarray) -> np.ndarray: 83 | return observation[:self.truncate_size] 84 | 85 | class GoalWrapper(gym.ObservationWrapper): 86 | def __init__(self, env: gym.Env): 87 | super().__init__(env) 88 | self.custom_goal = None 89 | 90 | def observation(self, observation: np.ndarray) -> np.ndarray: 91 | if self.custom_goal is not None: 92 | return np.concatenate([observation, self.custom_goal]) 93 | else: 94 | return observation 95 | -------------------------------------------------------------------------------- /common/evaluation.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # 3 | # Tools for evaluating policies in environments. 4 | # 5 | ############################### 6 | 7 | 8 | from typing import Dict 9 | import jax 10 | import gym 11 | import numpy as np 12 | from collections import defaultdict 13 | import time 14 | import wandb 15 | 16 | 17 | def flatten(d, parent_key="", sep="."): 18 | """ 19 | Helper function that flattens a dictionary of dictionaries into a single dictionary. 20 | E.g: flatten({'a': {'b': 1}}) -> {'a.b': 1} 21 | """ 22 | items = [] 23 | for k, v in d.items(): 24 | new_key = parent_key + sep + k if parent_key else k 25 | if hasattr(v, "items"): 26 | items.extend(flatten(v, new_key, sep=sep).items()) 27 | else: 28 | items.append((new_key, v)) 29 | return dict(items) 30 | 31 | 32 | def add_to(dict_of_lists, single_dict): 33 | for k, v in single_dict.items(): 34 | dict_of_lists[k].append(v) 35 | 36 | 37 | def evaluate(policy_fn, env: gym.Env, num_episodes: int, record_video : bool = False, 38 | return_trajectories=False, clip_return_at_goal=False, binary_return=False, use_discrete_xy=False, clip_margin=0): 39 | print("Clip return at goal is", clip_return_at_goal) 40 | stats = defaultdict(list) 41 | frames = [] 42 | trajectories = [] 43 | for i in range(num_episodes): 44 | now = time.time() 45 | trajectory = defaultdict(list) 46 | ob_list = [] 47 | ac_list = [] 48 | observation, done = env.reset(), False 49 | ob_list.append(observation) 50 | while not done: 51 | if use_discrete_xy: 52 | import fre.common.envs.d4rl.d4rl_ant as d4rl_ant 53 | ob_input = d4rl_ant.discretize_obs(observation) 54 | else: 55 | ob_input = observation 56 | action = policy_fn(ob_input) 57 | action = np.array(action) 58 | next_observation, r, done, info = env.step(action) 59 | add_to(stats, flatten(info)) 60 | 61 | if type(observation) is dict: 62 | obs_pure = observation['observation'] 63 | next_obs_pure = next_observation['observation'] 64 | else: 65 | obs_pure = observation 66 | next_obs_pure = next_observation 67 | transition = dict( 68 | observation=obs_pure, 69 | next_observation=next_obs_pure, 70 | action=action, 71 | reward=r, 72 | done=done, 73 | info=info, 74 | ) 75 | observation = next_observation 76 | ob_list.append(observation) 77 | ac_list.append(action) 78 | add_to(trajectory, transition) 79 | 80 | if i <= 3 and record_video: 81 | frames.append(env.render(mode="rgb_array")) 82 | add_to(stats, flatten(info, parent_key="final")) 83 | trajectories.append(trajectory) 84 | print("Finished Episode", i, "in", time.time() - now, "seconds") 85 | 86 | if clip_return_at_goal and 'episode.return' in stats: 87 | print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length'])) 88 | stats['episode.return'] = np.clip(np.array(stats['episode.length']) + np.array(stats['episode.return']) - clip_margin, 0, 1) # Goal is a binary indicator. 89 | print("Clipped return is {}.".format(stats['episode.return'])) 90 | elif binary_return and 'episode.return' in stats: 91 | # Assume that the reward is either 0 or 1 at each timestep. 92 | print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length'])) 93 | stats['episode.return'] = np.clip(np.array(stats['episode.return']), 0, 1) 94 | print("Clipped return is {}.".format(stats['episode.return'])) 95 | 96 | if 'episode.return' in stats: 97 | print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length'])) 98 | 99 | for k, v in stats.items(): 100 | stats[k] = np.mean(v) 101 | 102 | if record_video: 103 | stacked = np.stack(frames) 104 | stacked = stacked.transpose(0, 3, 1, 2) 105 | while stacked.shape[2] > 160: 106 | stacked = stacked[:, :, ::2, ::2] 107 | stats['video'] = wandb.Video(stacked, fps=60) 108 | 109 | if return_trajectories: 110 | return stats, trajectories 111 | else: 112 | return stats -------------------------------------------------------------------------------- /common/networks/basic.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # 3 | # Common Flax Networks. 4 | # 5 | ############################### 6 | 7 | from fre.common.typing import * 8 | 9 | import flax.linen as nn 10 | import jax.numpy as jnp 11 | 12 | import distrax 13 | import flax.linen as nn 14 | import jax.numpy as jnp 15 | from dataclasses import field 16 | 17 | ############################### 18 | # 19 | # Common Networks 20 | # 21 | ############################### 22 | 23 | def mish(x): 24 | return x * jnp.tanh(nn.softplus(x)) 25 | 26 | def default_init(scale: Optional[float] = 1.0): 27 | return nn.initializers.variance_scaling(scale, "fan_avg", "uniform") 28 | 29 | class MLP(nn.Module): 30 | hidden_dims: Sequence[int] 31 | activations: Callable[[jnp.ndarray], jnp.ndarray] = mish 32 | activate_final: int = False 33 | use_layer_norm: bool = True 34 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_init() 35 | 36 | def setup(self): 37 | self.layers = [ 38 | nn.Dense(size, kernel_init=self.kernel_init) for size in self.hidden_dims 39 | ] 40 | if self.use_layer_norm: 41 | self.layer_norms = [nn.LayerNorm() for _ in self.hidden_dims] 42 | 43 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 44 | for i, layer in enumerate(self.layers): 45 | x = layer(x) 46 | if i + 1 < len(self.layers) and self.use_layer_norm: 47 | x = self.layer_norms[i](x) 48 | if i + 1 < len(self.layers) or self.activate_final: 49 | x = self.activations(x) 50 | return x 51 | 52 | ############################### 53 | # 54 | # Common RL Networks 55 | # 56 | ############################### 57 | 58 | 59 | # DQN-style critic. 60 | class DiscreteCritic(nn.Module): 61 | hidden_dims: Sequence[int] 62 | n_actions: int 63 | mlp_kwargs: Dict[str, Any] = field(default_factory=dict) 64 | 65 | @nn.compact 66 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 67 | return MLP((*self.hidden_dims, self.n_actions), **self.mlp_kwargs)( 68 | observations 69 | ) 70 | 71 | # Q(s,a) critic. 72 | class Critic(nn.Module): 73 | hidden_dims: Sequence[int] 74 | mlp_kwargs: Dict[str, Any] = field(default_factory=dict) 75 | 76 | @nn.compact 77 | def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: 78 | inputs = jnp.concatenate([observations, actions], -1) 79 | critic = MLP((*self.hidden_dims, 1), **self.mlp_kwargs)(inputs) 80 | return jnp.squeeze(critic, -1) 81 | 82 | # V(s) critic. 83 | class ValueCritic(nn.Module): 84 | hidden_dims: Sequence[int] 85 | mlp_kwargs: Dict[str, Any] = field(default_factory=dict) 86 | 87 | @nn.compact 88 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray: 89 | critic = MLP((*self.hidden_dims, 1), **self.mlp_kwargs)(observations) 90 | return jnp.squeeze(critic, -1) 91 | 92 | # pi(a|s). Returns a distrax distribution. 93 | class Policy(nn.Module): 94 | hidden_dims: Sequence[int] 95 | action_dim: int 96 | mlp_kwargs: Dict[str, Any] = field(default_factory=dict) 97 | 98 | is_discrete: bool = False 99 | log_std_min: Optional[float] = -20 100 | log_std_max: Optional[float] = 2 101 | mean_min: Optional[float] = -5 102 | mean_max: Optional[float] = 5 103 | tanh_squash_distribution: bool = False 104 | state_dependent_std: bool = True 105 | final_fc_init_scale: float = 1e-2 106 | 107 | @nn.compact 108 | def __call__( 109 | self, observations: jnp.ndarray, temperature: float = 1.0 110 | ) -> distrax.Distribution: 111 | outputs = MLP( 112 | self.hidden_dims, 113 | activate_final=True, 114 | **self.mlp_kwargs 115 | )(observations) 116 | 117 | if self.is_discrete: 118 | logits = nn.Dense( 119 | self.action_dim, kernel_init=default_init(self.final_fc_init_scale) 120 | )(outputs) 121 | distribution = distrax.Categorical(logits=logits / jnp.maximum(1e-6, temperature)) 122 | else: 123 | means = nn.Dense( 124 | self.action_dim, kernel_init=default_init(self.final_fc_init_scale) 125 | )(outputs) 126 | if self.state_dependent_std: 127 | log_stds = nn.Dense( 128 | self.action_dim, kernel_init=default_init(self.final_fc_init_scale) 129 | )(outputs) 130 | else: 131 | log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,)) 132 | 133 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) 134 | means = jnp.clip(means, self.mean_min, self.mean_max) 135 | 136 | distribution = distrax.MultivariateNormalDiag( 137 | loc=means, scale_diag=jnp.exp(log_stds) * temperature 138 | ) 139 | if self.tanh_squash_distribution: 140 | distribution = TransformedWithMode( 141 | distribution, distrax.Block(distrax.Tanh(), ndims=1) 142 | ) 143 | return distribution 144 | 145 | ############################### 146 | # 147 | # Helper Things 148 | # 149 | ############################### 150 | 151 | 152 | class TransformedWithMode(distrax.Transformed): 153 | def mode(self) -> jnp.ndarray: 154 | return self.bijector.forward(self.distribution.mode()) 155 | 156 | def ensemblize(cls, num_qs, out_axes=0, **kwargs): 157 | """ 158 | Useful for making ensembles of Q functions (e.g. double Q in SAC). 159 | 160 | Usage: 161 | 162 | critic_def = ensemblize(Critic, 2)(hidden_dims=hidden_dims) 163 | 164 | """ 165 | return nn.vmap( 166 | cls, 167 | variable_axes={"params": 0}, 168 | split_rngs={"params": True}, 169 | in_axes=None, 170 | out_axes=out_axes, 171 | axis_size=num_qs, 172 | **kwargs 173 | ) -------------------------------------------------------------------------------- /common/networks/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Tuple, Type 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | Array = Any 7 | PRNGKey = Any 8 | Shape = Tuple[int] 9 | Dtype = Any 10 | 11 | 12 | class IdentityLayer(nn.Module): 13 | """Identity layer, convenient for giving a name to an array.""" 14 | 15 | @nn.compact 16 | def __call__(self, x): 17 | return x 18 | 19 | 20 | class AddPositionEmbs(nn.Module): 21 | # Need to define function that adds the poisition embeddings to the input. 22 | posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] 23 | 24 | @nn.compact 25 | def __call__(self, inputs): 26 | """ 27 | inputs.shape is (batch_size, timesteps, emb_dim). 28 | Output tensor with shape `(batch_size, timesteps, in_dim)`. 29 | """ 30 | assert inputs.ndim == 3, ('Number of dimensions should be 3, but it is: %d' % inputs.ndim) 31 | 32 | position_ids = jnp.arange(inputs.shape[1])[None] # (1, timesteps) 33 | pos_embeddings = nn.Embed( 34 | 128, # Max Positional Embeddings 35 | inputs.shape[2], 36 | embedding_init=self.posemb_init, 37 | dtype=inputs.dtype, 38 | )(position_ids) 39 | print("For Input Shape {}, Pos Embes Shape is {}".format(inputs.shape, pos_embeddings.shape)) 40 | return inputs + pos_embeddings 41 | 42 | # pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) 43 | # pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape) 44 | # return inputs + pe 45 | 46 | 47 | class MlpBlock(nn.Module): 48 | """Transformer MLP / feed-forward block.""" 49 | 50 | mlp_dim: int 51 | dtype: Dtype = jnp.float32 52 | out_dim: Optional[int] = None 53 | dropout_rate: float = None 54 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform() 55 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6) 56 | 57 | @nn.compact 58 | def __call__(self, inputs, *, deterministic): 59 | """It's just an MLP, so the input shape is (batch, len, emb).""" 60 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim 61 | x = nn.Dense( 62 | features=self.mlp_dim, 63 | dtype=self.dtype, 64 | kernel_init=self.kernel_init, 65 | bias_init=self.bias_init)(inputs) 66 | x = nn.gelu(x) 67 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 68 | output = nn.Dense( 69 | features=actual_out_dim, 70 | dtype=self.dtype, 71 | kernel_init=self.kernel_init, 72 | bias_init=self.bias_init)(x) 73 | output = nn.Dropout( 74 | rate=self.dropout_rate)(output, deterministic=deterministic) 75 | return output 76 | 77 | 78 | class Encoder1DBlock(nn.Module): 79 | """Transformer encoder layer. 80 | Given a sequence, it passes it through an attention layer, then through a mlp layer. 81 | In each case it is a residual block with a layer norm. 82 | """ 83 | 84 | mlp_dim: int 85 | num_heads: int 86 | causal: bool 87 | dropout_rate: float 88 | attention_dropout_rate: float 89 | dtype: Dtype = jnp.float32 90 | 91 | @nn.compact 92 | def __call__(self, inputs, *, deterministic, train=True): 93 | 94 | if self.causal: 95 | causal_mask = nn.make_causal_mask(jnp.ones((inputs.shape[0], inputs.shape[1]), 96 | dtype="bool"), dtype="bool") 97 | print("Using Causal Mask with shape", causal_mask.shape, "and inputs shape", inputs.shape, ".") 98 | else: 99 | causal_mask = None 100 | 101 | # Attention block. 102 | assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}' 103 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 104 | x = nn.MultiHeadDotProductAttention( 105 | dtype=self.dtype, 106 | kernel_init=nn.initializers.xavier_uniform(), 107 | broadcast_dropout=False, 108 | deterministic=deterministic, 109 | dropout_rate=self.attention_dropout_rate, 110 | decode=False, 111 | num_heads=self.num_heads)(x, x, causal_mask) 112 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 113 | x = x + inputs 114 | 115 | # MLP block. This does NOT change the embedding dimension! 116 | y = nn.LayerNorm(dtype=self.dtype)(x) 117 | y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(y, deterministic=deterministic) 118 | 119 | return x + y 120 | 121 | 122 | class Transformer(nn.Module): 123 | """Transformer Model Encoder for sequence to sequence translation. 124 | """ 125 | 126 | num_layers: int 127 | emb_dim: int 128 | mlp_dim: int 129 | num_heads: int 130 | dropout_rate: float 131 | attention_dropout_rate: float 132 | causal: bool = True 133 | 134 | @nn.compact 135 | def __call__(self, x, *, train): 136 | assert x.ndim == 3 # (batch, len, emb) 137 | assert x.shape[-1] == self.emb_dim 138 | 139 | # Input Encoder. Each layer processes x, but the shape of x does not change. 140 | for lyr in range(self.num_layers): 141 | x = Encoder1DBlock( 142 | mlp_dim=self.mlp_dim, 143 | dropout_rate=self.dropout_rate, 144 | attention_dropout_rate=self.attention_dropout_rate, 145 | name=f'encoderblock_{lyr}', 146 | causal=self.causal, 147 | num_heads=self.num_heads)( 148 | x, deterministic=not train, train=train) 149 | encoded = nn.LayerNorm(name='encoder_norm')(x) 150 | 151 | return encoded 152 | 153 | def get_default_config(): 154 | import ml_collections 155 | 156 | config = ml_collections.ConfigDict({ 157 | 'num_layers': 4, 158 | 'emb_dim': 256, 159 | 'mlp_dim': 256, 160 | 'num_heads': 4, 161 | 'dropout_rate': 0.0, 162 | 'attention_dropout_rate': 0.0, 163 | 'causal': True, 164 | }) 165 | return config -------------------------------------------------------------------------------- /common/train_state.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # 3 | # Structures for managing training of flax networks. 4 | # 5 | ############################### 6 | 7 | from fre.common.typing import * 8 | import flax 9 | import flax.linen as nn 10 | import jax 11 | import jax.numpy as jnp 12 | from jax import tree_util 13 | import optax 14 | import functools 15 | 16 | import gym 17 | 18 | nonpytree_field = functools.partial(flax.struct.field, pytree_node=False) 19 | 20 | 21 | def shard_batch(batch): 22 | d = jax.local_device_count() 23 | 24 | def reshape(x): 25 | assert ( 26 | x.shape[0] % d == 0 27 | ), f"Batch size needs to be divisible by # devices, got {x.shape[0]} and {d}" 28 | return x.reshape((d, x.shape[0] // d, *x.shape[1:])) 29 | 30 | return tree_util.tree_map(reshape, batch) 31 | 32 | 33 | def target_update( 34 | model: "TrainState", target_model: "TrainState", tau: float 35 | ) -> "TrainState": 36 | new_target_params = jax.tree_map( 37 | lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params 38 | ) 39 | return target_model.replace(params=new_target_params) 40 | 41 | 42 | class TrainState(flax.struct.PyTreeNode): 43 | """ 44 | Core abstraction of a model in this repository. 45 | 46 | Creation: 47 | ``` 48 | model_def = nn.Dense(12) # or any other flax.linen Module 49 | params = model_def.init(jax.random.PRNGKey(0), jnp.ones((1, 4)))['params'] 50 | model = TrainState.create(model_def, params, tx=None) # Optionally, pass in an optax optimizer 51 | ``` 52 | 53 | Usage: 54 | ``` 55 | y = model(jnp.ones((1, 4))) # By default, uses the `__call__` method of the model_def and params stored in TrainState 56 | y = model(jnp.ones((1, 4)), params=params) # You can pass in params (useful for gradient computation) 57 | y = model(jnp.ones((1, 4)), method=method) # You can apply a different method as well 58 | ``` 59 | 60 | More complete example: 61 | ``` 62 | def loss(params): 63 | y_pred = model(x, params=params) 64 | return jnp.mean((y - y_pred) ** 2) 65 | 66 | grads = jax.grad(loss)(model.params) 67 | new_model = model.apply_gradients(grads=grads) # Alternatively, new_model = model.apply_loss_fn(loss_fn=loss) 68 | ``` 69 | """ 70 | 71 | step: int 72 | apply_fn: Callable[..., Any] = nonpytree_field() 73 | model_def: Any = nonpytree_field() 74 | params: Params 75 | tx: Optional[optax.GradientTransformation] = nonpytree_field() 76 | opt_state: Optional[optax.OptState] = None 77 | 78 | @classmethod 79 | def create( 80 | cls, 81 | model_def: nn.Module, 82 | params: Params, 83 | tx: Optional[optax.GradientTransformation] = None, 84 | **kwargs, 85 | ) -> "TrainState": 86 | if tx is not None: 87 | opt_state = tx.init(params) 88 | else: 89 | opt_state = None 90 | 91 | return cls( 92 | step=1, 93 | apply_fn=model_def.apply, 94 | model_def=model_def, 95 | params=params, 96 | tx=tx, 97 | opt_state=opt_state, 98 | **kwargs, 99 | ) 100 | 101 | def __call__( 102 | self, 103 | *args, 104 | params=None, 105 | extra_variables: dict = None, 106 | method: ModuleMethod = None, 107 | **kwargs, 108 | ): 109 | """ 110 | Internally calls model_def.apply_fn with the following logic: 111 | 112 | Arguments: 113 | params: If not None, use these params instead of the ones stored in the model. 114 | extra_variables: Additional variables to pass into apply_fn 115 | method: If None, use the `__call__` method of the model_def. If a string, uses 116 | the method of the model_def with that name (e.g. 'encode' -> model_def.encode). 117 | If a function, uses that function. 118 | 119 | """ 120 | if params is None: 121 | params = self.params 122 | 123 | variables = {"params": params} 124 | 125 | if extra_variables is not None: 126 | variables = {**variables, **extra_variables} 127 | 128 | if isinstance(method, str): 129 | method = getattr(self.model_def, method) 130 | 131 | return self.apply_fn(variables, *args, method=method, **kwargs) 132 | 133 | def do(self, method): 134 | return functools.partial(self, method=method) 135 | 136 | def apply_gradients(self, *, grads, **kwargs): 137 | """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. 138 | 139 | Note that internally this function calls `.tx.update()` followed by a call 140 | to `optax.apply_updates()` to update `params` and `opt_state`. 141 | 142 | Args: 143 | grads: Gradients that have the same pytree structure as `.params`. 144 | **kwargs: Additional dataclass attributes that should be `.replace()`-ed. 145 | 146 | Returns: 147 | An updated instance of `self` with `step` incremented by one, `params` 148 | and `opt_state` updated by applying `grads`, and additional attributes 149 | replaced as specified by `kwargs`. 150 | """ 151 | updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) 152 | new_params = optax.apply_updates(self.params, updates) 153 | 154 | return self.replace( 155 | step=self.step + 1, 156 | params=new_params, 157 | opt_state=new_opt_state, 158 | **kwargs, 159 | ) 160 | 161 | def apply_loss_fn(self, *, loss_fn, pmap_axis=None, has_aux=False): 162 | """ 163 | Takes a gradient step towards minimizing `loss_fn`. Internally, this calls 164 | `jax.grad` followed by `TrainState.apply_gradients`. If pmap_axis is provided, 165 | additionally it averages gradients (and info) across devices before performing update. 166 | """ 167 | if has_aux: 168 | grads, info = jax.grad(loss_fn, has_aux=has_aux)(self.params) 169 | if pmap_axis is not None: 170 | grads = jax.lax.pmean(grads, axis_name=pmap_axis) 171 | info = jax.lax.pmean(info, axis_name=pmap_axis) 172 | 173 | return self.apply_gradients(grads=grads), info 174 | 175 | else: 176 | grads = jax.grad(loss_fn, has_aux=has_aux)(self.params) 177 | if pmap_axis is not None: 178 | grads = jax.lax.pmean(grads, axis_name=pmap_axis) 179 | return self.apply_gradients(grads=grads) 180 | 181 | class NormalizeActionWrapper(gym.Wrapper): 182 | """A wrapper that maps actions from [-1,1] to [low, hgih].""" 183 | def __init__(self, env): 184 | super().__init__(env) 185 | self.active = type(env.action_space) == gym.spaces.Box 186 | if self.active: 187 | self.action_low = env.action_space.low 188 | self.action_high = env.action_space.high 189 | self.action_scale = (self.action_high - self.action_low) * 0.5 190 | self.action_mid = (self.action_high + self.action_low) * 0.5 191 | print("Normalizing Action Space from [{}, {}] to [-1, 1]".format(self.action_low[0], self.action_high[0])) 192 | def step(self, action): 193 | if self.active: 194 | action = np.clip(action, -1, 1) 195 | action = action * self.action_scale 196 | action = action + self.action_mid 197 | return self.env.step(action) 198 | 199 | def reset(self, **kwargs): 200 | return self.env.reset(**kwargs) 201 | 202 | -------------------------------------------------------------------------------- /common/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union 2 | import numpy as np 3 | import jax.numpy as jnp 4 | import flax 5 | 6 | PRNGKey = Any 7 | Params = flax.core.FrozenDict[str, Any] 8 | PRNGKey = Any 9 | Shape = Sequence[int] 10 | Dtype = Any # this could be a real type? 11 | InfoDict = Dict[str, float] 12 | Array = Union[np.ndarray, jnp.ndarray] 13 | Data = Union[Array, Dict[str, "Data"]] 14 | Batch = Dict[str, Data] 15 | ModuleMethod = Union[ 16 | str, Callable, None 17 | ] # A method to be passed into TrainState.__call__ 18 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # 3 | # Some shared utility functions 4 | # 5 | ############################### 6 | 7 | import jax 8 | 9 | def supply_rng(f, rng=jax.random.PRNGKey(0)): 10 | """ 11 | Wraps a function to supply jax rng. It will remember the rng state for that function. 12 | """ 13 | def wrapped(*args, **kwargs): 14 | nonlocal rng 15 | rng, key = jax.random.split(rng) 16 | return f(*args, seed=key, **kwargs) 17 | 18 | return wrapped 19 | 20 | -------------------------------------------------------------------------------- /common/wandb.py: -------------------------------------------------------------------------------- 1 | """WandB logging helpers. 2 | 3 | Run setup_wandb(hyperparam_dict, ...) to initialize wandb logging. 4 | See default_wandb_config() for a list of available configurations. 5 | 6 | We recommend the following workflow (see examples/mujoco/d4rl_iql.py for a more full example): 7 | 8 | from ml_collections import config_flags 9 | from jaxrl_m.wandb import setup_wandb, default_wandb_config 10 | import wandb 11 | 12 | # This line allows us to change wandb config flags from the command line 13 | config_flags.DEFINE_config_dict('wandb', default_wandb_config(), lock_config=False) 14 | 15 | ... 16 | def main(argv): 17 | hyperparams = ... 18 | setup_wandb(hyperparams, **FLAGS.wandb) 19 | 20 | # Log metrics as you wish now 21 | wandb.log({'metric': 0.0}, step=0) 22 | 23 | 24 | With the following setup, you may set wandb configurations from the command line, e.g. 25 | python main.py --wandb.project=my_project --wandb.group=my_group --wandb.offline 26 | """ 27 | import wandb 28 | 29 | import tempfile 30 | import absl.flags as flags 31 | import ml_collections 32 | from ml_collections.config_dict import FieldReference 33 | import datetime 34 | import wandb 35 | import time 36 | import numpy as np 37 | import os 38 | 39 | 40 | def get_flag_dict(): 41 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS} 42 | for k in flag_dict: 43 | if isinstance(flag_dict[k], ml_collections.ConfigDict): 44 | flag_dict[k] = flag_dict[k].to_dict() 45 | return flag_dict 46 | 47 | 48 | def default_wandb_config(): 49 | config = ml_collections.ConfigDict() 50 | config.offline = False # Syncs online or not? 51 | config.project = "jaxrl_m" # WandB Project Name 52 | config.entity = FieldReference(None, field_type=str) # Which entity to log as (default: your own user) 53 | 54 | group_name = FieldReference(None, field_type=str) # Group name 55 | config.exp_prefix = group_name # Group name (deprecated, but kept for backwards compatibility) 56 | config.group = group_name # Group name 57 | 58 | experiment_name = FieldReference(None, field_type=str) # Experiment name 59 | config.name = experiment_name # Run name (will be formatted with flags / variant) 60 | config.exp_descriptor = experiment_name # Run name (deprecated, but kept for backwards compatibility) 61 | 62 | config.unique_identifier = "" # Unique identifier for run (will be automatically generated unless provided) 63 | config.random_delay = 0 # Random delay for wandb.init (in seconds) 64 | return config 65 | 66 | 67 | def setup_wandb( 68 | hyperparam_dict, 69 | entity=None, 70 | project="jaxrl_m", 71 | group=None, 72 | name=None, 73 | unique_identifier="", 74 | offline=False, 75 | random_delay=0, 76 | **additional_init_kwargs, 77 | ): 78 | """ 79 | Utility for setting up wandb logging (based on Young's simplesac): 80 | 81 | Arguments: 82 | - hyperparam_dict: dict of hyperparameters for experiment 83 | - offline: bool, whether to sync online or not 84 | - project: str, wandb project name 85 | - entity: str, wandb entity name (default is your user) 86 | - group: str, Group name for wandb 87 | - name: str, Experiment name for wandb (formatted with FLAGS & hyperparameter_dict) 88 | - unique_identifier: str, Unique identifier for wandb (default is timestamp) 89 | - random_delay: float, Random delay for wandb.init (in seconds) to avoid collisions 90 | - additional_init_kwargs: dict, additional kwargs to pass to wandb.init 91 | Returns: 92 | - wandb.run 93 | 94 | """ 95 | if "exp_descriptor" in additional_init_kwargs: 96 | # Remove deprecated exp_descriptor 97 | additional_init_kwargs.pop("exp_descriptor") 98 | additional_init_kwargs.pop("exp_prefix") 99 | 100 | if not unique_identifier: 101 | if random_delay: 102 | time.sleep(np.random.uniform(0, random_delay)) 103 | unique_identifier = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 104 | unique_identifier += f"_{np.random.randint(0, 1000000):06d}" 105 | flag_dict = get_flag_dict() 106 | if 'seed' in flag_dict: 107 | unique_identifier += f"_{flag_dict['seed']:02d}" 108 | 109 | if name is not None: 110 | name = name.format(**{**get_flag_dict(), **hyperparam_dict}) 111 | 112 | if group is not None and name is not None: 113 | experiment_id = f"{name}_{unique_identifier}" 114 | elif name is not None: 115 | experiment_id = f"{name}_{unique_identifier}" 116 | else: 117 | experiment_id = None 118 | 119 | # check if dir exists. 120 | wandb_output_dir = tempfile.mkdtemp() 121 | tags = [group] if group is not None else None 122 | 123 | init_kwargs = dict( 124 | config=hyperparam_dict, 125 | project=project, 126 | entity=entity, 127 | tags=tags, 128 | group=group, 129 | dir=wandb_output_dir, 130 | id=experiment_id, 131 | name=name, 132 | settings=wandb.Settings( 133 | start_method="thread", 134 | _disable_stats=False, 135 | ), 136 | mode="offline" if offline else "online", 137 | save_code=True, 138 | ) 139 | 140 | init_kwargs.update(additional_init_kwargs) 141 | run = wandb.init(**init_kwargs) 142 | 143 | wandb.config.update(get_flag_dict()) 144 | 145 | wandb_config = dict( 146 | exp_prefix=group, 147 | exp_descriptor=name, 148 | experiment_id=experiment_id, 149 | ) 150 | wandb.config.update(wandb_config) 151 | return run 152 | -------------------------------------------------------------------------------- /deps/base_container.def: -------------------------------------------------------------------------------- 1 | Bootstrap:docker 2 | From: nvidia/cuda:11.8.0-devel-ubuntu22.04 3 | 4 | # Copy the conda env file into the container for installation 5 | %files 6 | environment.yml /contained/setup/environment.yml 7 | requirements.txt /contained/setup/requirements.txt 8 | 9 | %post -c /bin/bash 10 | apt-get update && apt-get install -y wget 11 | apt-get install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf 12 | wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -P /contained/ 13 | sh /contained/Miniforge3-Linux-x86_64.sh -b -p /contained/miniconda 14 | ls /contained/ 15 | source /contained/miniconda/etc/profile.d/conda.sh 16 | source /contained/miniconda/etc/profile.d/mamba.sh 17 | 18 | export MUJOCO_PY_MJKEY_PATH='/contained/software/mujoco/mjkey.txt' 19 | export MUJOCO_PY_MUJOCO_PATH='/contained/software/mujoco/mujoco210' 20 | export MJKEY_PATH='/contained/software/mujoco/mjkey.txt' 21 | export MJLIB_PATH='/contained/software/mujoco/mujoco210/bin/libmujoco210.so' 22 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/contained/software/mujoco/mujoco210/bin" 23 | export D4RL_SUPPRESS_IMPORT_ERROR=1 24 | 25 | mkdir /contained/software 26 | mkdir /contained/software/mujoco 27 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 28 | tar -C /contained/software/mujoco -zxvf mujoco210-linux-x86_64.tar.gz --no-same-owner 29 | 30 | export WANDB_API_KEY='' 31 | 32 | cp -r /contained/software/mujoco /root/.mujoco 33 | CONDA_OVERRIDE_CUDA="11.8" mamba env create -f /contained/setup/environment.yml 34 | 35 | mamba clean --all 36 | 37 | # Trigger mujoco-py build 38 | mamba activate $(cat /contained/setup/environment.yml | egrep "name: .+$" | sed -e 's/^name:[ \t]*//') 39 | python -c 'import gym; gym.make("HalfCheetah-v2")' 40 | 41 | chmod -R 777 /contained 42 | 43 | %environment 44 | # Activate conda environment 45 | source /contained/miniconda/etc/profile.d/conda.sh 46 | source /contained/miniconda/etc/profile.d/mamba.sh 47 | conda activate $(cat /contained/setup/environment.yml | egrep "name: .+$" | sed -e 's/^name:[ \t]*//') 48 | 49 | export MUJOCO_PY_MJKEY_PATH='/contained/software/mujoco/mjkey.txt' 50 | export MUJOCO_PY_MUJOCO_PATH='/contained/software/mujoco/mujoco210' 51 | export MJKEY_PATH='/contained/software/mujoco/mjkey.txt' 52 | export MJLIB_PATH='/contained/software/mujoco/mujoco210/bin/libmujoco210.so' 53 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/contained/software/mujoco/mujoco210/bin" 54 | export D4RL_SUPPRESS_IMPORT_ERROR=1 55 | 56 | export WANDB_API_KEY='' 57 | 58 | # Set up python path for the research project, put your home dir here. 59 | export PYTHONPATH="$PYTHONPATH:" 60 | 61 | %runscript 62 | #! /bin/bash 63 | python -m "$@" 64 | # Entry point for singularity run 65 | -------------------------------------------------------------------------------- /deps/environment.yml: -------------------------------------------------------------------------------- 1 | name: project-brc 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9 6 | - pip 7 | - numpy 8 | - scipy 9 | - h5py 10 | - matplotlib 11 | - scikit-learn 12 | - jupyter 13 | - tqdm 14 | - seaborn 15 | - mesalib 16 | - glew 17 | - glfw 18 | - Cython 19 | - conda-forge::opencv 20 | - conda-forge::jax=0.4.14 21 | - conda-forge::jaxlib=0.4.14=*cuda* 22 | - pip: 23 | - -r requirements.txt -------------------------------------------------------------------------------- /deps/requirements.txt: -------------------------------------------------------------------------------- 1 | numba 2 | opt-einsum 3 | numpy 4 | absl-py 5 | termcolor 6 | matplotlib 7 | mujoco 8 | mujoco-py 9 | ml-collections 10 | cython<3 11 | wandb 12 | imageio 13 | moviepy 14 | opensimplex 15 | pygame 16 | libtmux 17 | threadpoolctl==3.1.0 18 | plotly 19 | 20 | tensorflow-probability==0.19.0 21 | d4rl==1.1 22 | dm_control==1.0.15 23 | dm-env==1.6 24 | dm-tree==0.1.8 25 | gym==0.23.1 26 | 27 | flax==0.7.4 28 | optax==0.1.7 29 | orbax==0.1.9 30 | distrax==0.1.4 31 | chex==0.1.82 -------------------------------------------------------------------------------- /experiment/ant_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 4 | 5 | def get_canvas_image(canvas): 6 | canvas.draw() 7 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') 8 | out_image = out_image.reshape(canvas.get_width_height()[::-1] + (3,)) 9 | return out_image 10 | 11 | def value_image(env, dataset, value_fn, mask, clip=False): 12 | """ 13 | Visualize the value function. 14 | Args: 15 | env: The environment. 16 | value_fn: a function with signature value_fn([# states, state_dim]) -> [#states, 1] 17 | Returns: 18 | A numpy array of the image. 19 | """ 20 | fig = plt.figure(tight_layout=True) 21 | canvas = FigureCanvas(fig) 22 | plot_value(env, dataset, value_fn, mask, fig, plt.gca(), clip=clip) 23 | image = get_canvas_image(canvas) 24 | plt.close(fig) 25 | return image 26 | 27 | def plot_value(env, dataset, value_fn, mask, fig, ax, title=None, clip=True): 28 | N = 14 29 | M = 20 30 | ob_xy = env.XY(n=N, m=M) 31 | 32 | base_observation = np.copy(dataset['observations'][0]) 33 | base_observations = np.tile(base_observation, (5, ob_xy.shape[0], 1)) 34 | base_observations[:, :, :2] = ob_xy 35 | base_observations[:, :, 15:17] = 0.0 36 | base_observations[0, :, 15] = 1.0 37 | base_observations[1, :, 16] = 1.0 38 | base_observations[2, :, 15] = -1.0 39 | base_observations[3, :, 16] = -1.0 40 | print("Base observations, ", base_observations.shape) 41 | 42 | 43 | values = [] 44 | for i in range(5): 45 | values.append(value_fn(base_observations[i])) 46 | values = np.stack(values, axis=0) 47 | print("Values", values.shape) 48 | 49 | x, y = ob_xy[:, 0], ob_xy[:, 1] 50 | x = x.reshape(N, M) 51 | y = y.reshape(N, M) * 0.975 + 0.7 52 | values = values.reshape(5, N, M) 53 | values[-1, 10, 0] = np.min(values[-1]) + 0.3 # Hack to make the scaling not show small errors. 54 | print("Clip:", clip) 55 | if clip: 56 | mesh = ax.pcolormesh(x, y, values[-1], cmap='viridis', vmin=-0.1, vmax=1.0) 57 | else: 58 | mesh = ax.pcolormesh(x, y, values[-1], cmap='viridis') 59 | 60 | v = (values[1] - values[3]) / 2 61 | u = (values[0] - values[2]) / 2 62 | uv_dist = np.sqrt(u**2 + v**2) + 1e-6 63 | # Normalize u,v 64 | un = u / uv_dist 65 | vn = v / uv_dist 66 | un[uv_dist < 0.1] = 0 67 | vn[uv_dist < 0.1] = 0 68 | 69 | plt.quiver(x, y, un, vn, color='r', pivot='mid', scale=0.75, scale_units='xy') 70 | 71 | if mask is not None and type(mask) == np.ndarray: 72 | # mask = NxM array of things to unmask. 73 | from matplotlib.colors import LinearSegmentedColormap 74 | colors = [(0,0,0,c) for c in np.linspace(0,1,100)] 75 | cmapred = LinearSegmentedColormap.from_list('mycmap', colors, N=5) 76 | mask_mesh_ax = ax.pcolormesh(x, y, mask, cmap=cmapred) 77 | elif mask is not None and type(mask) is list: 78 | maskmesh = np.ones((N, M)) 79 | for xy in mask: 80 | for xi in range(N): 81 | for yi in range(M): 82 | if np.linalg.norm(np.array(xy) - np.array([x[xi, yi], y[xi, yi]])) < 1.4: 83 | # print(xy, x[xi, yi], y[xi, yi]) 84 | maskmesh[xi,yi] = 0 85 | from matplotlib.colors import LinearSegmentedColormap 86 | colors = [(0,0,0,c) for c in np.linspace(0,1,100)] 87 | cmapred = LinearSegmentedColormap.from_list('mycmap', colors, N=5) 88 | mask_mesh_ax = ax.pcolormesh(x, y, maskmesh, cmap=cmapred) 89 | 90 | env.draw(ax, scale=0.95) 91 | 92 | 93 | 94 | # env.draw(ax, scale=1.0) 95 | 96 | # divider = make_axes_locatable(ax) 97 | # cax = divider.append_axes('right', size='5%', pad=0.05) 98 | # fig.colorbar(mesh, cax=cax, orientation='vertical') 99 | 100 | if title: 101 | ax.set_title(title) -------------------------------------------------------------------------------- /experiment/rewards_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import opensimplex 4 | import jax 5 | import jax.numpy as jnp 6 | from functools import partial 7 | 8 | from fre.experiment.rewards_unsupervised import RewardFunction 9 | 10 | 11 | class VelocityRewardFunction(RewardFunction): 12 | def __init__(self): 13 | pass 14 | 15 | # Select an XY velocity from a future state in the trajectory. 16 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 17 | batch_size = traj_states.shape[0] 18 | selected_traj_state_idx = np.random.randint(traj_states.shape[1], size=(batch_size,)) 19 | selected_traj_state = traj_states[np.arange(batch_size), selected_traj_state_idx] # (batch_size, obs_dim) 20 | params = selected_traj_state[:, 15:17] # (batch_size, 2) 21 | params[:batch_size//4] = np.random.uniform(-1, 1, size=(batch_size//4, 2)) # Randomize 25% of the time. 22 | params = params / np.linalg.norm(params, axis=-1, keepdims=True) # Normalize XY 23 | 24 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 25 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 26 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 27 | 28 | decode_pairs = random_states_decode 29 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 30 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 31 | 32 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 33 | masks = np.ones_like(rewards) # (batch_size,) 34 | 35 | return params, encode_pairs, decode_pairs, rewards, masks 36 | 37 | def compute_reward(self, states, params): 38 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 39 | xy_vels = states[..., 15:17] * 0.33820298 40 | return np.sum(xy_vels * params, axis=-1) # (batch_size,) 41 | 42 | def make_encoder_pairs_testing(self, params, random_states): 43 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 44 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 45 | 46 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 47 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 48 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 49 | 50 | class TestRewMatrix(RewardFunction): 51 | def __init__(self): 52 | self.pos = np.zeros((36, 24)) 53 | self.xvel = np.zeros((36, 24)) 54 | self.yvel = np.zeros((36, 24)) 55 | 56 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 57 | batch_size = traj_states.shape[0] 58 | params = np.zeros((batch_size, 1)) # (batch_size, 1) 59 | 60 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 61 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 62 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 63 | 64 | decode_pairs = random_states_decode 65 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 66 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 67 | 68 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 69 | masks = np.ones_like(rewards) # (batch_size,) 70 | 71 | return params, encode_pairs, decode_pairs, rewards, masks 72 | 73 | def compute_reward(self, s, params): 74 | rews = np.zeros_like(s[..., 0]) # (batch, examples) 75 | # XY Vel Reward 76 | xy_vels = s[..., 15:17] * 0.33820298 77 | 78 | x = s[..., 0].astype(int).clip(0, 35) 79 | y = s[..., 1].astype(int).clip(0, 23) 80 | simplex = self.pos[x, y] 81 | simplex_xvel = self.xvel[x, y] 82 | simplex_yvel = self.yvel[x, y] 83 | rews = (simplex > 0.3).astype(float) * 0.5 84 | rews += xy_vels[...,0] * simplex_xvel + xy_vels[...,1] * simplex_yvel 85 | 86 | return rews 87 | 88 | def make_encoder_pairs_testing(self, params, random_states): 89 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 90 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 91 | batch_size = random_states.shape[0] 92 | 93 | # TODO: Be smarter about the states to use here. 94 | 95 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 96 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 97 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 98 | 99 | class SimplexRewardFunction(RewardFunction): 100 | def __init__(self, num_simplex): 101 | self.simplex_size = num_simplex 102 | self.simplex_seeds_pos = np.zeros((self.simplex_size, 36, 24)) 103 | self.simplex_seeds_xvel = np.zeros((self.simplex_size, 36, 24)) 104 | self.simplex_seeds_yvel = np.zeros((self.simplex_size, 36, 24)) 105 | self.simplex_best_xy = np.zeros((self.simplex_size, 10, 2)) 106 | print("Generating simplex seeds") 107 | xi = np.arange(36) 108 | yi = np.arange(24) 109 | for r in tqdm.tqdm(range(self.simplex_size)): 110 | opensimplex.seed(r) 111 | self.simplex_seeds_pos[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T 112 | opensimplex.seed(r + self.simplex_size) 113 | self.simplex_seeds_xvel[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T 114 | opensimplex.seed(r + self.simplex_size * 2) 115 | self.simplex_seeds_yvel[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T 116 | 117 | best_topn = np.argpartition(self.simplex_seeds_pos[r].flatten(), -10)[-10:] # (10,) 118 | best_xy = np.array(np.unravel_index(best_topn, self.simplex_seeds_pos[r].shape)).T # (10, 2) 119 | self.simplex_best_xy[r] = best_xy 120 | self.simplex_seeds_xvel[np.abs(self.simplex_seeds_xvel) < 0.5] = 0 121 | self.simplex_seeds_yvel[np.abs(self.simplex_seeds_yvel) < 0.5] = 0 122 | 123 | # Select an XY velocity from a future state in the trajectory. 124 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 125 | batch_size = traj_states.shape[0] 126 | params = np.random.randint(self.simplex_size, size=(batch_size, 1)) # (batch_size, 1) 127 | 128 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 129 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 130 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 131 | 132 | decode_pairs = random_states_decode 133 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 134 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 135 | 136 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 137 | masks = np.ones_like(rewards) # (batch_size,) 138 | 139 | return params, encode_pairs, decode_pairs, rewards, masks 140 | 141 | def compute_reward(self, states, params): 142 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 143 | 144 | simplex_id = params[..., 0].astype(int) 145 | x = states[..., 0].astype(int).clip(0, 35) 146 | y = states[..., 1].astype(int).clip(0, 23) 147 | simplex = self.simplex_seeds_pos[simplex_id, x, y] 148 | simplex_xvel = self.simplex_seeds_xvel[simplex_id, x, y] 149 | simplex_yvel = self.simplex_seeds_yvel[simplex_id, x, y] 150 | rews = -1 + (simplex > 0.3).astype(float) * 0.5 151 | xy_vels = states[..., 15:17] * 0.33820298 152 | rews += xy_vels[...,0] * simplex_xvel + xy_vels[...,1] * simplex_yvel 153 | return rews # (batch_size,) 154 | 155 | def make_encoder_pairs_testing(self, params, random_states): 156 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 157 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 158 | batch_size = random_states.shape[0] 159 | 160 | # For simplex rewards, make sure to include the top 4 best points. 161 | simplex_id = params[..., 0].astype(int) 162 | random_best_4 = np.random.randint(0, 10, size=(batch_size, 4)) 163 | random_states[:, :4, :2] = self.simplex_best_xy[simplex_id[:, None], random_best_4, :] 164 | 165 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 166 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 167 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 168 | 169 | class TestRewMatrixEdges(TestRewMatrix): 170 | def __init__(self): 171 | super().__init__() 172 | self.pos[:3, :] = 1 173 | self.pos[-3:, :] = 1 174 | self.pos[:, :3] = 1 175 | self.pos[:, -3:] = 1 176 | 177 | class TestRewLoop(TestRewMatrix): 178 | def __init__(self): 179 | super().__init__() 180 | self.pos[22:33, 14:18] = 1 181 | self.xvel[22:33, 14:18] = -1 182 | 183 | self.pos[21:, 0:3] = 1 184 | self.xvel[21:, 0:3] = 1 185 | 186 | self.pos[33:, 3:18] = 1 187 | self.yvel[33:, 3:18] = 1 188 | 189 | self.pos[18:21, 0:7] = 1 190 | self.yvel[18:21, 0:7] = -1 191 | 192 | class TestRewPath(TestRewMatrix): 193 | def __init__(self): 194 | super().__init__() 195 | self.pos[3:21, 7:10] = 1 196 | self.xvel[3:21, 7:10] = -1 197 | 198 | self.pos[0:3, 3:10] = 1 199 | self.yvel[0:3, 3:10] = -1 200 | 201 | self.pos[0:18, 0:3] = 1 202 | self.xvel[0:18, 0:3] = 1 203 | 204 | class TestRewLoop2(TestRewMatrix): 205 | def __init__(self): 206 | super().__init__() 207 | self.pos[22:33, 14:18] = 1 208 | self.pos[21:, 0:3] = 1 209 | self.pos[33:, 3:18] = 1 210 | self.pos[18:21, 0:7] = 1 211 | 212 | class TestRewPath2(TestRewMatrix): 213 | def __init__(self): 214 | super().__init__() 215 | self.pos[3:21, 7:10] = 1 216 | self.pos[0:3, 3:10] = 1 217 | self.pos[0:18, 0:3] = 1 218 | 219 | 220 | # =================== For DMC 221 | 222 | class VelocityRewardFunctionWalker(RewardFunction): 223 | def __init__(self): 224 | pass 225 | 226 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 227 | batch_size = traj_states.shape[0] 228 | params = np.random.uniform(low=0, high=8, size=(batch_size, 1)) # (batch_size, 1) 229 | 230 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 231 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 232 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 233 | 234 | decode_pairs = random_states_decode 235 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 236 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 237 | 238 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 239 | masks = np.ones_like(rewards) # (batch_size,) 240 | 241 | return params, encode_pairs, decode_pairs, rewards, masks 242 | 243 | def _sigmoids(self, x, value_at_1, sigmoid): 244 | if sigmoid == 'gaussian': 245 | scale = np.sqrt(-2 * np.log(value_at_1)) 246 | return np.exp(-0.5 * (x*scale)**2) 247 | 248 | elif sigmoid == 'linear': 249 | scale = 1-value_at_1 250 | scaled_x = x*scale 251 | return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0) 252 | 253 | def tolerance(self, x, lower, upper, margin=0.0, sigmoid='gaussian', value_at_margin=0.1): 254 | in_bounds = np.logical_and(lower <= x, x <= upper) 255 | d = np.where(x < lower, lower - x, x - upper) / margin 256 | value = np.where(in_bounds, 1.0, self._sigmoids(d, value_at_margin, sigmoid)) 257 | return value 258 | 259 | def compute_reward(self, states, params): 260 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 261 | 262 | _STAND_HEIGHT = 1.2 263 | horizontal_velocity = states[..., 24:25] 264 | torso_upright = states[..., 25:26] 265 | torso_height = states[..., 26:27] 266 | standing = self.tolerance(torso_height, lower=_STAND_HEIGHT, upper=float('inf'), margin=_STAND_HEIGHT/2) 267 | upright = (1 + torso_upright) / 2 268 | stand_reward = (3*standing + upright) / 4 269 | move_reward = self.tolerance(horizontal_velocity, 270 | lower=params, 271 | upper=float('inf'), 272 | margin=params/2, 273 | value_at_margin=0.5, 274 | sigmoid='linear') 275 | # move_reward[params == 0] = stand_reward[params == 0] 276 | rew = stand_reward * (5*move_reward + 1) / 6 277 | return rew[..., 0] 278 | 279 | def make_encoder_pairs_testing(self, params, random_states): 280 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 281 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 282 | 283 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 284 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 285 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 286 | 287 | class VelocityRewardFunctionCheetah(RewardFunction): 288 | def __init__(self): 289 | pass 290 | 291 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 292 | batch_size = traj_states.shape[0] 293 | params = np.random.uniform(low=-10, high=10, size=(batch_size, 1)) # (batch_size, 1) 294 | 295 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 296 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 297 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 298 | 299 | decode_pairs = random_states_decode 300 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 301 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 302 | 303 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 304 | masks = np.ones_like(rewards) # (batch_size,) 305 | 306 | return params, encode_pairs, decode_pairs, rewards, masks 307 | 308 | def _sigmoids(self, x, value_at_1, sigmoid): 309 | if sigmoid == 'linear': 310 | scale = 1-value_at_1 311 | scaled_x = x*scale 312 | return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0) 313 | else: 314 | raise NotImplementedError 315 | 316 | def tolerance(self, x, lower, upper, margin=0.0, sigmoid='linear', value_at_margin=0): 317 | in_bounds = np.logical_and(lower <= x, x <= upper) 318 | d = np.where(x < lower, lower - x, x - upper) / margin 319 | value = np.where(in_bounds, 1.0, self._sigmoids(d, value_at_margin, sigmoid)) 320 | return value 321 | 322 | def compute_reward(self, states, params): 323 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 324 | 325 | horizontal_velocity = states[..., 17:18] 326 | sign_of_param = np.sign(params) 327 | horizontal_velocity = horizontal_velocity * sign_of_param 328 | rew = self.tolerance(horizontal_velocity, 329 | lower=np.abs(params), 330 | upper=float('inf'), 331 | margin=np.abs(params), 332 | value_at_margin=0, 333 | sigmoid='linear') 334 | return rew[..., 0] 335 | 336 | def make_encoder_pairs_testing(self, params, random_states): 337 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 338 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 339 | 340 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 341 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 342 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 343 | 344 | # =================== For Kitchen 345 | 346 | class SingleTaskRewardFunction(RewardFunction): 347 | def __init__(self): 348 | self.obs_element_indices = { 349 | "bottom left burner": np.array([11, 12]), 350 | "top left burner": np.array([15, 16]), 351 | "light switch": np.array([17, 18]), 352 | "slide cabinet": np.array([19]), 353 | "hinge cabinet": np.array([20, 21]), 354 | "microwave": np.array([22]), 355 | "kettle": np.array([23, 24, 25, 26, 27, 28, 29]), 356 | } 357 | self.obs_element_goals = { 358 | "bottom left burner": np.array([-0.88, -0.01]), 359 | "top left burner": np.array([-0.92, -0.01]), 360 | "light switch": np.array([-0.69, -0.05]), 361 | "slide cabinet": np.array([0.37]), 362 | "hinge cabinet": np.array([0.0, 1.45]), 363 | "microwave": np.array([-0.75]), 364 | "kettle": np.array([-0.23, 0.75, 1.62, 0.99, 0.0, 0.0, -0.06]), 365 | } 366 | self.dist_thresh = 0.3 367 | self.num_tasks = len(self.obs_element_indices) 368 | 369 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 370 | batch_size = traj_states.shape[0] 371 | params = np.random.randint(self.num_tasks, size=(batch_size, 1)) # (batch_size, 1) 372 | params = np.eye(self.num_tasks)[params[:, 0]] # (batch_size, num_tasks) 373 | 374 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 375 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 376 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 377 | 378 | decode_pairs = random_states_decode 379 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 380 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 381 | 382 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 383 | masks = np.ones_like(rewards) # (batch_size,) 384 | 385 | return params, encode_pairs, decode_pairs, rewards, masks 386 | 387 | def compute_reward(self, states, params): 388 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 389 | task_rewards = [] 390 | for task, target_indices in self.obs_element_indices.items(): 391 | task_dists = np.linalg.norm(states[..., target_indices] - self.obs_element_goals[task], axis=-1) 392 | task_completes = (task_dists < self.dist_thresh).astype(float) 393 | task_rewards.append(task_completes) 394 | task_rewards = np.stack(task_rewards, axis=-1) 395 | 396 | return np.sum(task_rewards * params, axis=-1) 397 | 398 | def make_encoder_pairs_testing(self, params, random_states): 399 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 400 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 401 | 402 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 403 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 404 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 405 | 406 | -------------------------------------------------------------------------------- /experiment/rewards_unsupervised.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import opensimplex 4 | import jax 5 | import jax.numpy as jnp 6 | from functools import partial 7 | 8 | class RewardFunction(): 9 | # Given a batch of trajectory states and random states, generate a reward function. 10 | # Return the labelled state-reward pairs. (batch_size, num_pairs, obs_dim + 1) 11 | def generate_params_and_pairs(self, traj_states, random_states): 12 | raise NotImplementedError 13 | 14 | # Given a batch of states and a batch of parameters, compute the reward. 15 | def compute_reward(self, states, params): 16 | raise NotImplementedError 17 | 18 | class GoalReachingRewardFunction(RewardFunction): 19 | def __init__(self): 20 | self.p_current = 0.2 21 | self.p_trajectory = 0.5 22 | self.p_random = 0.3 23 | 24 | # TODO: If this is slow, we can try and JIT it. 25 | # Select a random goal from the provided states. 26 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 27 | all_states = np.concatenate([traj_states, random_states], axis=1) 28 | batch_size = all_states.shape[0] 29 | p_trajectory_normalized = self.p_trajectory / traj_states.shape[1] 30 | p_random_normalized = self.p_random / random_states.shape[1] 31 | probabilities = [self.p_current] + [p_trajectory_normalized] * (traj_states.shape[1]-1) \ 32 | + [p_random_normalized] * random_states.shape[1] 33 | probabilities = np.array(probabilities) / np.sum(probabilities) 34 | selected_goal_idx = np.random.choice(len(probabilities), size=(batch_size,), p=probabilities) 35 | selected_goal = all_states[np.arange(batch_size), selected_goal_idx] 36 | 37 | params = selected_goal # (batch_size, obs_dim) 38 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 39 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 40 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 41 | 42 | decode_pairs = random_states_decode 43 | decode_pairs[:, 0] = params # Decode the goal state too. 44 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 45 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 46 | 47 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 48 | masks = -rewards # If (rew=-1, mask=1), else (rew=0, mask=0) 49 | 50 | return params, encode_pairs, decode_pairs, rewards, masks 51 | 52 | def compute_reward(self, states, params, delta=False): 53 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 54 | if states.shape[-1] == 29: # AntMaze 55 | if delta: 56 | dists = np.linalg.norm(states - params, axis=-1) 57 | is_goal = (dists < 0.1) 58 | else: 59 | dists = np.linalg.norm(states[..., :2] - params[..., :2], axis=-1) 60 | is_goal = (dists < 2) 61 | return -1 + is_goal.astype(float) # (batch_size,) 62 | elif states.shape[-1] == 18: # Cheetah 63 | std = np.array([[0.4407440506721877, 10.070289916801876, 0.5172332956856273, 0.5601041145815341, 0.518947027289748, 0.3204431592542281, 0.5501848643154092, 0.3856393812067661, 1.9882502334402663, 1.6377168569884073, 4.308505013609855, 12.144181770553105, 13.537567521831702, 16.88983033626308, 7.715009572436841, 14.345667964212357, 10.6904255152284, 100]]) 64 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 65 | # if len(states.shape) == 3: 66 | # breakpoint() 67 | dists_per_dim = states - params 68 | dists_per_dim = dists_per_dim / std 69 | dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1] 70 | is_goal = (dists < 0.08) 71 | # print(dists_per_dim) 72 | # print(dists, is_goal) 73 | return -1 + is_goal.astype(float) # (batch_size,) 74 | elif states.shape[-1] == 27: # Walker 75 | std = np.array([[0.7212967364054736, 0.6775020895964047, 0.7638155887842976, 0.6395721376821286, 0.6849394775886244, 0.7078581708129903, 0.7113168519036742, 0.6753408522523937, 0.6818095329625652, 0.7133958718133511, 0.65227578338642, 0.757622576816855, 0.7311826446274479, 0.6745824928740024, 0.36822491550384456, 2.1134839667805805, 1.813353841099317, 10.594648894374815, 17.41041469033713, 17.836743227082106, 22.399097178637533, 16.1492222730888, 15.693574546557201, 18.539929326905067, 100, 100, 100]]) 76 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 77 | dists_per_dim = states - params 78 | dists_per_dim = dists_per_dim / std 79 | dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1] 80 | is_goal = (dists < 0.2) 81 | return -1 + is_goal.astype(float) # e6yfwsc ebnev (batch_size,) 82 | elif states.shape[-1] == 30: # Kitchen 83 | dists_per_dim = states - params 84 | dists_per_dim = dists_per_dim 85 | dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1] 86 | is_goal = (dists < 1e-6) 87 | return -1 + is_goal.astype(float) 88 | else: 89 | raise NotImplementedError 90 | 91 | def make_encoder_pairs_testing(self, params, random_states): 92 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 93 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 94 | 95 | if random_states.shape[-1] == 29: # AntMaze 96 | random_states[:, 0, :2] = params[:, :2] # Make sure to include the goal. 97 | else: 98 | random_states[:, 0] = params 99 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 100 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 101 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 102 | 103 | class LinearRewardFunction(RewardFunction): 104 | def __init__(self): 105 | pass 106 | 107 | # Randomly generate a linear weighting over state features. 108 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 109 | assert len(traj_states.shape) == 3, traj_states.shape # (batch_size, traj_len, obs_dim) 110 | batch_size = traj_states.shape[0] 111 | state_len = traj_states.shape[-1] 112 | 113 | params = np.random.uniform(-1, 1, size=(batch_size, state_len)) # Uniform weighting. 114 | random_mask = np.random.uniform(size=(batch_size,state_len)) < 0.9 115 | if state_len == 29: 116 | random_mask[:, :2] = True # Zero out the XY position for antmaze. 117 | random_mask_positive = np.random.randint(2, state_len, size=(batch_size)) 118 | random_mask[np.arange(batch_size), random_mask_positive] = False # Force at least one positive weight. 119 | params[random_mask] = 0 # Zero out some of the weights. 120 | # if state_len == 29: 121 | # params = params / np.linalg.norm(params, axis=-1, keepdims=True) # Normalize XY 122 | 123 | # Remove auxilliary features during training. 124 | if state_len == 27: 125 | params[:, -3:] = 0 126 | if state_len == 18: 127 | params[:, -1:] = 0 128 | 129 | clip_bit = np.random.uniform(size=(batch_size,)) < 0.5 130 | params = np.concatenate([params, clip_bit[:, None]], axis=-1) # (batch_size, obs_dim + 1) 131 | 132 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 133 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 134 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 135 | 136 | decode_pairs = random_states_decode 137 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 138 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 139 | 140 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 141 | masks = np.ones_like(rewards) # (batch_size,) 142 | 143 | return params, encode_pairs, decode_pairs, rewards, masks 144 | 145 | def compute_reward(self, states, params): 146 | params_raw = params[..., :-1] 147 | assert len(states.shape) == len(params_raw.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 148 | r = np.sum(states * params_raw, axis=-1) # (batch_size,) 149 | r = np.where(params[..., -1] > 0, np.clip(r, 0, 1), np.clip(r, -1, 1)) 150 | return r 151 | 152 | def make_encoder_pairs_testing(self, params, random_states): 153 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 154 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 155 | 156 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 157 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 158 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 159 | 160 | class RandomRewardFunction(RewardFunction): 161 | def __init__(self, num_simplex, obs_len=29): 162 | # Pre-compute parameter matrices. 163 | print("Generating parameter matrices...") 164 | self.simplex_size = num_simplex 165 | np_random = np.random.RandomState(0) 166 | self.param_w1 = np_random.normal(size=(self.simplex_size, obs_len, 32)) * np.sqrt(1/32) 167 | self.param_b1 = np_random.normal(size=(self.simplex_size, 1, 32)) * np.sqrt(16) 168 | self.param_w2 = np_random.normal(size=(self.simplex_size, 32, 1)) * np.sqrt(1/16) 169 | 170 | # Remove auxilliary features during training. 171 | if obs_len == 27: 172 | self.param_w1[:, -3:] = 0 173 | if obs_len == 18: 174 | self.param_w1[:, -1:] = 0 175 | 176 | # Random neural network. 177 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode): 178 | batch_size = traj_states.shape[0] 179 | params = np.random.randint(self.simplex_size, size=(batch_size, 1)) # (batch_size, 1) 180 | 181 | encode_pairs = np.concatenate([traj_states, random_states], axis=1) 182 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None] 183 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1) 184 | 185 | decode_pairs = random_states_decode 186 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None] 187 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1) 188 | 189 | rewards = encode_pairs[:, 0, -1] # (batch_size,) 190 | masks = np.ones_like(rewards) # (batch_size,) 191 | 192 | return params, encode_pairs, decode_pairs, rewards, masks 193 | 194 | def compute_reward(self, states, params): 195 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim) 196 | 197 | param_id = params[..., 0].astype(int) 198 | param1_w = self.param_w1[param_id] 199 | param1_b = self.param_b1[param_id] 200 | param2_w = self.param_w2[param_id] 201 | 202 | obs = states 203 | x = np.expand_dims(obs, -2) # [batch, (pairs), 1, features_in] 204 | x = np.matmul(x, param1_w) # [batch, (pairs), 1, features_out] 205 | x = x + param1_b 206 | x = np.tanh(x) 207 | x = np.matmul(x, param2_w) # [batch, (pairs), 1, 1] 208 | x = x.squeeze(-1).squeeze(-1) # [batch, (pairs)] 209 | x = np.clip(x, -1, 1) 210 | return x 211 | 212 | def make_encoder_pairs_testing(self, params, random_states): 213 | assert len(params.shape) == 2, params.shape # (batch_size, 2) 214 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim) 215 | batch_size = random_states.shape[0] 216 | 217 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None] 218 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1) 219 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1) 220 | --------------------------------------------------------------------------------