├── .gitignore
├── README.md
├── algs
├── bc.py
└── iql.py
├── common
├── dataset.py
├── envs
│ ├── bandit
│ │ └── bandit.py
│ ├── d4rl
│ │ ├── antmaze_actions.npy
│ │ ├── d4rl_ant.py
│ │ └── d4rl_utils.py
│ ├── data_transforms.py
│ ├── dmc
│ │ ├── __init__.py
│ │ ├── jaco.py
│ │ └── wrappers.py
│ ├── env_helper.py
│ ├── exorl
│ │ ├── custom_dmc_tasks
│ │ │ ├── __init__.py
│ │ │ ├── cheetah.py
│ │ │ ├── cheetah.xml
│ │ │ ├── hopper.py
│ │ │ ├── hopper.xml
│ │ │ ├── jaco.py
│ │ │ ├── quadruped.py
│ │ │ ├── quadruped.xml
│ │ │ ├── walker.py
│ │ │ └── walker.xml
│ │ ├── dmc.py
│ │ └── exorl_utils.py
│ ├── gc_utils.py
│ └── wrappers.py
├── evaluation.py
├── networks
│ ├── basic.py
│ └── transformer.py
├── train_state.py
├── typing.py
├── utils.py
└── wandb.py
├── deps
├── base_container.def
├── environment.yml
└── requirements.txt
└── experiment
├── ant_helper.py
├── rewards_eval.py
├── rewards_unsupervised.py
└── run_fre.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 | *.sh
162 | make_sbatch/
163 | experiment_output/
164 | ipynbs/
165 | .DS_Store
166 | data/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unsupervised Zero-Shot RL via Functional Reward Representations
2 | Code for "Unsupervised Zero-Shot RL via Functional Reward Representations"
3 |
4 | Kevin Frans, Seohong Park, Pieter Abbeel, Sergey Levine
5 |
6 | [Link to Paper](https://arxiv.org/abs/2402.17135)
7 |
8 | ### Abstract
9 | Can we pre-train a generalist agent from a large amount of unlabeled offline trajectories such that it can be immediately adapted to any new downstream tasks in a zero-shot manner?
10 | In this work, we present a \emph{functional} reward encoding (FRE) as a general, scalable solution to this *zero-shot RL* problem.
11 | Our main idea is to learn functional representations of any arbitrary tasks by encoding their state-reward samples using a transformer-based variational auto-encoder.
12 | This functional encoding not only enables the pre-training of an agent from a wide diversity of general unsupervised reward functions, but also provides a way to solve any new downstream tasks in a zero-shot manner, given a small number of reward-annotated samples.
13 | We empirically show that FRE agents trained on diverse random unsupervised reward functions can generalize to solve novel tasks in a range of simulated robotic benchmarks, often outperforming previous zero-shot RL and offline RL methods.
14 |
15 | ### Code Instructions
16 | First install the dependencies in the `deps` folder.
17 | ```
18 | cd deps
19 | conda env create -f environment.yml
20 | ```
21 |
22 | For the ExORL experiments, you will need to first download the data using [these instructions](https://github.com/denisyarats/exorl).
23 | Then, download the [auxilliary offline data](https://drive.google.com/drive/folders/1HDkCP6-eLKuyQRPcyO3ei-vhubMRlGct?usp=sharing) and place it in the `data/` folder.
24 |
25 | To run the code for the experiments, use the following commands.
26 |
27 | ```
28 | # AntMaze
29 | python experiment/run_pre.py --env_name antmaze-large-diverse-v2
30 | # ExORL
31 | python experiment/run_pre.py --env_name dmc_walker_walk --agent.warmup_steps 1000000 --max_steps 2000000
32 | python experiment/run_pre.py --env_name dmc_cheetah_run --agent.warmup_steps 1000000 --max_steps 2000000
33 | # Kitchen
34 | python experiment/run_pre.py --env_name kitchen-mixed-v0 --agent.warmup_steps 1000000 --max_steps 2000000
35 | ```
36 |
--------------------------------------------------------------------------------
/algs/bc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from absl import app, flags
3 | from functools import partial
4 | import numpy as np
5 | import tqdm
6 | import jax
7 | import jax.numpy as jnp
8 | import optax
9 | import flax
10 | import wandb
11 | from ml_collections import config_flags
12 | import pickle
13 | from flax.training import checkpoints
14 | import ml_collections
15 |
16 | import fre.common.envs.d4rl.d4rl_utils as d4rl_utils
17 | from fre.common.envs.gc_utils import GCDataset
18 | from fre.common.envs.env_helper import make_env
19 | from fre.common.wandb import setup_wandb, default_wandb_config, get_flag_dict
20 | from fre.common.evaluation import evaluate
21 | from fre.common.utils import supply_rng
22 | from fre.common.typing import *
23 | from fre.common.train_state import TrainState, target_update
24 | from fre.common.networks.basic import Policy, ValueCritic, Critic, ensemblize
25 |
26 |
27 | ###############################
28 | # Configs
29 | ###############################
30 |
31 |
32 | FLAGS = flags.FLAGS
33 | flags.DEFINE_string('env_name', 'gc-antmaze-large-diverse-v2', 'Environment name.')
34 | flags.DEFINE_string('name', 'default', '')
35 |
36 | flags.DEFINE_string('save_dir', None, 'Logging dir (if not None, save params).')
37 |
38 | flags.DEFINE_integer('seed', np.random.choice(1000000), 'Random seed.')
39 | flags.DEFINE_integer('eval_episodes', 20,
40 | 'Number of episodes used for evaluation.')
41 | flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
42 | flags.DEFINE_integer('eval_interval', 50000, 'Eval interval.')
43 | flags.DEFINE_integer('save_interval', 250000, 'Eval interval.')
44 | flags.DEFINE_integer('video_interval', 50000, 'Eval interval.')
45 | flags.DEFINE_integer('batch_size', 1024, 'Mini batch size.')
46 | flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
47 | flags.DEFINE_integer('goal_conditioned', 0, 'Whether to use goal conditioned relabelling or not.')
48 |
49 | # These variables are passed to the BCAgent class.
50 | agent_config = ml_collections.ConfigDict({
51 | 'actor_lr': 3e-4,
52 | 'hidden_dims': (512, 512, 512),
53 | 'opt_decay_schedule': 'none',
54 | 'use_tanh': 0,
55 | 'state_dependent_std': 0,
56 | 'use_layer_norm': 1,
57 | })
58 |
59 | wandb_config = default_wandb_config()
60 | wandb_config.update({
61 | 'project': 'mujoco_rlalgs',
62 | 'name': 'bc_{env_name}',
63 | })
64 |
65 |
66 | config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
67 | config_flags.DEFINE_config_dict('agent', agent_config, lock_config=False)
68 | config_flags.DEFINE_config_dict('gcdataset', GCDataset.get_default_config(), lock_config=False)
69 |
70 | ###############################
71 | # Agent. Contains the neural networks, training logic, and sampling.
72 | ###############################
73 |
74 | class BCAgent(flax.struct.PyTreeNode):
75 | rng: PRNGKey
76 | actor: TrainState
77 | config: dict = flax.struct.field(pytree_node=False)
78 |
79 | @jax.jit
80 | def update(agent, batch: Batch) -> InfoDict:
81 | observations = batch['observations']
82 | actions = batch['actions']
83 |
84 | def actor_loss_fn(actor_params):
85 | dist = agent.actor(observations, params=actor_params)
86 | log_probs = dist.log_prob(actions)
87 | actor_loss = -(log_probs).mean()
88 |
89 | mse_error = jnp.square(dist.loc - actions).mean()
90 |
91 | return actor_loss, {
92 | 'actor_loss': actor_loss,
93 | 'action_std': dist.stddev().mean(),
94 | 'mse_error': mse_error,
95 | }
96 |
97 | new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=actor_loss_fn, has_aux=True)
98 |
99 | return agent.replace(actor=new_actor), {
100 | **actor_info
101 | }
102 |
103 | @jax.jit
104 | def sample_actions(agent,
105 | observations: np.ndarray,
106 | *,
107 | seed: PRNGKey,
108 | temperature: float = 1.0) -> jnp.ndarray:
109 | if type(observations) is dict:
110 | observations = jnp.concatenate([observations['observation'], observations['goal']], axis=-1)
111 | actions = agent.actor(observations, temperature=temperature).sample(seed=seed)
112 | actions = jnp.clip(actions, -1, 1)
113 | return actions
114 |
115 | def create_agent(
116 | seed: int,
117 | observations: jnp.ndarray,
118 | actions: jnp.ndarray,
119 | actor_lr: float,
120 | use_tanh: bool,
121 | state_dependent_std: bool,
122 | use_layer_norm: bool,
123 | hidden_dims: Sequence[int],
124 | opt_decay_schedule: str,
125 | max_steps: Optional[int] = None,
126 | **kwargs):
127 |
128 | print('Extra kwargs:', kwargs)
129 |
130 | rng = jax.random.PRNGKey(seed)
131 | rng, actor_key, critic_key, value_key = jax.random.split(rng, 4)
132 |
133 | action_dim = actions.shape[-1]
134 | actor_def = Policy(hidden_dims, action_dim=action_dim,
135 | log_std_min=-5.0, state_dependent_std=state_dependent_std, tanh_squash_distribution=use_tanh, mlp_kwargs=dict(use_layer_norm=use_layer_norm))
136 |
137 | if opt_decay_schedule == "cosine":
138 | schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps)
139 | actor_tx = optax.chain(optax.scale_by_adam(),
140 | optax.scale_by_schedule(schedule_fn))
141 | else:
142 | actor_tx = optax.adam(learning_rate=actor_lr)
143 |
144 | actor_params = actor_def.init(actor_key, observations)['params']
145 | actor = TrainState.create(actor_def, actor_params, tx=actor_tx)
146 |
147 | config = flax.core.FrozenDict(dict(
148 | actor_lr=actor_lr,
149 | ))
150 |
151 | return BCAgent(rng, actor=actor, config=config)
152 |
153 | ###############################
154 | # Run Script. Loads data, logs to wandb, and runs the training loop.
155 | ###############################
156 |
157 | def main(_):
158 | if FLAGS.goal_conditioned:
159 | assert 'gc' in FLAGS.env_name
160 | else:
161 | assert 'gc' not in FLAGS.env_name
162 |
163 | # Create wandb logger
164 | setup_wandb(FLAGS.agent.to_dict(), **FLAGS.wandb)
165 |
166 | env = make_env(FLAGS.env_name)
167 | eval_env = make_env(FLAGS.env_name)
168 |
169 | dataset = d4rl_utils.get_dataset(env, FLAGS.env_name)
170 | dataset = d4rl_utils.normalize_dataset(FLAGS.env_name, dataset)
171 | if FLAGS.goal_conditioned:
172 | dataset = GCDataset(dataset, **FLAGS.gcdataset.to_dict())
173 | example_batch = dataset.sample(1)
174 | example_obs = np.concatenate([example_batch['observations'], example_batch['goals']], axis=-1)
175 | debug_batch = dataset.sample(100)
176 | print("Masks Look Like", debug_batch['masks'])
177 | print("Rewards Look Like", debug_batch['rewards'])
178 | else:
179 | example_obs = dataset.sample(1)['observations']
180 |
181 | agent = create_agent(FLAGS.seed,
182 | example_obs,
183 | example_batch['actions'],
184 | max_steps=FLAGS.max_steps,
185 | **FLAGS.agent)
186 |
187 | for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
188 | smoothing=0.1,
189 | dynamic_ncols=True):
190 |
191 | batch = dataset.sample(FLAGS.batch_size)
192 | if FLAGS.goal_conditioned:
193 | batch['observations'] = np.concatenate([batch['observations'], batch['goals']], axis=-1)
194 | batch['next_observations'] = np.concatenate([batch['next_observations'], batch['goals']], axis=-1)
195 |
196 | agent, update_info = agent.update(batch)
197 |
198 | if i % FLAGS.log_interval == 0:
199 | train_metrics = {f'training/{k}': v for k, v in update_info.items()}
200 | wandb.log(train_metrics, step=i)
201 |
202 | if i % FLAGS.eval_interval == 0:
203 | record_video = i % FLAGS.video_interval == 0
204 | policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0)
205 | eval_info, trajs = evaluate(policy_fn, eval_env, num_episodes=FLAGS.eval_episodes, record_video=record_video, return_trajectories=True)
206 | eval_metrics = {}
207 | for k in ['episode.return', 'episode.length']:
208 | eval_metrics[f'evaluation/{k}'] = eval_info[k]
209 | print(f'evaluation/{k}: {eval_info[k]}')
210 | eval_metrics['evaluation/episode.return.normalized'] = eval_env.get_normalized_score(eval_info['episode.return'])
211 | print(f'evaluation/episode.return.normalized: {eval_metrics["evaluation/episode.return.normalized"]}')
212 | if record_video:
213 | wandb.log({'video': eval_info['video']}, step=i)
214 |
215 | # Antmaze Specific Logging
216 | if 'antmaze-large' in FLAGS.env_name or 'maze2d-large' in FLAGS.env_name:
217 | import fre.common.envs.d4rl.d4rl_ant as d4rl_ant
218 | # Make an image of the trajectories.
219 | traj_image = d4rl_ant.trajectory_image(eval_env, trajs)
220 | eval_metrics['trajectories'] = wandb.Image(traj_image)
221 |
222 | wandb.log(eval_metrics, step=i)
223 |
224 | if __name__ == '__main__':
225 | app.run(main)
--------------------------------------------------------------------------------
/algs/iql.py:
--------------------------------------------------------------------------------
1 | import os
2 | from absl import app, flags
3 | from functools import partial
4 | import numpy as np
5 | import tqdm
6 | import jax
7 | import jax.numpy as jnp
8 | import flax
9 | import optax
10 | import wandb
11 | from ml_collections import config_flags
12 | import pickle
13 | from flax.training import checkpoints
14 | import ml_collections
15 |
16 | from fre.common.envs.gc_utils import GCDataset
17 | from fre.common.envs.data_transforms import ActionDiscretizeCluster, ActionDiscretizeBins, ActionTransform
18 | from fre.common.envs.env_helper import make_env, get_dataset
19 | from fre.common.wandb import setup_wandb, default_wandb_config, get_flag_dict
20 | from fre.common.evaluation import evaluate
21 | from fre.common.utils import supply_rng
22 | from fre.common.typing import *
23 | from fre.common.train_state import TrainState, target_update
24 | from fre.common.networks.basic import Policy, ValueCritic, Critic, ensemblize
25 |
26 |
27 | ###############################
28 | # Configs
29 | ###############################
30 |
31 |
32 | FLAGS = flags.FLAGS
33 | flags.DEFINE_string('env_name', 'gc-antmaze-large-diverse-v2', 'Environment name.')
34 | flags.DEFINE_string('name', 'default', '')
35 |
36 | flags.DEFINE_string('save_dir', None, 'Logging dir (if not None, save params).')
37 |
38 | flags.DEFINE_integer('seed', np.random.choice(1000000), 'Random seed.')
39 | flags.DEFINE_integer('eval_episodes', 20,
40 | 'Number of episodes used for evaluation.')
41 | flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
42 | flags.DEFINE_integer('eval_interval', 50000, 'Eval interval.')
43 | flags.DEFINE_integer('save_interval', 250000, 'Eval interval.')
44 | flags.DEFINE_integer('video_interval', 250000, 'Eval interval.')
45 | flags.DEFINE_integer('batch_size', 1024, 'Mini batch size.')
46 | flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
47 | flags.DEFINE_integer('goal_conditioned', 0, 'Whether to use goal conditioned relabelling or not.')
48 |
49 | # These variables are passed to the IQLAgent class.
50 | agent_config = ml_collections.ConfigDict({
51 | 'actor_lr': 3e-4,
52 | 'value_lr': 3e-4,
53 | 'critic_lr': 3e-4,
54 | 'num_qs': 2,
55 | 'actor_hidden_dims': (512, 512, 512),
56 | 'hidden_dims': (512, 512, 512),
57 | 'discount': 0.99,
58 | 'expectile': 0.9,
59 | 'temperature': 3.0, # 0 for behavior cloning.
60 | 'dropout_rate': 0,
61 | 'use_tanh': 0,
62 | 'state_dependent_std': 0,
63 | 'use_layer_norm': 1,
64 | 'tau': 0.005,
65 | 'opt_decay_schedule': 'none',
66 | 'action_transform_type': 'none',
67 | 'actor_loss_type': 'awr', # or ddpg
68 | 'bc_weight': 0.0, # for ddpg
69 | })
70 |
71 | wandb_config = default_wandb_config()
72 | wandb_config.update({
73 | 'name': 'iql_{env_name}',
74 | })
75 |
76 | config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
77 | config_flags.DEFINE_config_dict('agent', agent_config, lock_config=False)
78 | config_flags.DEFINE_config_dict('gcdataset', GCDataset.get_default_config(), lock_config=False)
79 |
80 | ###############################
81 | # Agent. Contains the neural networks, training logic, and sampling.
82 | ###############################
83 |
84 | def expectile_loss(diff, expectile=0.8):
85 | weight = jnp.where(diff > 0, expectile, (1 - expectile))
86 | return weight * (diff**2)
87 |
88 | class IQLAgent(flax.struct.PyTreeNode):
89 | rng: PRNGKey
90 | critic: TrainState
91 | target_critic: TrainState
92 | value: TrainState
93 | actor: TrainState
94 | config: dict = flax.struct.field(pytree_node=False)
95 |
96 | @jax.jit
97 | def update(agent, batch: Batch) -> InfoDict:
98 | def critic_loss_fn(critic_params):
99 | next_v = agent.value(batch['next_observations'])
100 | target_q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_v
101 | qs = agent.critic(batch['observations'], batch['actions'], params=critic_params) # [num_q, batch]
102 | critic_loss = ((qs - target_q[None])**2).mean()
103 | return critic_loss, {
104 | 'critic_loss': critic_loss,
105 | 'q1': qs[0].mean(),
106 | }
107 |
108 | def value_loss_fn(value_params):
109 | qs = agent.target_critic(batch['observations'], batch['actions'])
110 | q = jnp.min(qs, axis=0) # Min over ensemble.
111 | v = agent.value(batch['observations'], params=value_params)
112 | value_loss = expectile_loss(q-v, agent.config['expectile']).mean()
113 | return value_loss, {
114 | 'value_loss': value_loss,
115 | 'v': v.mean(),
116 | }
117 |
118 | def actor_loss_fn(actor_params):
119 | if agent.config['actor_loss_type'] == 'awr':
120 | v = agent.value(batch['observations'])
121 | qs = agent.critic(batch['observations'], batch['actions'])
122 | q = jnp.min(qs, axis=0) # Min over ensemble.
123 | exp_a = jnp.exp((q - v) * agent.config['temperature'])
124 | exp_a = jnp.minimum(exp_a, 100.0)
125 |
126 | actions = batch['actions']
127 | if agent.config['action_transform_type'] == "cluster":
128 | actions = agent.config['action_transform'].action_to_ids(actions)
129 | dist = agent.actor(batch['observations'], params=actor_params)
130 | log_probs = dist.log_prob(actions)
131 | actor_loss = -(exp_a * log_probs).mean()
132 |
133 | action_std = dist.stddev().mean() if agent.config['action_transform_type'] == 'none' else 0
134 | return actor_loss, {
135 | 'actor_loss': actor_loss,
136 | 'action_std': action_std,
137 | 'adv': (q - v).mean(),
138 | 'adv_min': (q - v).min(),
139 | 'adv_max': (q - v).max(),
140 | }
141 | elif agent.config['actor_loss_type'] == 'ddpg':
142 | dist = agent.actor(batch['observations'], params=actor_params)
143 | normalized_actions = jnp.tanh(dist.loc)
144 | qs = agent.critic(batch['observations'], normalized_actions)
145 | q = jnp.min(qs, axis=0) # Min over ensemble.
146 |
147 | q_loss = -q.mean()
148 |
149 | log_probs = dist.log_prob(batch['actions'])
150 | bc_loss = -((agent.config['bc_weight'] * log_probs)).mean() # Abuse the name 'temperature' for the BC coefficient
151 |
152 | actor_loss = (q_loss + bc_loss).mean()
153 | return actor_loss, {
154 | 'bc_loss': bc_loss,
155 | 'agent_q': q.mean(),
156 | }
157 | else:
158 | raise NotImplementedError
159 |
160 | new_critic, critic_info = agent.critic.apply_loss_fn(loss_fn=critic_loss_fn, has_aux=True)
161 | new_target_critic = target_update(agent.critic, agent.target_critic, agent.config['target_update_rate'])
162 | new_value, value_info = agent.value.apply_loss_fn(loss_fn=value_loss_fn, has_aux=True)
163 | new_actor, actor_info = agent.actor.apply_loss_fn(loss_fn=actor_loss_fn, has_aux=True)
164 |
165 | return agent.replace(critic=new_critic, target_critic=new_target_critic, value=new_value, actor=new_actor), {
166 | **critic_info, **value_info, **actor_info
167 | }
168 |
169 | @jax.jit
170 | def sample_actions(agent,
171 | observations: np.ndarray,
172 | *,
173 | seed: PRNGKey,
174 | temperature: float = 1.0) -> jnp.ndarray:
175 | if type(observations) is dict:
176 | observations = jnp.concatenate([observations['observation'], observations['goal']], axis=-1)
177 | actions = agent.actor(observations, temperature=temperature).sample(seed=seed)
178 | if agent.config['action_transform_type'] == "cluster":
179 | actions = agent.config['action_transform'].ids_to_action(actions)
180 | if agent.config['actor_loss_type'] == 'ddpg':
181 | actions = jnp.tanh(actions)
182 | actions = jnp.clip(actions, -1, 1)
183 | return actions
184 |
185 | # Initializes all the networks, etc. for the agent.
186 | def create_agent(
187 | seed: int,
188 | observations: jnp.ndarray,
189 | actions: jnp.ndarray,
190 | actor_lr: float,
191 | value_lr: float,
192 | critic_lr: float,
193 | hidden_dims: Sequence[int],
194 | actor_hidden_dims: Sequence[int],
195 | discount: float,
196 | tau: float,
197 | expectile: float,
198 | temperature: float,
199 | use_tanh: bool,
200 | state_dependent_std: bool,
201 | use_layer_norm: bool,
202 | opt_decay_schedule: str,
203 | actor_loss_type: str,
204 | bc_weight: float,
205 | num_qs: int,
206 | action_transform_type: str,
207 | action_transform: ActionTransform = None,
208 | max_steps: Optional[int] = None,
209 | **kwargs):
210 |
211 | print('Extra kwargs:', kwargs)
212 |
213 | rng = jax.random.PRNGKey(seed)
214 | rng, actor_key, critic_key, value_key = jax.random.split(rng, 4)
215 |
216 | action_dim = actions.shape[-1]
217 | if action_transform_type == "cluster":
218 | num_clusters = action_transform.num_clusters
219 | actor_def = Policy(actor_hidden_dims, action_dim=num_clusters, is_discrete=True, mlp_kwargs=dict(use_layer_norm=use_layer_norm))
220 | else:
221 | actor_def = Policy(actor_hidden_dims, action_dim=action_dim,
222 | log_std_min=-5.0, state_dependent_std=state_dependent_std, tanh_squash_distribution=use_tanh, mlp_kwargs=dict(use_layer_norm=use_layer_norm))
223 |
224 | if opt_decay_schedule == "cosine":
225 | schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps)
226 | actor_tx = optax.chain(optax.scale_by_adam(),
227 | optax.scale_by_schedule(schedule_fn))
228 | else:
229 | actor_tx = optax.adam(learning_rate=actor_lr)
230 |
231 | actor_params = actor_def.init(actor_key, observations)['params']
232 | actor = TrainState.create(actor_def, actor_params, tx=actor_tx)
233 |
234 | critic_def = ensemblize(Critic, num_qs=num_qs)(hidden_dims, mlp_kwargs=dict(use_layer_norm=use_layer_norm))
235 | critic_params = critic_def.init(critic_key, observations, actions)['params']
236 | critic = TrainState.create(critic_def, critic_params, tx=optax.adam(learning_rate=critic_lr))
237 | target_critic = TrainState.create(critic_def, critic_params)
238 |
239 | value_def = ValueCritic(hidden_dims, mlp_kwargs=dict(use_layer_norm=use_layer_norm))
240 | value_params = value_def.init(value_key, observations)['params']
241 | value = TrainState.create(value_def, value_params, tx=optax.adam(learning_rate=value_lr))
242 |
243 | config = flax.core.FrozenDict(dict(
244 | discount=discount, temperature=temperature, expectile=expectile, target_update_rate=tau, action_transform_type=action_transform_type,
245 | action_transform=action_transform, actor_loss_type=actor_loss_type, bc_weight=bc_weight
246 | ))
247 |
248 | return IQLAgent(rng, critic=critic, target_critic=target_critic, value=value, actor=actor, config=config)
249 |
250 |
251 | ###############################
252 | # Run Script. Loads data, logs to wandb, and runs the training loop.
253 | ###############################
254 |
255 |
256 | def main(_):
257 | if FLAGS.goal_conditioned:
258 | assert 'gc' in FLAGS.env_name
259 | else:
260 | assert 'gc' not in FLAGS.env_name
261 |
262 | np.random.seed(FLAGS.seed)
263 |
264 | # Create wandb logger
265 | setup_wandb(FLAGS.agent.to_dict(), **FLAGS.wandb)
266 |
267 | env = make_env(FLAGS.env_name)
268 | eval_env = make_env(FLAGS.env_name)
269 |
270 | dataset = get_dataset(env, FLAGS.env_name)
271 | if FLAGS.goal_conditioned:
272 | dataset = GCDataset(dataset, **FLAGS.gcdataset.to_dict())
273 | example_batch = dataset.sample(1)
274 | example_obs = np.concatenate([example_batch['observations'], example_batch['goals']], axis=-1)
275 | debug_batch = dataset.sample(100)
276 | print("Masks Look Like", debug_batch['masks'])
277 | print("Rewards Look Like", debug_batch['rewards'])
278 | else:
279 | example_obs = dataset.sample(1)['observations']
280 | example_batch = dataset.sample(1)
281 |
282 | if FLAGS.agent['action_transform_type'] == 'cluster':
283 | if FLAGS.goal_conditioned:
284 | action_transform = ActionDiscretizeCluster(1024, dataset.dataset["actions"][::100])
285 | else:
286 | action_transform = ActionDiscretizeCluster(1024, dataset["actions"][::100])
287 | else:
288 | action_transform = None
289 |
290 | agent = create_agent(FLAGS.seed,
291 | example_obs,
292 | example_batch['actions'],
293 | max_steps=FLAGS.max_steps,
294 | action_transform=action_transform,
295 | **FLAGS.agent)
296 |
297 | for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
298 | smoothing=0.1,
299 | dynamic_ncols=True):
300 |
301 | batch = dataset.sample(FLAGS.batch_size)
302 | if FLAGS.goal_conditioned:
303 | batch['observations'] = np.concatenate([batch['observations'], batch['goals']], axis=-1)
304 | batch['next_observations'] = np.concatenate([batch['next_observations'], batch['goals']], axis=-1)
305 |
306 | agent, update_info = agent.update(batch)
307 |
308 | if i % FLAGS.log_interval == 0:
309 | train_metrics = {f'training/{k}': v for k, v in update_info.items()}
310 | wandb.log(train_metrics, step=i)
311 |
312 | if i % FLAGS.eval_interval == 0:
313 | record_video = i % FLAGS.video_interval == 0
314 | policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0)
315 | eval_info, trajs = evaluate(policy_fn, eval_env, num_episodes=FLAGS.eval_episodes, record_video=record_video, return_trajectories=True)
316 | eval_metrics = {}
317 | for k in ['episode.return', 'episode.length']:
318 | eval_metrics[f'evaluation/{k}'] = eval_info[k]
319 | print(f'evaluation/{k}: {eval_info[k]}')
320 | try:
321 | eval_metrics['evaluation/episode.return.normalized'] = eval_env.get_normalized_score(eval_info['episode.return'])
322 | print(f'evaluation/episode.return.normalized: {eval_metrics["evaluation/episode.return.normalized"]}')
323 | except:
324 | pass
325 | if record_video:
326 | wandb.log({'video': eval_info['video']}, step=i)
327 |
328 | # Antmaze Specific Logging
329 | if 'antmaze-large' in FLAGS.env_name or 'maze2d-large' in FLAGS.env_name:
330 | import fre.common.envs.d4rl.d4rl_ant as d4rl_ant
331 | # Make an image of the trajectories.
332 | traj_image = d4rl_ant.trajectory_image(eval_env, trajs)
333 | eval_metrics['trajectories'] = wandb.Image(traj_image)
334 |
335 | # Make an image of the value function.
336 | if 'antmaze-large' in FLAGS.env_name or 'maze2d-large' in FLAGS.env_name:
337 | def get_gcvalue(state, goal):
338 | obgoal = jnp.concatenate([state, goal], axis=-1)
339 | return agent.value(obgoal)
340 | pred_value_img = d4rl_ant.value_image(eval_env, dataset, get_gcvalue)
341 | eval_metrics['v'] = wandb.Image(pred_value_img)
342 |
343 | # Maze2d Action Distribution
344 | if 'maze2d-large' in FLAGS.env_name:
345 | # Make a plot of the actions.
346 | traj_actions = np.concatenate([t['action'] for t in trajs], axis=0) # (T, A)
347 | import matplotlib.pyplot as plt
348 | plt.figure()
349 | plt.scatter(traj_actions[::100, 0], traj_actions[::100, 1], alpha=0.4)
350 | wandb.log({'actions_traj': wandb.Image(plt)}, step=i)
351 |
352 | data_actions = batch['actions']
353 | import matplotlib.pyplot as plt
354 | plt.figure()
355 | plt.scatter(data_actions[:, 0], data_actions[:, 1], alpha=0.2)
356 | wandb.log({'actions_data': wandb.Image(plt)}, step=i)
357 |
358 | wandb.log(eval_metrics, step=i)
359 |
360 | if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
361 | checkpoints.save_checkpoint(FLAGS.save_dir, agent, i)
362 |
363 | if __name__ == '__main__':
364 | app.run(main)
--------------------------------------------------------------------------------
/common/dataset.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | #
3 | # Dataset Pytrees for offline data, replay buffers, etc.
4 | #
5 | ###############################
6 |
7 | import numpy as np
8 | from fre.common.typing import Data, Array
9 | from flax.core.frozen_dict import FrozenDict
10 | from jax import tree_util
11 |
12 |
13 | def get_size(data: Data) -> int:
14 | sizes = tree_util.tree_map(lambda arr: len(arr), data)
15 | return max(tree_util.tree_leaves(sizes))
16 |
17 |
18 | class Dataset(FrozenDict):
19 | """
20 | A class for storing (and retrieving batches of) data in nested dictionary format.
21 |
22 | Example:
23 | dataset = Dataset({
24 | 'observations': {
25 | 'image': np.random.randn(100, 28, 28, 1),
26 | 'state': np.random.randn(100, 4),
27 | },
28 | 'actions': np.random.randn(100, 2),
29 | })
30 |
31 | batch = dataset.sample(32)
32 | # Batch should have nested shape: {
33 | # 'observations': {'image': (32, 28, 28, 1), 'state': (32, 4)},
34 | # 'actions': (32, 2)
35 | # }
36 | """
37 |
38 | @classmethod
39 | def create(
40 | cls,
41 | observations: Data,
42 | actions: Array,
43 | rewards: Array,
44 | masks: Array,
45 | next_observations: Data,
46 | freeze=True,
47 | **extra_fields
48 | ):
49 | data = {
50 | "observations": observations,
51 | "actions": actions,
52 | "rewards": rewards,
53 | "masks": masks,
54 | "next_observations": next_observations,
55 | **extra_fields,
56 | }
57 | # Force freeze
58 | if freeze:
59 | tree_util.tree_map(lambda arr: arr.setflags(write=False), data)
60 | return cls(data)
61 |
62 | def __init__(self, *args, **kwargs):
63 | super().__init__(*args, **kwargs)
64 | self.size = get_size(self._dict)
65 |
66 | def sample(self, batch_size: int, indx=None):
67 | """
68 | Sample a batch of data from the dataset. Use `indx` to specify a specific
69 | set of indices to retrieve. Otherwise, a random sample will be drawn.
70 |
71 | Returns a dictionary with the same structure as the original dataset.
72 | """
73 | if indx is None:
74 | indx = np.random.randint(self.size, size=batch_size)
75 | return self.get_subset(indx)
76 |
77 | def get_subset(self, indx):
78 | return tree_util.tree_map(lambda arr: arr[indx], self._dict)
79 |
80 |
81 | class ReplayBuffer(Dataset):
82 | """
83 | Dataset where data is added to the buffer.
84 |
85 | Example:
86 | example_transition = {
87 | 'observations': {
88 | 'image': np.random.randn(28, 28, 1),
89 | 'state': np.random.randn(4),
90 | },
91 | 'actions': np.random.randn(2),
92 | }
93 | buffer = ReplayBuffer.create(example_transition, size=1000)
94 | buffer.add_transition(example_transition)
95 | batch = buffer.sample(32)
96 |
97 | """
98 |
99 | @classmethod
100 | def create(cls, transition: Data, size: int):
101 | def create_buffer(example):
102 | example = np.array(example)
103 | return np.zeros((size, *example.shape), dtype=example.dtype)
104 |
105 | buffer_dict = tree_util.tree_map(create_buffer, transition)
106 | return cls(buffer_dict)
107 |
108 | @classmethod
109 | def create_from_initial_dataset(cls, init_dataset: dict, size: int):
110 | def create_buffer(init_buffer):
111 | buffer = np.zeros((size, *init_buffer.shape[1:]), dtype=init_buffer.dtype)
112 | buffer[: len(init_buffer)] = init_buffer
113 | return buffer
114 |
115 | buffer_dict = tree_util.tree_map(create_buffer, init_dataset)
116 | dataset = cls(buffer_dict)
117 | dataset.size = dataset.pointer = get_size(init_dataset)
118 | return dataset
119 |
120 | def __init__(self, *args, **kwargs):
121 | super().__init__(*args, **kwargs)
122 |
123 | self.max_size = get_size(self._dict)
124 | self.size = 0
125 | self.pointer = 0
126 |
127 | def add_transition(self, transition):
128 | def set_idx(buffer, new_element):
129 | buffer[self.pointer] = new_element
130 |
131 | tree_util.tree_map(set_idx, self._dict, transition)
132 | self.pointer = (self.pointer + 1) % self.max_size
133 | self.size = max(self.pointer, self.size)
134 |
135 | def clear(self):
136 | self.size = self.pointer = 0
137 |
--------------------------------------------------------------------------------
/common/envs/bandit/bandit.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gym
3 |
4 | # Here's a super simple bandit environment that follows the OpenAI Gym API.
5 | # There is one continuous action. The observation is always zero.
6 | # A reward of 1 is given if the action is either 0.5 or -0.5.
7 |
8 | class BanditEnv(gym.Env):
9 | def __init__(self):
10 | self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
11 | self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
12 | self._state = None
13 | self.width = 0.15
14 |
15 | def reset(self):
16 | self._state = np.zeros(1)
17 | return self._state
18 |
19 | def step(self, action):
20 | reward = 0
21 | if (np.abs(action[0] - 0.5) < self.width) or (np.abs(action[0] + 0.5) < self.width):
22 | reward = 1
23 | self.last_action = action
24 | return self._state, reward, True, {}
25 |
26 | def render(self, mode='human'):
27 | # Render the last action on a line. Also indicate where the reward is. Return this as a numpy array.
28 | img = np.ones((20, 100, 3), dtype=np.uint8) * 255
29 | # Render reward zones in green. 0-100 means actions between -1 and 1.
30 | center_low = 25
31 | center_high = 75
32 | width_int = int(self.width * 50)
33 | img[:, center_low-width_int:center_low+width_int, :] = [0, 255, 0]
34 | img[:, center_high-width_int:center_high+width_int, :] = [0, 255, 0]
35 | # Render the last action in red.
36 | action = self.last_action[0]
37 | action = int((action + 1) * 50)
38 | img[:, action:action+1, :] = [255, 0, 0]
39 | return img
40 |
41 |
42 |
43 |
44 | def close(self):
45 | pass
--------------------------------------------------------------------------------
/common/envs/d4rl/antmaze_actions.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kvfrans/fre/c9b5a1d8d88a69bbe25da0d88d39ef1a9c4cf39a/common/envs/d4rl/antmaze_actions.npy
--------------------------------------------------------------------------------
/common/envs/d4rl/d4rl_ant.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | from matplotlib import patches
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
9 | from functools import partial
10 | from mpl_toolkits.axes_grid1 import make_axes_locatable
11 |
12 | import os
13 | import os.path as osp
14 |
15 | import gym
16 | import d4rl
17 | import numpy as np
18 | import functools as ft
19 | import math
20 | import matplotlib.gridspec as gridspec
21 |
22 | from fre.common.envs.gc_utils import GCDataset
23 |
24 | class MazeWrapper(gym.Wrapper):
25 | def __init__(self, env_name):
26 | self.env = gym.make(env_name)
27 | self.env.render(mode='rgb_array', width=200, height=200)
28 | self.env_name = env_name
29 | self.inner_env = get_inner_env(self.env)
30 | if 'antmaze' in env_name:
31 | if 'medium' in env_name:
32 | self.env.viewer.cam.lookat[0] = 10
33 | self.env.viewer.cam.lookat[1] = 10
34 | self.env.viewer.cam.distance = 40
35 | self.env.viewer.cam.elevation = -90
36 | elif 'umaze' in env_name:
37 | self.env.viewer.cam.lookat[0] = 4
38 | self.env.viewer.cam.lookat[1] = 4
39 | self.env.viewer.cam.distance = 30
40 | self.env.viewer.cam.elevation = -90
41 | elif 'large' in env_name:
42 | self.env.viewer.cam.lookat[0] = 18
43 | self.env.viewer.cam.lookat[1] = 13
44 | self.env.viewer.cam.distance = 55
45 | self.env.viewer.cam.elevation = -90
46 | self.inner_env.goal_sampler = ft.partial(valid_goal_sampler, self.inner_env)
47 | elif 'maze2d' in env_name:
48 | if 'open' in env_name:
49 | pass
50 | elif 'large' in env_name:
51 | self.env.viewer.cam.lookat[0] = 5
52 | self.env.viewer.cam.lookat[1] = 6.5
53 | self.env.viewer.cam.distance = 15
54 | self.env.viewer.cam.elevation = -90
55 | self.env.viewer.cam.azimuth = 180
56 | self.draw_ant_maze = get_inner_env(gym.make('antmaze-large-diverse-v2'))
57 | self.action_space = self.env.action_space
58 |
59 | def render(self, *args, **kwargs):
60 | img = self.env.render(*args, **kwargs)
61 | if 'maze2d' in self.env_name:
62 | img = img[::-1]
63 | return img
64 |
65 | # ======== BELOW is helper stuff for drawing and visualizing ======== #
66 |
67 | def get_starting_boundary(self):
68 | if 'antmaze' in self.env_name:
69 | self = self.inner_env
70 | else:
71 | self = self.draw_ant_maze
72 | torso_x, torso_y = self._init_torso_x, self._init_torso_y
73 | S = self._maze_size_scaling
74 | return (0 - S / 2 + S - torso_x, 0 - S/2 + S - torso_y), (len(self._maze_map[0]) * S - torso_x - S/2 - S, len(self._maze_map) * S - torso_y - S/2 - S)
75 |
76 | def XY(self, n=20, m=30):
77 | bl, tr = self.get_starting_boundary()
78 | X = np.linspace(bl[0] + 0.04 * (tr[0] - bl[0]) , tr[0] - 0.04 * (tr[0] - bl[0]), m)
79 | Y = np.linspace(bl[1] + 0.04 * (tr[1] - bl[1]) , tr[1] - 0.04 * (tr[1] - bl[1]), n)
80 |
81 | X,Y = np.meshgrid(X,Y)
82 | states = np.array([X.flatten(), Y.flatten()]).T
83 | return states
84 |
85 | def four_goals(self):
86 | self = self.inner_env
87 |
88 | valid_cells = []
89 | goal_cells = []
90 |
91 | for i in range(len(self._maze_map)):
92 | for j in range(len(self._maze_map[0])):
93 | if self._maze_map[i][j] in [0, 'r', 'g']:
94 | valid_cells.append(self._rowcol_to_xy((i, j), add_random_noise=False))
95 |
96 | goals = []
97 | goals.append(max(valid_cells, key=lambda x: -x[0]-x[1]))
98 | goals.append(max(valid_cells, key=lambda x: x[0]-x[1]))
99 | goals.append(max(valid_cells, key=lambda x: x[0]+x[1]))
100 | goals.append(max(valid_cells, key=lambda x: -x[0] + x[1]))
101 | return goals
102 |
103 | def draw(self, ax=None, scale=1.0):
104 | if not ax: ax = plt.gca()
105 | if 'antmaze' in self.env_name:
106 | self = self.inner_env
107 | else:
108 | self = self.draw_ant_maze
109 | torso_x, torso_y = self._init_torso_x, self._init_torso_y
110 | S = self._maze_size_scaling
111 | if scale < 1.0:
112 | S *= 0.965
113 | torso_x -= 0.7
114 | torso_y -= 0.95
115 | for i in range(len(self._maze_map)):
116 | for j in range(len(self._maze_map[0])):
117 | struct = self._maze_map[i][j]
118 | if struct == 1:
119 | rect = patches.Rectangle((j *S - torso_x - S/ 2,
120 | i * S- torso_y - S/ 2),
121 | S,
122 | S, linewidth=1, edgecolor='none', facecolor='grey', alpha=1.0)
123 |
124 | ax.add_patch(rect)
125 | ax.set_xlim(0 - S /2 + 0.6 * S - torso_x, len(self._maze_map[0]) * S - torso_x - S/2 - S * 0.6)
126 | ax.set_ylim(0 - S/2 + 0.6 * S - torso_y, len(self._maze_map) * S - torso_y - S/2 - S * 0.6)
127 | ax.axis('off')
128 |
129 | class CenteredMaze(MazeWrapper):
130 | start_loc: str = "center"
131 |
132 | def __init__(self, env_name, start_loc="center"):
133 | super().__init__(env_name)
134 | self.start_loc = start_loc
135 | self.t = 0
136 |
137 | def step(self, action):
138 | next_obs, r, done, info = self.env.step(action)
139 | if 'antmaze' in self.env_name:
140 | info['x'], info['y'] = self.get_xy()
141 | self.t += 1
142 | done = self.t >= 2000
143 | return next_obs, r, done, info
144 |
145 | def reset(self, **kwargs):
146 | self.t = 0
147 | obs = self.env.reset(**kwargs)
148 | if 'maze2d' in self.env_name:
149 | if self.start_loc == 'center' or self.start_loc == 'center2':
150 | obs = self.env.reset_to_location([4, 5.8])
151 | elif self.start_loc == 'original':
152 | obs = self.env.reset_to_location([0.9, 0.9])
153 | else:
154 | raise NotImplementedError
155 | elif 'antmaze' in self.env_name:
156 | if self.start_loc == 'center' or self.start_loc == 'center2':
157 | self.env.set_xy([20, 15])
158 | obs[:2] = [20, 15]
159 | elif self.start_loc == 'original':
160 | pass
161 | else:
162 | raise NotImplementedError
163 | return obs
164 |
165 | class GoalReachingMaze(MazeWrapper):
166 | def __init__(self, env_name):
167 | super().__init__(env_name)
168 | self.observation_space = gym.spaces.Dict({
169 | 'observation': self.env.observation_space,
170 | 'goal': self.env.observation_space,
171 | })
172 |
173 | def step(self, action):
174 | next_obs, r, done, info = self.env.step(action)
175 |
176 | if 'antmaze' in self.env_name:
177 | achieved = self.get_xy()
178 | desired = self.target_goal
179 | elif 'maze2d' in self.env_name:
180 | achieved = next_obs[:2]
181 | desired = self.env.get_target()
182 | distance = np.linalg.norm(achieved - desired)
183 | info['x'], info['y'] = achieved
184 | info['achieved_goal'] = np.array(achieved)
185 | info['desired_goal'] = np.copy(desired)
186 | info['success'] = float(distance < 0.5)
187 | done = 'TimeLimit.truncated' in info or info['success']
188 |
189 | return self.get_obs(next_obs), r, done, info
190 |
191 | def get_obs(self, obs):
192 | if 'antmaze' in self.env_name:
193 | desired = self.target_goal
194 | elif 'maze2d' in self.env_name:
195 | desired = self.env.get_target()
196 | target_goal = obs.copy()
197 | target_goal[:2] = desired
198 | if 'antmaze' in self.env_name:
199 | obs = discretize_obs(obs)
200 | target_goal = discretize_obs(target_goal)
201 | return dict(observation=obs, goal=target_goal)
202 |
203 | def reset(self, **kwargs):
204 | obs = self.env.reset(**kwargs)
205 | if 'maze2d' in self.env_name:
206 | obs = self.env.reset_to_location([0.9, 0.9])
207 | return self.get_obs(obs)
208 |
209 | def get_normalized_score(self, score):
210 | return score
211 |
212 | # ===================================
213 | # HELPER FUNCTIONS FOR OB DISCRETIZATION
214 | # ===================================
215 |
216 | def discretize_obs(ob, num_bins=32, disc_type='tanh', disc_temperature=1.0):
217 | min_ob = np.array([0, 0])
218 | max_ob = np.array([35, 35])
219 | disc_dims = 2
220 | bins = np.linspace(min_ob, max_ob, num_bins).T # [num_bins,] values from min_ob to max_ob
221 | bin_size = (max_ob - min_ob) / num_bins
222 | if disc_type == 'twohot':
223 | raise NotImplementedError
224 | elif disc_type == 'tanh':
225 | orig_ob = ob
226 | ob = np.expand_dims(ob, -1)
227 | # Convert each discretized dimension into num_bins dimensions. Value of each dimension is tanh of the distance from the bin center.
228 | bin_diff = ob[..., :disc_dims, :] - bins[:disc_dims]
229 | bin_diff_normalized = bin_diff / np.expand_dims(bin_size[:disc_dims], -1) * disc_temperature
230 | bin_tanh = np.tanh(bin_diff_normalized).reshape(*orig_ob.shape[:-1], -1)
231 | disc_ob = np.concatenate([bin_tanh, orig_ob[..., disc_dims:]], axis=-1)
232 | return disc_ob
233 | else:
234 | raise NotImplementedError
235 |
236 | # ===================================
237 | # HELPER FUNCTIONS FOR VISUALIZATION
238 | # ===================================
239 |
240 | def get_canvas_image(canvas):
241 | canvas.draw()
242 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
243 | out_image = out_image.reshape(canvas.get_width_height()[::-1] + (3,))
244 | return out_image
245 |
246 | def valid_goal_sampler(self, np_random):
247 | valid_cells = []
248 | goal_cells = []
249 | # print('Hello')
250 |
251 | for i in range(len(self._maze_map)):
252 | for j in range(len(self._maze_map[0])):
253 | if self._maze_map[i][j] in [0, 'r', 'g']:
254 | valid_cells.append((i, j))
255 |
256 | # If there is a 'goal' designated, use that. Otherwise, any valid cell can
257 | # be a goal.
258 | sample_choices = valid_cells
259 | cell = sample_choices[np_random.choice(len(sample_choices))]
260 | xy = self._rowcol_to_xy(cell, add_random_noise=True)
261 |
262 | random_x = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling
263 | random_y = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling
264 |
265 | xy = (max(xy[0] + random_x, 0), max(xy[1] + random_y, 0))
266 |
267 | return xy
268 |
269 |
270 | def get_inner_env(env):
271 | if hasattr(env, '_maze_size_scaling'):
272 | return env
273 | elif hasattr(env, 'env'):
274 | return get_inner_env(env.env)
275 | elif hasattr(env, 'wrapped_env'):
276 | return get_inner_env(env.wrapped_env)
277 | return env
278 |
279 |
280 | # ===================================
281 | # PLOT VALUE FUNCTION
282 | # ===================================
283 |
284 | def value_image(env, dataset, value_fn):
285 | """
286 | Visualize the value function.
287 | Args:
288 | env: The environment.
289 | value_fn: a function with signature value_fn([# states, state_dim]) -> [#states, 1]
290 | Returns:
291 | A numpy array of the image.
292 | """
293 | fig, axs = plt.subplots(2, 2, tight_layout=True)
294 | axs_flat = axs.flatten()
295 | canvas = FigureCanvas(fig)
296 | if type(dataset) is GCDataset:
297 | dataset = dataset.dataset
298 | if 'antmaze' in env.env_name:
299 | goals = env.four_goals()
300 | goal_states = dataset['observations'][0]
301 | goal_states = goal_states[-29:] # Remove discretized observations.
302 | goal_states = np.tile(goal_states, (len(goals), 1))
303 | goal_states[:, :2] = goals
304 | goal_states = discretize_obs(goal_states)
305 | elif 'maze2d' in env.env_name:
306 | goals = np.array([[0.8, 0.8], [1, 9.7], [6.8, 9], [6.8, 1]])
307 | goal_states = dataset['observations'][0]
308 | goal_states = np.tile(goal_states, (len(goals), 1))
309 | goal_states[:, :2] = goals
310 | for i in range(4):
311 | plot_value(goal_states[i], env, dataset, value_fn, axs_flat[i])
312 | image = get_canvas_image(canvas)
313 | plt.close(fig)
314 | return image
315 |
316 | def plot_value(goal_observation, env, dataset, value_fn, ax):
317 | N = 14
318 | M = 20
319 | ob_xy = env.XY(n=N, m=M)
320 |
321 | goal_observation = np.tile(goal_observation, (ob_xy.shape[0], 1)) # (N*M, 29)
322 |
323 | base_observation = np.copy(dataset['observations'][0])
324 | xy_observations = np.tile(base_observation, (ob_xy.shape[0], 1)) # (N*M, 29)
325 | if 'antmaze' in env.env_name:
326 | xy_observations = xy_observations[:, -29:] # Remove discretized observations.
327 | xy_observations[:, :2] = ob_xy # Set to XY.
328 | xy_observations = discretize_obs(xy_observations) # Discretize again.
329 | assert xy_observations.shape[1] == 91
330 | elif 'maze2d' in env.env_name:
331 | ob_xy_scaled = ob_xy / 3.5
332 | ob_xy_scaled = ob_xy_scaled[:, [1, 0]]
333 | xy_observations[:, :2] = ob_xy_scaled
334 | assert xy_observations.shape[1] == 4 # (x, y, vx, vy)
335 | values = value_fn(xy_observations, goal_observation) # (N*M, 1)
336 |
337 | x, y = ob_xy[:, 0], ob_xy[:, 1]
338 | x = x.reshape(N, M)
339 | y = y.reshape(N, M) * 0.975 + 0.7
340 | values = values.reshape(N, M)
341 | mesh = ax.pcolormesh(x, y, values, cmap='viridis')
342 |
343 | env.draw(ax, scale=0.95)
344 |
345 |
346 | # ===================================
347 | # PLOT TRAJECTORIES
348 | # ===================================
349 |
350 | # Makes an image of the trajectory the Ant follows.
351 | def trajectory_image(env, trajectories, **kwargs):
352 | fig = plt.figure(tight_layout=True)
353 | canvas = FigureCanvas(fig)
354 |
355 | plot_trajectories(env, trajectories, fig, plt.gca(), **kwargs)
356 |
357 | plt.tight_layout()
358 | image = get_canvas_image(canvas)
359 | plt.close(fig)
360 | return image
361 |
362 | # Helper that plots the XY coordinates as scatter plots.
363 | def plot_trajectories(env, trajectories, fig, ax, color_list=None):
364 | if color_list is None:
365 | from itertools import cycle
366 | color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
367 | color_list = cycle(color_cycle)
368 |
369 | for color, trajectory in zip(color_list, trajectories):
370 | obs = np.array(trajectory['observation'])
371 |
372 | # convert back to xy?
373 | if 'ant' in env.env_name:
374 | all_x = []
375 | all_y = []
376 | for info in trajectory['info']:
377 | all_x.append(info['x'])
378 | all_y.append(info['y'])
379 | all_x = np.array(all_x)
380 | all_y = np.array(all_y)
381 | elif 'maze2d' in env.env_name:
382 | all_x = obs[:, 1] * 4 - 3.2
383 | all_y = obs[:, 0] * 4 - 3.2
384 | ax.scatter(all_x, all_y, s=5, c=color, alpha=0.2)
385 | ax.scatter(all_x[-1], all_y[-1], s=50, c=color, marker='*', alpha=1, edgecolors='black')
386 |
387 | env.draw(ax)
--------------------------------------------------------------------------------
/common/envs/d4rl/d4rl_utils.py:
--------------------------------------------------------------------------------
1 | import d4rl
2 | import d4rl.gym_mujoco
3 | import gym
4 | import numpy as np
5 | from jax import tree_util
6 |
7 |
8 | import fre.common.envs.d4rl.d4rl_ant as d4rl_ant
9 | from fre.common.dataset import Dataset
10 |
11 |
12 | # Note on AntMaze. Reward = 1 at the goal, and Terminal = 1 at the goal.
13 | # Masks = Does the episode end due to final state?
14 | # Dones_float = Does the episode end due to time limit? OR does the episode end due to final state?
15 | def get_dataset(env: gym.Env, env_name: str, clip_to_eps: bool = True,
16 | eps: float = 1e-5, dataset=None, filter_terminals=False, obs_dtype=np.float32):
17 | if dataset is None:
18 | dataset = d4rl.qlearning_dataset(env)
19 |
20 | if clip_to_eps:
21 | lim = 1 - eps
22 | dataset['actions'] = np.clip(dataset['actions'], -lim, lim)
23 |
24 | # Mask everything that is marked as a terminal state.
25 | # For AntMaze, this should mask the end of each trajectory.
26 | masks = 1.0 - dataset['terminals']
27 |
28 | # In the AntMaze data, terminal is 1 when at the goal. But the episode doesn't end.
29 | # This just ensures that we treat AntMaze trajectories as non-ending.
30 | if "antmaze" in env_name or "maze2d" in env_name:
31 | dataset['terminals'] = np.zeros_like(dataset['terminals'])
32 |
33 | # if 'antmaze' in env_name:
34 | # print("Discretizing AntMaze observations.")
35 | # print("Raw observations looks like", dataset['observations'].shape[1:])
36 | # dataset['observations'] = d4rl_ant.discretize_obs(dataset['observations'])
37 | # dataset['next_observations'] = d4rl_ant.discretize_obs(dataset['next_observations'])
38 | # print("Discretized observations looks like", dataset['observations'].shape[1:])
39 |
40 | # Compute dones if terminal OR orbservation jumps.
41 | dones_float = np.zeros_like(dataset['rewards'])
42 |
43 | imputed_next_observations = np.roll(dataset['observations'], -1, axis=0)
44 | same_obs = np.all(np.isclose(imputed_next_observations, dataset['next_observations'], atol=1e-5), axis=-1)
45 | dones_float = 1.0 - same_obs.astype(np.float32)
46 | dones_float += dataset['terminals']
47 | dones_float[-1] = 1.0
48 | dones_float = np.clip(dones_float, 0.0, 1.0)
49 |
50 | observations = dataset['observations'].astype(obs_dtype)
51 | next_observations = dataset['next_observations'].astype(obs_dtype)
52 |
53 | return Dataset.create(
54 | observations=observations,
55 | actions=dataset['actions'].astype(np.float32),
56 | rewards=dataset['rewards'].astype(np.float32),
57 | masks=masks.astype(np.float32),
58 | dones_float=dones_float.astype(np.float32),
59 | next_observations=next_observations,
60 | )
61 |
62 | def get_normalization(dataset):
63 | returns = []
64 | ret = 0
65 | for r, term in zip(dataset['rewards'], dataset['dones_float']):
66 | ret += r
67 | if term:
68 | returns.append(ret)
69 | ret = 0
70 | return (max(returns) - min(returns)) / 1000
71 |
72 | def normalize_dataset(env_name, dataset):
73 | print("Normalizing", env_name)
74 | if 'antmaze' in env_name or 'maze2d' in env_name:
75 | return dataset.copy({'rewards': dataset['rewards']- 1.0})
76 | else:
77 | normalizing_factor = get_normalization(dataset)
78 | print(f'Normalizing factor: {normalizing_factor}')
79 | dataset = dataset.copy({'rewards': dataset['rewards'] / normalizing_factor})
80 | return dataset
81 |
82 | # Flattens environment with a dictionary of observation,goal to a single concatenated observation.
83 | class GoalReachingFlat(gym.Wrapper):
84 | """A wrapper that maps actions from [-1,1] to [low, hgih]."""
85 | def __init__(self, env):
86 | super().__init__(env)
87 | self.observation_space = gym.spaces.Box(
88 | low=-np.inf, high=np.inf, shape=(self.observation_space['observation'].shape[0] + self.observation_space['goal'].shape[0],), dtype=np.float32)
89 |
90 | def step(self, action):
91 | ob, reward, done, info = self.env.step(action)
92 | ob_flat = np.concatenate([ob['observation'], ob['goal']])
93 | return ob_flat, reward, done, info
94 |
95 | def reset(self, **kwargs):
96 | ob = self.env.reset(**kwargs)
97 | ob_flat = np.concatenate([ob['observation'], ob['goal']])
98 | return ob_flat
99 |
100 | def parse_trajectories(dataset):
101 | trajectory_ids = np.where(dataset['dones_float'] == 1)[0] + 1
102 | trajectory_ids = np.concatenate([[0], trajectory_ids])
103 | num_trajectories = trajectory_ids.shape[0] - 1
104 | print("There are {} trajectories. Some traj lens are {}".format(num_trajectories, [trajectory_ids[i + 1] - trajectory_ids[i] for i in range(min(5, num_trajectories))]))
105 | trajectories = []
106 | for i in range(len(trajectory_ids) - 1):
107 | trajectories.append(tree_util.tree_map(lambda arr: arr[trajectory_ids[i]:trajectory_ids[i + 1]], dataset._dict))
108 | return trajectories
109 |
110 | class KitchenRenderWrapper(gym.Wrapper):
111 | def render(self, *args, **kwargs):
112 | from dm_control.mujoco import engine
113 | camera = engine.MovableCamera(self.sim, 1920, 2560)
114 | camera.set_pose(distance=2.2, lookat=[-0.2, .5, 2.], azimuth=70, elevation=-35)
115 | img = camera.render()
116 | return img
117 |
--------------------------------------------------------------------------------
/common/envs/data_transforms.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | #
3 | # Helpful utilities for processing actions, observations.
4 | #
5 | ###############################
6 |
7 | import numpy as np
8 | import jax.numpy as jnp
9 |
10 | class ActionTransform():
11 | pass
12 |
13 | class ActionDiscretizeBins(ActionTransform):
14 | def __init__(self, bins_per_dim, action_dim):
15 | self.bins_per_dim = bins_per_dim
16 | self.action_dim = action_dim
17 | self.bins = np.linspace(-1, 1, bins_per_dim + 1)
18 |
19 | # Assumes action is in [-1, 1].
20 | def action_to_ids(self, action):
21 | ids = np.digitize(action, self.bins) - 1
22 | ids = np.clip(ids, 0, self.bins_per_dim - 1)
23 | return ids
24 |
25 | def ids_to_action(self, ids):
26 | action = (self.bins[ids] + self.bins[ids + 1]) / 2
27 | return action
28 |
29 | class ActionDiscretizeCluster(ActionTransform):
30 | def __init__(self, num_clusters, data_actions):
31 | self.num_clusters = num_clusters
32 | assert len(data_actions.shape) == 2 # (data_size, action_dim)
33 | print("Clustering actions of shape", data_actions.shape)
34 |
35 | # Cluster the data.
36 | from sklearn.cluster import KMeans
37 | kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(data_actions)
38 | self.centers = kmeans.cluster_centers_
39 | # self.labels = kmeans.labels_
40 | self.centers = jnp.array(self.centers)
41 | print("Average cluster error is", kmeans.inertia_ / len(data_actions))
42 | print("Average cluster error per dimension is", (kmeans.inertia_ / len(data_actions)) / data_actions.shape[1])
43 | # print(self.centers.shape)
44 |
45 | def action_to_ids(self, action):
46 | if len(action.shape) == 1:
47 | action = action[None]
48 | assert len(action.shape) == 2 # (batch, action_dim,)
49 | # Find the closest cluster center.
50 | dists = jnp.linalg.norm(self.centers[None] - action[:, None], axis=-1)
51 | ids = jnp.argmin(dists, axis=-1)
52 | return ids
53 |
54 | def ids_to_action(self, ids):
55 | action = self.centers[ids]
56 | return action
57 |
58 | # Test
59 | # action_discretize_bins = ActionDiscretizeBins(32, 2)
60 | # action = np.array([-1, -0.999, -0.5, 0, 0.5, 0.999, 1])
61 | # ids = action_discretize_bins.action_to_ids(action)
62 | # print(ids)
63 | # action_recreate = action_discretize_bins.ids_to_action(ids)
64 | # print(action_recreate)
65 | # assert np.abs(action - action_recreate).max() < 0.1
66 |
67 | # action_discretize_cluster = ActionDiscretizeCluster(32, np.random.uniform(low=-1, high=1, size=(10000, 1)))
68 | # action = np.array([-1, -0.999, -0.5, 0, 0.5, 0.999, 1])[:, None] # [7, 1]
69 | # ids = action_discretize_cluster.action_to_ids(action)
70 | # print(ids)
71 | # action_recreate = action_discretize_cluster.ids_to_action(ids)
72 | # print(action_recreate)
73 | # assert np.abs(action - action_recreate).max() < 0.1
--------------------------------------------------------------------------------
/common/envs/dmc/__init__.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym.envs.registration import register
3 |
4 |
5 | def make(
6 | domain_name,
7 | task_name,
8 | seed=1,
9 | visualize_reward=True,
10 | from_pixels=False,
11 | height=84,
12 | width=84,
13 | camera_id=0,
14 | frame_skip=1,
15 | episode_length=1000,
16 | environment_kwargs=None,
17 | time_limit=None,
18 | channels_first=True
19 | ):
20 | env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
21 |
22 | if from_pixels:
23 | assert not visualize_reward, 'cannot use visualize reward when learning from pixels'
24 |
25 | # shorten episode length
26 | max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
27 |
28 | if not env_id in gym.envs.registry.env_specs:
29 | task_kwargs = {}
30 | if seed is not None:
31 | task_kwargs['random'] = seed
32 | if time_limit is not None:
33 | task_kwargs['time_limit'] = time_limit
34 | register(
35 | id=env_id,
36 | # entry_point='dmc2gym.wrappers:DMCWrapper',
37 | entry_point='fre.common.envs.dmc.wrappers:DMCWrapper',
38 | kwargs=dict(
39 | domain_name=domain_name,
40 | task_name=task_name,
41 | task_kwargs=task_kwargs,
42 | environment_kwargs=environment_kwargs,
43 | visualize_reward=visualize_reward,
44 | from_pixels=from_pixels,
45 | height=height,
46 | width=width,
47 | camera_id=camera_id,
48 | frame_skip=frame_skip,
49 | channels_first=channels_first,
50 | ),
51 | max_episode_steps=max_episode_steps,
52 | )
53 | return gym.make(env_id)
54 |
--------------------------------------------------------------------------------
/common/envs/dmc/jaco.py:
--------------------------------------------------------------------------------
1 | """A task where the goal is to move the hand close to a target prop or site."""
2 |
3 | import collections
4 |
5 | from dm_control import composer
6 | from dm_control.composer import initializers
7 | from dm_control.composer.variation import distributions
8 | from dm_control.entities import props
9 | from dm_control.manipulation.shared import arenas
10 | from dm_control.manipulation.shared import cameras
11 | from dm_control.manipulation.shared import constants
12 | from dm_control.manipulation.shared import observations
13 | from dm_control.manipulation.shared import robots
14 | from dm_control.manipulation.shared import workspaces
15 | from dm_control.utils import rewards
16 | from dm_env import specs
17 | import numpy as np
18 |
19 | _ReachWorkspace = collections.namedtuple(
20 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
21 |
22 | # Ensures that the props are not touching the table before settling.
23 | _PROP_Z_OFFSET = 0.001
24 |
25 | _DUPLO_WORKSPACE = _ReachWorkspace(
26 | target_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, _PROP_Z_OFFSET),
27 | upper=(0.1, 0.1, _PROP_Z_OFFSET)),
28 | tcp_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, 0.2),
29 | upper=(0.1, 0.1, 0.4)),
30 | arm_offset=robots.ARM_OFFSET)
31 |
32 | _SITE_WORKSPACE = _ReachWorkspace(
33 | target_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
34 | upper=(0.2, 0.2, 0.4)),
35 | tcp_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
36 | upper=(0.2, 0.2, 0.4)),
37 | arm_offset=robots.ARM_OFFSET)
38 |
39 | _TARGET_RADIUS = 0.05
40 | _TIME_LIMIT = 10.
41 |
42 | TASKS = [('reach_top_left', np.array([-0.09, 0.09, _PROP_Z_OFFSET])),
43 | ('reach_top_right', np.array([0.09, 0.09, _PROP_Z_OFFSET])),
44 | ('reach_bottom_left', np.array([-0.09, -0.09, _PROP_Z_OFFSET])),
45 | ('reach_bottom_right', np.array([0.09, -0.09, _PROP_Z_OFFSET]))]
46 |
47 |
48 | def make(task_id, obs_type, seed):
49 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
50 | task = _reach(task_id, obs_settings=obs_settings, use_site=True)
51 | return composer.Environment(task,
52 | time_limit=_TIME_LIMIT,
53 | random_state=seed)
54 |
55 |
56 | class MultiTaskReach(composer.Task):
57 | """Bring the hand close to a target prop or site."""
58 |
59 | def __init__(self, task_id, arena, arm, hand, prop, obs_settings,
60 | workspace, control_timestep):
61 | """Initializes a new `Reach` task.
62 | Args:
63 | arena: `composer.Entity` instance.
64 | arm: `robot_base.RobotArm` instance.
65 | hand: `robot_base.RobotHand` instance.
66 | prop: `composer.Entity` instance specifying the prop to reach to, or None
67 | in which case the target is a fixed site whose position is specified by
68 | the workspace.
69 | obs_settings: `observations.ObservationSettings` instance.
70 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
71 | control_timestep: Float specifying the control timestep in seconds.
72 | """
73 | self._arena = arena
74 | self._arm = arm
75 | self._hand = hand
76 | self._arm.attach(self._hand)
77 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
78 | self.control_timestep = control_timestep
79 | self._tcp_initializer = initializers.ToolCenterPointInitializer(
80 | self._hand,
81 | self._arm,
82 | position=distributions.Uniform(*workspace.tcp_bbox),
83 | quaternion=workspaces.DOWN_QUATERNION)
84 |
85 | # Add custom camera observable.
86 | self._task_observables = cameras.add_camera_observables(
87 | arena, obs_settings, cameras.FRONT_CLOSE)
88 |
89 | if task_id == 'reach_multitask':
90 | self._targets = [target for (_, target) in TASKS]
91 | else:
92 | self._targets = [
93 | target for (task, target) in TASKS if task == task_id
94 | ]
95 | assert len(self._targets) > 0
96 |
97 | #target_pos_distribution = distributions.Uniform(*TASKS[task_id])
98 | self._prop = prop
99 | if prop:
100 | # The prop itself is used to visualize the target location.
101 | self._make_target_site(parent_entity=prop, visible=False)
102 | self._target = self._arena.add_free_entity(prop)
103 | self._prop_placer = initializers.PropPlacer(
104 | props=[prop],
105 | position=target_pos_distribution,
106 | quaternion=workspaces.uniform_z_rotation,
107 | settle_physics=True)
108 | else:
109 | if len(self._targets) == 1:
110 | self._target = self._make_target_site(parent_entity=arena,
111 | visible=True)
112 |
113 | #obs = observable.MJCFFeature('pos', self._target)
114 | # obs.configure(**obs_settings.prop_pose._asdict())
115 | #self._task_observables['target_position'] = obs
116 |
117 | # Add sites for visualizing the prop and target bounding boxes.
118 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
119 | lower=workspace.tcp_bbox.lower,
120 | upper=workspace.tcp_bbox.upper,
121 | rgba=constants.GREEN,
122 | name='tcp_spawn_area')
123 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
124 | lower=workspace.target_bbox.lower,
125 | upper=workspace.target_bbox.upper,
126 | rgba=constants.BLUE,
127 | name='target_spawn_area')
128 |
129 | def _make_target_site(self, parent_entity, visible):
130 | return workspaces.add_target_site(
131 | body=parent_entity.mjcf_model.worldbody,
132 | radius=_TARGET_RADIUS,
133 | visible=visible,
134 | rgba=constants.RED,
135 | name='target_site')
136 |
137 | @property
138 | def root_entity(self):
139 | return self._arena
140 |
141 | @property
142 | def arm(self):
143 | return self._arm
144 |
145 | @property
146 | def hand(self):
147 | return self._hand
148 |
149 | def get_reward_spec(self):
150 | n = len(self._targets)
151 | return specs.Array(shape=(n,), dtype=np.float32, name='reward')
152 |
153 | @property
154 | def task_observables(self):
155 | return self._task_observables
156 |
157 | def get_reward(self, physics):
158 | hand_pos = physics.bind(self._hand.tool_center_point).xpos
159 | rews = []
160 | for target_pos in self._targets:
161 | distance = np.linalg.norm(hand_pos - target_pos)
162 | reward = rewards.tolerance(distance,
163 | bounds=(0, _TARGET_RADIUS),
164 | margin=_TARGET_RADIUS)
165 | rews.append(reward)
166 | rews = np.array(rews).astype(np.float32)
167 | if len(self._targets) == 1:
168 | return rews[0]
169 | return rews
170 |
171 | def initialize_episode(self, physics, random_state):
172 | self._hand.set_grasp(physics, close_factors=random_state.uniform())
173 | self._tcp_initializer(physics, random_state)
174 | if self._prop:
175 | self._prop_placer(physics, random_state)
176 | else:
177 | if len(self._targets) == 1:
178 | physics.bind(self._target).pos = self._targets[0]
179 |
180 |
181 | def _reach(task_id, obs_settings, use_site):
182 | """Configure and instantiate a `Reach` task.
183 | Args:
184 | obs_settings: An `observations.ObservationSettings` instance.
185 | use_site: Boolean, if True then the target will be a fixed site, otherwise
186 | it will be a moveable Duplo brick.
187 | Returns:
188 | An instance of `reach.Reach`.
189 | """
190 | arena = arenas.Standard()
191 | arm = robots.make_arm(obs_settings=obs_settings)
192 | hand = robots.make_hand(obs_settings=obs_settings)
193 | if use_site:
194 | workspace = _SITE_WORKSPACE
195 | prop = None
196 | else:
197 | workspace = _DUPLO_WORKSPACE
198 | prop = props.Duplo(observable_options=observations.make_options(
199 | obs_settings, observations.FREEPROP_OBSERVABLES))
200 | task = MultiTaskReach(task_id,
201 | arena=arena,
202 | arm=arm,
203 | hand=hand,
204 | prop=prop,
205 | obs_settings=obs_settings,
206 | workspace=workspace,
207 | control_timestep=constants.CONTROL_TIMESTEP)
208 | return task
--------------------------------------------------------------------------------
/common/envs/dmc/wrappers.py:
--------------------------------------------------------------------------------
1 | from gym import core, spaces
2 | from dm_control import suite
3 | from dm_env import specs
4 | import numpy as np
5 |
6 |
7 | def _spec_to_box(spec, dtype):
8 | def extract_min_max(s):
9 | assert s.dtype == np.float64 or s.dtype == np.float32
10 | dim = int(np.prod(s.shape))
11 | if type(s) == specs.Array:
12 | bound = np.inf * np.ones(dim, dtype=np.float32)
13 | return -bound, bound
14 | elif type(s) == specs.BoundedArray:
15 | zeros = np.zeros(dim, dtype=np.float32)
16 | return s.minimum + zeros, s.maximum + zeros
17 |
18 | mins, maxs = [], []
19 | for s in spec:
20 | mn, mx = extract_min_max(s)
21 | mins.append(mn)
22 | maxs.append(mx)
23 | low = np.concatenate(mins, axis=0).astype(dtype)
24 | high = np.concatenate(maxs, axis=0).astype(dtype)
25 | assert low.shape == high.shape
26 | return spaces.Box(low, high, dtype=dtype)
27 |
28 |
29 | def _flatten_obs(obs):
30 | obs_pieces = []
31 | for v in obs.values():
32 | flat = np.array([v]) if np.isscalar(v) else v.ravel()
33 | obs_pieces.append(flat)
34 | return np.concatenate(obs_pieces, axis=0)
35 |
36 |
37 | class DMCWrapper(core.Env):
38 | def __init__(
39 | self,
40 | domain_name,
41 | task_name,
42 | task_kwargs=None,
43 | visualize_reward={},
44 | from_pixels=False,
45 | height=84,
46 | width=84,
47 | camera_id=0,
48 | frame_skip=1,
49 | environment_kwargs=None,
50 | channels_first=True
51 | ):
52 | assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
53 | self._from_pixels = from_pixels
54 | self._height = height
55 | self._width = width
56 | self._camera_id = camera_id
57 | self._frame_skip = frame_skip
58 | self._channels_first = channels_first
59 |
60 | # create task
61 | if domain_name == 'jaco':
62 | import fre.common.envs.dmc.jaco as jaco
63 | self._env = jaco.make(task_id=task_name, obs_type=jaco.observations.PERFECT_FEATURES, seed=1)
64 | else:
65 | self._env = suite.load(
66 | domain_name=domain_name,
67 | task_name=task_name,
68 | task_kwargs=task_kwargs,
69 | visualize_reward=visualize_reward,
70 | environment_kwargs=environment_kwargs
71 | )
72 |
73 | # true and normalized action spaces
74 | self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32)
75 | self._norm_action_space = spaces.Box(
76 | low=-1.0,
77 | high=1.0,
78 | shape=self._true_action_space.shape,
79 | dtype=np.float32
80 | )
81 |
82 | # create observation space
83 | if from_pixels:
84 | shape = [3, height, width] if channels_first else [height, width, 3]
85 | self._observation_space = spaces.Box(
86 | low=0, high=255, shape=shape, dtype=np.uint8
87 | )
88 | else:
89 | self._observation_space = _spec_to_box(
90 | self._env.observation_spec().values(),
91 | np.float64
92 | )
93 |
94 | self._state_space = _spec_to_box(
95 | self._env.observation_spec().values(),
96 | np.float64
97 | )
98 |
99 | self.current_state = None
100 |
101 | # set seed
102 | self.seed(seed=task_kwargs.get('random', 1))
103 |
104 | def __getattr__(self, name):
105 | return getattr(self._env, name)
106 |
107 | def _get_obs(self, time_step):
108 | if self._from_pixels:
109 | obs = self.render(
110 | height=self._height,
111 | width=self._width,
112 | camera_id=self._camera_id
113 | )
114 | if self._channels_first:
115 | obs = obs.transpose(2, 0, 1).copy()
116 | else:
117 | obs = _flatten_obs(time_step.observation)
118 |
119 | return obs
120 |
121 | def _convert_action(self, action):
122 | action = action.astype(np.float32)
123 | true_delta = self._true_action_space.high - self._true_action_space.low
124 | norm_delta = self._norm_action_space.high - self._norm_action_space.low
125 | action = (action - self._norm_action_space.low) / norm_delta
126 | action = action * true_delta + self._true_action_space.low
127 | action = action.astype(np.float32)
128 | return action
129 |
130 | @property
131 | def observation_space(self):
132 | return self._observation_space
133 |
134 | @property
135 | def state_space(self):
136 | return self._state_space
137 |
138 | @property
139 | def action_space(self):
140 | return self._norm_action_space
141 |
142 | @property
143 | def reward_range(self):
144 | return 0, self._frame_skip
145 |
146 | def seed(self, seed):
147 | self._true_action_space.seed(seed)
148 | self._norm_action_space.seed(seed)
149 | self._observation_space.seed(seed)
150 |
151 | def step(self, action):
152 | assert self._norm_action_space.contains(action)
153 | action = self._convert_action(action)
154 | assert self._true_action_space.contains(action)
155 | reward = 0
156 | extra = {'internal_state': self._env.physics.get_state().copy()}
157 |
158 | for _ in range(self._frame_skip):
159 | time_step = self._env.step(action)
160 | reward += time_step.reward or 0
161 | done = time_step.last()
162 | if done:
163 | break
164 | obs = self._get_obs(time_step)
165 | self.current_state = _flatten_obs(time_step.observation)
166 | extra['discount'] = time_step.discount
167 | return obs, reward, done, extra
168 |
169 | def reset(self):
170 | time_step = self._env.reset()
171 | self.current_state = _flatten_obs(time_step.observation)
172 | obs = self._get_obs(time_step)
173 | return obs
174 |
175 | def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
176 | assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
177 | height = height or self._height
178 | width = width or self._width
179 | camera_id = camera_id or self._camera_id
180 | return self._env.physics.render(
181 | height=height, width=width, camera_id=camera_id
182 | )
183 |
--------------------------------------------------------------------------------
/common/envs/env_helper.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | #
3 | # Helper that initializes environments with the proper imports.
4 | # Returns an environment that is:
5 | # - Action normalized.
6 | # - Video rendering works.
7 | # - Episode monitor attached.
8 | #
9 | ###############################
10 |
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import os
14 | import os.path as osp
15 | import gym
16 | import numpy as np
17 | import functools as ft
18 | from fre.common.train_state import NormalizeActionWrapper
19 | from fre.common.envs.wrappers import EpisodeMonitor
20 |
21 |
22 | # Supported envs:
23 | env_list = [
24 | # From Gym
25 | 'HalfCheetah-v2',
26 | 'Hopper-v2',
27 | 'Walker2d-v2',
28 | 'Pendulum-v1',
29 | 'CartPole-v1',
30 | 'Acrobot_v1',
31 | 'MountainCar-v0',
32 | 'MountainCarContinuous-v0',
33 | # From DMC
34 | 'pendulum_swingup',
35 | 'acrobot_swingup',
36 | 'acrobot_swingup_sparse',
37 | 'cartpole_swingup', # has exorl dataset.
38 | 'cartpole_swingup_sparse',
39 | 'pointmass_easy',
40 | 'reacher_easy',
41 | 'reacher_hard',
42 | 'cheetah_run', # has exorl dataset.
43 | 'hopper_hop',
44 | 'walker_stand', # has exorl dataset.
45 | 'walker_walk', # has exorl dataset.
46 | 'walker_run', # has exorl dataset.
47 | 'quadruped_walk', # has exorl dataset.
48 | 'quadruped_run', # has exorl dataset.
49 | 'humanoid_stand',
50 | 'humanoid_run',
51 | 'jaco_reach_top_left', # has exorl dataset.
52 | 'jaco_reach_bottom_right', # has exorl dataset.
53 | # TODO: Atari games
54 | # Offline D4RL envs
55 | 'antmaze-large-diverse-v2', # Start in the corner, goal is in the top corner.
56 | 'gc-antmaze-large-diverse-v2', # Start in the corner, goal is in the top corner.
57 | 'center-antmaze-large-diverse-v2', # Start in the center, goal is UNDEFINED (this is for RARE rewards).
58 | 'maze2d-large-v1',
59 | 'gc-maze2d-large-v1',
60 | 'center-maze2d-large-v1',
61 | # D4RL mujoco
62 | 'halfcheetah-expert-v2',
63 | 'walker2d-expert-v2',
64 | 'hopper-expert-v2',
65 | 'kitchen-complete-v0' # broken
66 | 'kitchen-mixed-v0' # broken
67 | ]
68 |
69 | # Making an environment.
70 | def make_env(env_name, **kwargs):
71 | if 'exorl' in env_name:
72 | import os
73 | os.environ['DISPLAY'] = ':0'
74 | import fre.common.envs.exorl.dmc as dmc
75 | _, env_name, task_name = env_name.split('_', 2)
76 | def make_env(env_name, task_name):
77 | env = dmc.make(f'{env_name}_{task_name}', obs_type='states', frame_stack=1, action_repeat=1, seed=0)
78 | env = dmc.DMCWrapper(env, 0)
79 | return env
80 | env = make_env(env_name, task_name)
81 | env.reset()
82 | elif '_' in env_name: # DMC Control
83 | import fre.common.envs.dmc as dmc2gym
84 | import os
85 | os.environ['DISPLAY'] = ':0'
86 | suite, task = env_name.split('_', 1)
87 | print(suite, task)
88 | if suite == 'pointmass':
89 | suite = 'point_mass'
90 | frame_skip = kwargs['frame_skip'] if 'frame_skip' in kwargs else 2
91 | visualize_reward = kwargs['visualize_reward'] if 'visualize_reward' in kwargs else False
92 | env = dmc2gym.make(
93 | domain_name=suite,
94 | task_name=task, seed=1,
95 | frame_skip=frame_skip,
96 | visualize_reward=visualize_reward)
97 | env = NormalizeActionWrapper(env)
98 | elif 'antmaze' in env_name:
99 | from fre.common.envs.d4rl.d4rl_ant import CenteredMaze, GoalReachingMaze, MazeWrapper
100 | if 'gc-antmaze' in env_name:
101 | env = GoalReachingMaze('antmaze-large-diverse-v2')
102 | elif 'center-antmaze' in env_name:
103 | env = CenteredMaze('antmaze-large-diverse-v2')
104 | else:
105 | env = MazeWrapper('antmaze-large-diverse-v2')
106 | elif 'maze2d' in env_name:
107 | from fre.common.envs.d4rl.d4rl_ant import CenteredMaze, GoalReachingMaze, MazeWrapper
108 | if 'gc-maze2d' in env_name:
109 | env = GoalReachingMaze('maze2d-large-v1')
110 | elif 'center-maze2d' in env_name:
111 | env = CenteredMaze('maze2d-large-v1')
112 | else:
113 | env = CenteredMaze('maze2d-large-v1', start_loc='original')
114 | elif 'halfcheetah-' in env_name or 'walker2d-' in env_name or 'hopper-' in env_name: # D4RL Mujoco
115 | import d4rl
116 | import d4rl.gym_mujoco
117 | env = gym.make(env_name)
118 | elif 'kitchen' in env_name: # This doesn't work yet.
119 | import os
120 | os.environ['DISPLAY'] = ':0'
121 | from fre.common.envs.d4rl.d4rl_utils import KitchenRenderWrapper
122 | env = KitchenRenderWrapper(gym.make(env_name))
123 | elif 'bandit' in env_name:
124 | from fre.common.envs.bandit.bandit import BanditEnv
125 | env = BanditEnv()
126 | else:
127 | env = gym.make(env_name)
128 | env = EpisodeMonitor(env)
129 | return env
130 |
131 | # For getting offline data.
132 | def get_dataset(env, env_name, **kwargs):
133 | if 'exorl' in env_name:
134 | from fre.common.envs.exorl.exorl_utils import get_dataset
135 | env_name_short = env_name.split('_', 1)[1]
136 | return get_dataset(env, env_name_short, **kwargs)
137 | elif 'ant' in env_name or 'maze2d' in env_name or 'kitchen' in env_name or 'halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name:
138 | from fre.common.envs.d4rl.d4rl_utils import get_dataset, normalize_dataset
139 | dataset = get_dataset(env, env_name, **kwargs)
140 | dataset = normalize_dataset(env_name, dataset)
141 | return dataset
142 | elif 'cartpole' in env_name or 'cheetah' in env_name or 'jaco' in env_name or 'quadruped' in env_name or 'walker' in env_name:
143 | from fre.common.envs.exorl.exorl_utils import get_dataset
144 | return get_dataset(env, env_name, **kwargs)
145 |
146 | def make_vec_env(env_name, num_envs, **kwargs):
147 | from gym.vector import SyncVectorEnv
148 | envs = [lambda : make_env(env_name, **kwargs) for _ in range(num_envs)]
149 | env = SyncVectorEnv(envs)
150 | return env
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/__init__.py:
--------------------------------------------------------------------------------
1 | import typing as tp
2 | from . import cheetah
3 | from . import walker
4 | from . import hopper
5 | from . import quadruped
6 | from . import jaco
7 |
8 |
9 | def make(domain, task,
10 | task_kwargs=None,
11 | environment_kwargs=None,
12 | visualize_reward: bool = False):
13 |
14 | if domain == 'cheetah':
15 | return cheetah.make(task,
16 | task_kwargs=task_kwargs,
17 | environment_kwargs=environment_kwargs,
18 | visualize_reward=visualize_reward)
19 | elif domain == 'walker':
20 | return walker.make(task,
21 | task_kwargs=task_kwargs,
22 | environment_kwargs=environment_kwargs,
23 | visualize_reward=visualize_reward)
24 | elif domain == 'hopper':
25 | return hopper.make(task,
26 | task_kwargs=task_kwargs,
27 | environment_kwargs=environment_kwargs,
28 | visualize_reward=visualize_reward)
29 | elif domain == 'quadruped':
30 | return quadruped.make(task,
31 | task_kwargs=task_kwargs,
32 | environment_kwargs=environment_kwargs,
33 | visualize_reward=visualize_reward)
34 | elif domain == 'point_mass_maze':
35 | return point_mass_maze.make(task,
36 | task_kwargs=task_kwargs,
37 | environment_kwargs=environment_kwargs,
38 | visualize_reward=visualize_reward)
39 |
40 | else:
41 | raise ValueError(f'{task} not found')
42 |
43 | assert None
44 |
45 |
46 | def make_jaco(task, obs_type, seed) -> tp.Any:
47 | return jaco.make(task, obs_type, seed)
48 |
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/cheetah.py:
--------------------------------------------------------------------------------
1 | """Cheetah Domain."""
2 |
3 | import collections
4 | import os
5 | import typing as tp
6 | from typing import Any, Tuple
7 |
8 | from dm_control import mujoco
9 | from dm_control.rl import control
10 | from dm_control.suite import base
11 | from dm_control.suite import common
12 | from dm_control.utils import containers
13 | from dm_control.utils import rewards
14 | from dm_control.utils import io as resources
15 |
16 | _DEFAULT_TIME_LIMIT: int
17 | _RUN_SPEED: int
18 | _SPIN_SPEED: int
19 |
20 | # How long the simulation will run, in seconds.
21 | _DEFAULT_TIME_LIMIT = 10
22 |
23 | # Running speed above which reward is 1.
24 | _RUN_SPEED = 10
25 | _WALK_SPEED = 2
26 | _SPIN_SPEED = 5
27 |
28 | SUITE = containers.TaggedTasks()
29 |
30 |
31 | def make(task,
32 | task_kwargs=None,
33 | environment_kwargs=None,
34 | visualize_reward: bool = False):
35 | task_kwargs = task_kwargs or {}
36 | if environment_kwargs is not None:
37 | task_kwargs = task_kwargs.copy()
38 | task_kwargs['environment_kwargs'] = environment_kwargs
39 | env = SUITE[task](**task_kwargs)
40 | env.task.visualize_reward = visualize_reward
41 | return env
42 |
43 |
44 | def get_model_and_assets() -> Tuple[Any, Any]:
45 | """Returns a tuple containing the model XML string and a dict of assets."""
46 | root_dir = os.path.dirname(os.path.dirname(__file__))
47 | xml = resources.GetResource(
48 | os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml'))
49 | return xml, common.ASSETS
50 |
51 |
52 | @SUITE.add('benchmarking')
53 | def walk(time_limit: int = _DEFAULT_TIME_LIMIT,
54 | random=None,
55 | environment_kwargs=None):
56 | """Returns the run task."""
57 | physics = Physics.from_xml_string(*get_model_and_assets())
58 | task = Cheetah(move_speed=_WALK_SPEED, forward=True, flip=False, random=random)
59 | environment_kwargs = environment_kwargs or {}
60 | return control.Environment(physics,
61 | task,
62 | time_limit=time_limit,
63 | **environment_kwargs)
64 |
65 |
66 | @SUITE.add('benchmarking')
67 | def walk_backward(time_limit: int = _DEFAULT_TIME_LIMIT,
68 | random=None,
69 | environment_kwargs=None):
70 | """Returns the run task."""
71 | physics = Physics.from_xml_string(*get_model_and_assets())
72 | task = Cheetah(move_speed=_WALK_SPEED, forward=False, flip=False, random=random)
73 | environment_kwargs = environment_kwargs or {}
74 | return control.Environment(physics,
75 | task,
76 | time_limit=time_limit,
77 | **environment_kwargs)
78 |
79 |
80 | @SUITE.add('benchmarking')
81 | def run_backward(time_limit: int = _DEFAULT_TIME_LIMIT,
82 | random=None,
83 | environment_kwargs=None):
84 | """Returns the run task."""
85 | physics = Physics.from_xml_string(*get_model_and_assets())
86 | task = Cheetah(forward=False, flip=False, random=random)
87 | environment_kwargs = environment_kwargs or {}
88 | return control.Environment(physics,
89 | task,
90 | time_limit=time_limit,
91 | **environment_kwargs)
92 |
93 |
94 | @SUITE.add('benchmarking')
95 | def flip(time_limit: int = _DEFAULT_TIME_LIMIT,
96 | random=None,
97 | environment_kwargs=None):
98 | """Returns the run task."""
99 | physics = Physics.from_xml_string(*get_model_and_assets())
100 | task = Cheetah(move_speed=_WALK_SPEED, forward=True, flip=True, random=random)
101 | environment_kwargs = environment_kwargs or {}
102 | return control.Environment(physics,
103 | task,
104 | time_limit=time_limit,
105 | **environment_kwargs)
106 |
107 |
108 | @SUITE.add('benchmarking')
109 | def flip_backward(time_limit: int = _DEFAULT_TIME_LIMIT,
110 | random=None,
111 | environment_kwargs=None):
112 | """Returns the run task."""
113 | physics = Physics.from_xml_string(*get_model_and_assets())
114 | task = Cheetah(move_speed=_WALK_SPEED, forward=False, flip=True, random=random)
115 | environment_kwargs = environment_kwargs or {}
116 | return control.Environment(physics,
117 | task,
118 | time_limit=time_limit,
119 | **environment_kwargs)
120 |
121 |
122 | class Physics(mujoco.Physics):
123 | """Physics simulation with additional features for the Cheetah domain."""
124 |
125 | def speed(self) -> Any:
126 | """Returns the horizontal speed of the Cheetah."""
127 | return self.named.data.sensordata['torso_subtreelinvel'][0]
128 |
129 | def angmomentum(self) -> Any:
130 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
131 | return self.named.data.subtree_angmom['torso'][1]
132 |
133 |
134 | class Cheetah(base.Task):
135 | """A `Task` to train a running Cheetah."""
136 |
137 | def __init__(self, move_speed=_RUN_SPEED, forward=True, flip=False, random=None) -> None:
138 | self._move_speed = move_speed
139 | self._forward = 1 if forward else -1
140 | self._flip = flip
141 | super(Cheetah, self).__init__(random=random)
142 | self._timeout_progress = 0
143 |
144 | def initialize_episode(self, physics) -> None:
145 | """Sets the state of the environment at the start of each episode."""
146 | # The indexing below assumes that all joints have a single DOF.
147 | assert physics.model.nq == physics.model.njnt
148 | is_limited = physics.model.jnt_limited == 1
149 | lower, upper = physics.model.jnt_range[is_limited].T
150 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
151 |
152 | # Stabilize the model before the actual simulation.
153 | for _ in range(200):
154 | physics.step()
155 |
156 | physics.data.time = 0
157 | self._timeout_progress = 0
158 | super().initialize_episode(physics)
159 |
160 | def get_observation(self, physics) -> tp.Dict[str, Any]:
161 | """Returns an observation of the state, ignoring horizontal position."""
162 | obs = collections.OrderedDict()
163 | # Ignores horizontal position to maintain translational invariance.
164 | obs['position'] = physics.data.qpos[1:].copy()
165 | obs['velocity'] = physics.velocity()
166 | return obs
167 |
168 | def get_reward(self, physics) -> Any:
169 | """Returns a reward to the agent."""
170 | if self._flip:
171 | reward = rewards.tolerance(self._forward * physics.angmomentum(),
172 | bounds=(_SPIN_SPEED, float('inf')),
173 | margin=_SPIN_SPEED,
174 | value_at_margin=0,
175 | sigmoid='linear')
176 |
177 | else:
178 | reward = rewards.tolerance(self._forward * physics.speed(),
179 | bounds=(self._move_speed, float('inf')),
180 | margin=self._move_speed,
181 | value_at_margin=0,
182 | sigmoid='linear')
183 | return reward
184 |
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/cheetah.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/hopper.py:
--------------------------------------------------------------------------------
1 | """Hopper domain."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import collections
8 | import os
9 | import typing as tp
10 | from typing import Any, Tuple
11 |
12 | import numpy as np
13 | from dm_control import mujoco
14 | from dm_control.rl import control
15 | from dm_control.suite import base
16 | from dm_control.suite import common
17 | from dm_control.suite.utils import randomizers
18 | from dm_control.utils import containers
19 | from dm_control.utils import rewards
20 | from dm_control.utils import io as resources
21 |
22 | _CONTROL_TIMESTEP: float
23 | _DEFAULT_TIME_LIMIT: int
24 | _HOP_SPEED: int
25 | _SPIN_SPEED: int
26 | _STAND_HEIGHT: float
27 |
28 | SUITE = containers.TaggedTasks()
29 |
30 | _CONTROL_TIMESTEP = .02 # (Seconds)
31 |
32 | # Default duration of an episode, in seconds.
33 | _DEFAULT_TIME_LIMIT = 20
34 |
35 | # Minimal height of torso over foot above which stand reward is 1.
36 | _STAND_HEIGHT = 0.6
37 |
38 | # Hopping speed above which hop reward is 1.
39 | _HOP_SPEED = 2
40 | _SPIN_SPEED = 5
41 |
42 |
43 | def make(task,
44 | task_kwargs=None,
45 | environment_kwargs=None,
46 | visualize_reward: bool = False):
47 | task_kwargs = task_kwargs or {}
48 | if environment_kwargs is not None:
49 | task_kwargs = task_kwargs.copy()
50 | task_kwargs['environment_kwargs'] = environment_kwargs
51 | env = SUITE[task](**task_kwargs)
52 | env.task.visualize_reward = visualize_reward
53 | return env
54 |
55 |
56 | def get_model_and_assets() -> Tuple[Any, Any]:
57 | """Returns a tuple containing the model XML string and a dict of assets."""
58 | root_dir = os.path.dirname(os.path.dirname(__file__))
59 | xml = resources.GetResource(
60 | os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml'))
61 | return xml, common.ASSETS
62 |
63 |
64 | @SUITE.add('benchmarking')
65 | def hop_backward(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
66 | """Returns a Hopper that strives to hop forward."""
67 | physics = Physics.from_xml_string(*get_model_and_assets())
68 | task = Hopper(hopping=True, forward=False, flip=False, random=random)
69 | environment_kwargs = environment_kwargs or {}
70 | return control.Environment(physics,
71 | task,
72 | time_limit=time_limit,
73 | control_timestep=_CONTROL_TIMESTEP,
74 | **environment_kwargs)
75 |
76 |
77 | @SUITE.add('benchmarking')
78 | def flip(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
79 | """Returns a Hopper that strives to hop forward."""
80 | physics = Physics.from_xml_string(*get_model_and_assets())
81 | task = Hopper(hopping=True, forward=True, flip=True, random=random)
82 | environment_kwargs = environment_kwargs or {}
83 | return control.Environment(physics,
84 | task,
85 | time_limit=time_limit,
86 | control_timestep=_CONTROL_TIMESTEP,
87 | **environment_kwargs)
88 |
89 |
90 | @SUITE.add('benchmarking')
91 | def flip_backward(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
92 | """Returns a Hopper that strives to hop forward."""
93 | physics = Physics.from_xml_string(*get_model_and_assets())
94 | task = Hopper(hopping=True, forward=False, flip=True, random=random)
95 | environment_kwargs = environment_kwargs or {}
96 | return control.Environment(physics,
97 | task,
98 | time_limit=time_limit,
99 | control_timestep=_CONTROL_TIMESTEP,
100 | **environment_kwargs)
101 |
102 |
103 | class Physics(mujoco.Physics):
104 | """Physics simulation with additional features for the Hopper domain."""
105 |
106 | def height(self) -> Any:
107 | """Returns height of torso with respect to foot."""
108 | return (self.named.data.xipos['torso', 'z'] -
109 | self.named.data.xipos['foot', 'z'])
110 |
111 | def speed(self) -> Any:
112 | """Returns horizontal speed of the Hopper."""
113 | return self.named.data.sensordata['torso_subtreelinvel'][0]
114 |
115 | def touch(self) -> Any:
116 | """Returns the signals from two foot touch sensors."""
117 | return np.log1p(self.named.data.sensordata[['touch_toe',
118 | 'touch_heel']])
119 |
120 | def angmomentum(self) -> Any:
121 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
122 | return self.named.data.subtree_angmom['torso'][1]
123 |
124 |
125 | class Hopper(base.Task):
126 | """A Hopper's `Task` to train a standing and a jumping Hopper."""
127 |
128 | def __init__(self, hopping, forward=True, flip=False, random=None) -> None:
129 | """Initialize an instance of `Hopper`.
130 |
131 | Args:
132 | hopping: Boolean, if True the task is to hop forwards, otherwise it is to
133 | balance upright.
134 | random: Optional, either a `numpy.random.RandomState` instance, an
135 | integer seed for creating a new `RandomState`, or None to select a seed
136 | automatically (default).
137 | """
138 | self._hopping = hopping
139 | self._forward = 1 if forward else -1
140 | self._flip = flip
141 | self._timeout_progress = 0
142 | super(Hopper, self).__init__(random=random)
143 |
144 | def initialize_episode(self, physics) -> None:
145 | """Sets the state of the environment at the start of each episode."""
146 | randomizers.randomize_limited_and_rotational_joints(
147 | physics, self.random)
148 | self._timeout_progress = 0
149 | super(Hopper, self).initialize_episode(physics)
150 |
151 | def get_observation(self, physics) -> tp.Dict[str, Any]:
152 | """Returns an observation of positions, velocities and touch sensors."""
153 | obs = collections.OrderedDict()
154 | # Ignores horizontal position to maintain translational invariance:
155 | obs['position'] = physics.data.qpos[1:].copy()
156 | obs['velocity'] = physics.velocity()
157 | obs['touch'] = physics.touch()
158 | return obs
159 |
160 | def get_reward(self, physics) -> Any:
161 | """Returns a reward applicable to the performed task."""
162 | standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
163 | assert self._hopping
164 | if self._flip:
165 | hopping = rewards.tolerance(self._forward * physics.angmomentum(),
166 | bounds=(_SPIN_SPEED, float('inf')),
167 | margin=_SPIN_SPEED,
168 | value_at_margin=0,
169 | sigmoid='linear')
170 | else:
171 | hopping = rewards.tolerance(self._forward * physics.speed(),
172 | bounds=(_HOP_SPEED, float('inf')),
173 | margin=_HOP_SPEED / 2,
174 | value_at_margin=0.5,
175 | sigmoid='linear')
176 | return standing * hopping
177 |
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/jaco.py:
--------------------------------------------------------------------------------
1 | """A task where the goal is to move the hand close to a target prop or site."""
2 |
3 | import collections
4 |
5 | from dm_control import composer
6 | from dm_control.composer import initializers
7 | from dm_control.composer.variation import distributions
8 | from dm_control.entities import props
9 | from dm_control.manipulation.shared import arenas
10 | from dm_control.manipulation.shared import cameras
11 | from dm_control.manipulation.shared import constants
12 | from dm_control.manipulation.shared import observations
13 | from dm_control.manipulation.shared import robots
14 | from dm_control.manipulation.shared import workspaces
15 | from dm_control.utils import rewards
16 | from dm_env import specs
17 | import numpy as np
18 |
19 | _ReachWorkspace = collections.namedtuple(
20 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
21 |
22 | # Ensures that the props are not touching the table before settling.
23 | _PROP_Z_OFFSET = 0.001
24 |
25 | _DUPLO_WORKSPACE = _ReachWorkspace(
26 | target_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, _PROP_Z_OFFSET),
27 | upper=(0.1, 0.1, _PROP_Z_OFFSET)),
28 | tcp_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, 0.2),
29 | upper=(0.1, 0.1, 0.4)),
30 | arm_offset=robots.ARM_OFFSET)
31 |
32 | _SITE_WORKSPACE = _ReachWorkspace(
33 | target_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
34 | upper=(0.2, 0.2, 0.4)),
35 | tcp_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
36 | upper=(0.2, 0.2, 0.4)),
37 | arm_offset=robots.ARM_OFFSET)
38 |
39 | _TARGET_RADIUS = 0.05
40 | _TIME_LIMIT = 10.
41 |
42 | TASKS = [('reach_top_left', np.array([-0.09, 0.09, _PROP_Z_OFFSET])),
43 | ('reach_top_right', np.array([0.09, 0.09, _PROP_Z_OFFSET])),
44 | ('reach_bottom_left', np.array([-0.09, -0.09, _PROP_Z_OFFSET])),
45 | ('reach_bottom_right', np.array([0.09, -0.09, _PROP_Z_OFFSET]))]
46 |
47 |
48 | def make(task_id, obs_type, seed):
49 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
50 | task = _reach(task_id, obs_settings=obs_settings, use_site=True)
51 | return composer.Environment(task,
52 | time_limit=_TIME_LIMIT,
53 | random_state=seed)
54 |
55 |
56 | class MultiTaskReach(composer.Task):
57 | """Bring the hand close to a target prop or site."""
58 |
59 | def __init__(self, task_id, arena, arm, hand, prop, obs_settings,
60 | workspace, control_timestep):
61 | """Initializes a new `Reach` task.
62 |
63 | Args:
64 | arena: `composer.Entity` instance.
65 | arm: `robot_base.RobotArm` instance.
66 | hand: `robot_base.RobotHand` instance.
67 | prop: `composer.Entity` instance specifying the prop to reach to, or None
68 | in which case the target is a fixed site whose position is specified by
69 | the workspace.
70 | obs_settings: `observations.ObservationSettings` instance.
71 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
72 | control_timestep: Float specifying the control timestep in seconds.
73 | """
74 | self._arena = arena
75 | self._arm = arm
76 | self._hand = hand
77 | self._arm.attach(self._hand)
78 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
79 | self.control_timestep = control_timestep
80 | self._tcp_initializer = initializers.ToolCenterPointInitializer(
81 | self._hand,
82 | self._arm,
83 | position=distributions.Uniform(*workspace.tcp_bbox),
84 | quaternion=workspaces.DOWN_QUATERNION)
85 |
86 | # Add custom camera observable.
87 | self._task_observables = cameras.add_camera_observables(
88 | arena, obs_settings, cameras.FRONT_CLOSE)
89 |
90 | if task_id == 'reach_multitask':
91 | self._targets = [target for (_, target) in TASKS]
92 | else:
93 | self._targets = [
94 | target for (task, target) in TASKS if task == task_id
95 | ]
96 | assert len(self._targets) > 0
97 |
98 | #target_pos_distribution = distributions.Uniform(*TASKS[task_id])
99 | self._prop = prop
100 | if prop:
101 | # The prop itself is used to visualize the target location.
102 | self._make_target_site(parent_entity=prop, visible=False)
103 | self._target = self._arena.add_free_entity(prop)
104 | self._prop_placer = initializers.PropPlacer(
105 | props=[prop],
106 | position=target_pos_distribution,
107 | quaternion=workspaces.uniform_z_rotation,
108 | settle_physics=True)
109 | else:
110 | if len(self._targets) == 1:
111 | self._target = self._make_target_site(parent_entity=arena,
112 | visible=True)
113 |
114 | #obs = observable.MJCFFeature('pos', self._target)
115 | # obs.configure(**obs_settings.prop_pose._asdict())
116 | #self._task_observables['target_position'] = obs
117 |
118 | # Add sites for visualizing the prop and target bounding boxes.
119 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
120 | lower=workspace.tcp_bbox.lower,
121 | upper=workspace.tcp_bbox.upper,
122 | rgba=constants.GREEN,
123 | name='tcp_spawn_area')
124 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
125 | lower=workspace.target_bbox.lower,
126 | upper=workspace.target_bbox.upper,
127 | rgba=constants.BLUE,
128 | name='target_spawn_area')
129 |
130 | def _make_target_site(self, parent_entity, visible):
131 | return workspaces.add_target_site(
132 | body=parent_entity.mjcf_model.worldbody,
133 | radius=_TARGET_RADIUS,
134 | visible=visible,
135 | rgba=constants.RED,
136 | name='target_site')
137 |
138 | @property
139 | def root_entity(self):
140 | return self._arena
141 |
142 | @property
143 | def arm(self):
144 | return self._arm
145 |
146 | @property
147 | def hand(self):
148 | return self._hand
149 |
150 | def get_reward_spec(self):
151 | n = len(self._targets)
152 | return specs.Array(shape=(n,), dtype=np.float32, name='reward')
153 |
154 | @property
155 | def task_observables(self):
156 | return self._task_observables
157 |
158 | def get_reward(self, physics):
159 | hand_pos = physics.bind(self._hand.tool_center_point).xpos
160 | rews = []
161 | for target_pos in self._targets:
162 | distance = np.linalg.norm(hand_pos - target_pos)
163 | reward = rewards.tolerance(distance,
164 | bounds=(0, _TARGET_RADIUS),
165 | margin=_TARGET_RADIUS)
166 | rews.append(reward)
167 | rews = np.array(rews).astype(np.float32)
168 | if len(self._targets) == 1:
169 | return rews[0]
170 | return rews
171 |
172 | def initialize_episode(self, physics, random_state):
173 | self._hand.set_grasp(physics, close_factors=random_state.uniform())
174 | self._tcp_initializer(physics, random_state)
175 | if self._prop:
176 | self._prop_placer(physics, random_state)
177 | else:
178 | if len(self._targets) == 1:
179 | physics.bind(self._target).pos = self._targets[0]
180 |
181 |
182 | def _reach(task_id, obs_settings, use_site):
183 | """Configure and instantiate a `Reach` task.
184 |
185 | Args:
186 | obs_settings: An `observations.ObservationSettings` instance.
187 | use_site: Boolean, if True then the target will be a fixed site, otherwise
188 | it will be a moveable Duplo brick.
189 |
190 | Returns:
191 | An instance of `reach.Reach`.
192 | """
193 | arena = arenas.Standard()
194 | arm = robots.make_arm(obs_settings=obs_settings)
195 | hand = robots.make_hand(obs_settings=obs_settings)
196 | if use_site:
197 | workspace = _SITE_WORKSPACE
198 | prop = None
199 | else:
200 | workspace = _DUPLO_WORKSPACE
201 | prop = props.Duplo(observable_options=observations.make_options(
202 | obs_settings, observations.FREEPROP_OBSERVABLES))
203 | task = MultiTaskReach(task_id,
204 | arena=arena,
205 | arm=arm,
206 | hand=hand,
207 | prop=prop,
208 | obs_settings=obs_settings,
209 | workspace=workspace,
210 | control_timestep=constants.CONTROL_TIMESTEP)
211 | return task
212 |
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/quadruped.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/walker.py:
--------------------------------------------------------------------------------
1 | """Planar Walker Domain."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import collections
8 | from typing import Any, Tuple
9 | import typing as tp
10 | import os
11 |
12 | from dm_control import mujoco
13 | from dm_control.rl import control
14 | from dm_control.suite import base
15 | from dm_control.suite import common
16 | from dm_control.suite.utils import randomizers
17 | from dm_control.utils import containers
18 | from dm_control.utils import rewards
19 | from dm_control.utils import io as resources
20 |
21 | _CONTROL_TIMESTEP: float
22 | _DEFAULT_TIME_LIMIT: int
23 | _RUN_SPEED: int
24 | _SPIN_SPEED: int
25 | _STAND_HEIGHT: float
26 | _WALK_SPEED: int
27 | # from dm_control import suite # TODO useless?
28 |
29 | _DEFAULT_TIME_LIMIT = 25
30 | _CONTROL_TIMESTEP = .025
31 |
32 | # Minimal height of torso over foot above which stand reward is 1.
33 | _STAND_HEIGHT = 1.2
34 |
35 | # Horizontal speeds (meters/second) above which move reward is 1.
36 | _WALK_SPEED = 1
37 | _RUN_SPEED = 8
38 | _SPIN_SPEED = 5
39 |
40 | SUITE = containers.TaggedTasks()
41 |
42 |
43 | def make(task,
44 | task_kwargs=None,
45 | environment_kwargs=None,
46 | visualize_reward: bool = False):
47 | task_kwargs = task_kwargs or {}
48 | if environment_kwargs is not None:
49 | task_kwargs = task_kwargs.copy()
50 | task_kwargs['environment_kwargs'] = environment_kwargs
51 | env = SUITE[task](**task_kwargs)
52 | env.task.visualize_reward = visualize_reward
53 | return env
54 |
55 |
56 | def get_model_and_assets() -> Tuple[Any, Any]:
57 | """Returns a tuple containing the model XML string and a dict of assets."""
58 | root_dir = os.path.dirname(os.path.dirname(__file__))
59 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks',
60 | 'walker.xml'))
61 | return xml, common.ASSETS
62 |
63 |
64 | @SUITE.add('benchmarking')
65 | def flip(time_limit: int = _DEFAULT_TIME_LIMIT,
66 | random=None,
67 | environment_kwargs=None):
68 | """Returns the Run task."""
69 | physics = Physics.from_xml_string(*get_model_and_assets())
70 | task = PlanarWalker(move_speed=_RUN_SPEED,
71 | forward=True,
72 | flip=True,
73 | random=random)
74 | environment_kwargs = environment_kwargs or {}
75 | return control.Environment(physics,
76 | task,
77 | time_limit=time_limit,
78 | control_timestep=_CONTROL_TIMESTEP,
79 | **environment_kwargs)
80 |
81 |
82 | class Physics(mujoco.Physics):
83 | """Physics simulation with additional features for the Walker domain."""
84 |
85 | def torso_upright(self) -> Any:
86 | """Returns projection from z-axes of torso to the z-axes of world."""
87 | return self.named.data.xmat['torso', 'zz']
88 |
89 | def torso_height(self) -> Any:
90 | """Returns the height of the torso."""
91 | return self.named.data.xpos['torso', 'z']
92 |
93 | def horizontal_velocity(self) -> Any:
94 | """Returns the horizontal velocity of the center-of-mass."""
95 | return self.named.data.sensordata['torso_subtreelinvel'][0]
96 |
97 | def orientations(self) -> Any:
98 | """Returns planar orientations of all bodies."""
99 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
100 |
101 | def angmomentum(self) -> Any:
102 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
103 | return self.named.data.subtree_angmom['torso'][1]
104 |
105 |
106 | class PlanarWalker(base.Task):
107 | """A planar walker task."""
108 |
109 | def __init__(self, move_speed, forward=True, flip=False, random=None) -> None:
110 | """Initializes an instance of `PlanarWalker`.
111 |
112 | Args:
113 | move_speed: A float. If this value is zero, reward is given simply for
114 | standing up. Otherwise this specifies a target horizontal velocity for
115 | the walking task.
116 | random: Optional, either a `numpy.random.RandomState` instance, an
117 | integer seed for creating a new `RandomState`, or None to select a seed
118 | automatically (default).
119 | """
120 | self._move_speed = move_speed
121 | self._forward = 1 if forward else -1
122 | self._flip = flip
123 | super(PlanarWalker, self).__init__(random=random)
124 |
125 | def initialize_episode(self, physics) -> None:
126 | """Sets the state of the environment at the start of each episode.
127 |
128 | In 'standing' mode, use initial orientation and small velocities.
129 | In 'random' mode, randomize joint angles and let fall to the floor.
130 |
131 | Args:
132 | physics: An instance of `Physics`.
133 |
134 | """
135 | randomizers.randomize_limited_and_rotational_joints(
136 | physics, self.random)
137 | super(PlanarWalker, self).initialize_episode(physics)
138 |
139 | def get_observation(self, physics) -> tp.Dict[str, Any]:
140 | """Returns an observation of body orientations, height and velocites."""
141 | obs = collections.OrderedDict()
142 | obs['orientations'] = physics.orientations()
143 | obs['height'] = physics.torso_height()
144 | obs['velocity'] = physics.velocity()
145 | return obs
146 |
147 | def get_reward(self, physics) -> Any:
148 | """Returns a reward to the agent."""
149 | standing = rewards.tolerance(physics.torso_height(),
150 | bounds=(_STAND_HEIGHT, float('inf')),
151 | margin=_STAND_HEIGHT / 2)
152 | upright = (1 + physics.torso_upright()) / 2
153 | stand_reward = (3 * standing + upright) / 4
154 |
155 | if self._flip:
156 | move_reward = rewards.tolerance(self._forward *
157 | physics.angmomentum(),
158 | bounds=(_SPIN_SPEED, float('inf')),
159 | margin=_SPIN_SPEED,
160 | value_at_margin=0,
161 | sigmoid='linear')
162 | else:
163 | move_reward = rewards.tolerance(
164 | self._forward * physics.horizontal_velocity(),
165 | bounds=(self._move_speed, float('inf')),
166 | margin=self._move_speed / 2,
167 | value_at_margin=0.5,
168 | sigmoid='linear')
169 |
170 | return stand_reward * (5 * move_reward + 1) / 6
171 |
--------------------------------------------------------------------------------
/common/envs/exorl/custom_dmc_tasks/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/common/envs/exorl/exorl_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import glob
4 | import tqdm
5 | import numpy as np
6 | from collections import defaultdict
7 |
8 | # from fre.common.envs.d4rl import d4rl_utils
9 | from fre.common.dataset import Dataset
10 |
11 | # get path relative to 'fre' package
12 | data_path = os.path.dirname(os.path.abspath(__file__))
13 | data_path = Path(data_path).parents[2]
14 | data_path = os.path.join(data_path, 'data/exorl/')
15 | print("Path to exorl data is", data_path)
16 |
17 | def get_dataset(env, env_name, method='rnd', dmc_dataset_size=10000000, use_task_reward=True):
18 |
19 | # dmc_dataset_size /= 10
20 | # print("WARNING: Only using 10 percent of exorl data.")
21 |
22 | domain_name, task_name = env_name.split('_', 1)
23 |
24 | path = os.path.join(data_path, domain_name, method)
25 | if not os.path.exists(path):
26 | print("Downloading exorl data.")
27 | os.makedirs(path)
28 | url = "https://dl.fbaipublicfiles.com/exorl/" + domain_name + "/" + method + ".zip"
29 | print("Downloading from", url)
30 | os.system("wget " + url + " -P " + path)
31 | os.system("unzip " + path + "/" + method + ".zip -d " + path)
32 |
33 | # process data into Dataset object.
34 | path = os.path.join(data_path, domain_name, method, 'buffer')
35 | npzs = sorted(glob.glob(f'{path}/*.npz'))
36 | dataset_npy = os.path.join(data_path, domain_name, method, task_name + '.npy')
37 | if os.path.exists(dataset_npy):
38 | dataset = np.load(dataset_npy, allow_pickle=True).item()
39 | else:
40 | print("Calculating exorl rewards.")
41 | dataset = defaultdict(list)
42 | num_steps = 0
43 | for i, npz in tqdm.tqdm(enumerate(npzs)):
44 | traj_data = dict(np.load(npz))
45 | dataset['observations'].append(traj_data['observation'][:-1, :])
46 | dataset['next_observations'].append(traj_data['observation'][1:, :])
47 | dataset['actions'].append(traj_data['action'][1:, :])
48 | dataset['physics'].append(traj_data['physics'][1:, :]) # Note that this corresponds to next_observations (i.e., r(s, a, s') = r(s') -- following the original DMC rewards)
49 |
50 | if use_task_reward:
51 | # TODO: make this faster and sanity check it
52 | rewards = []
53 | reward_spec = env.reward_spec()
54 | states = traj_data['physics']
55 | for j in range(states.shape[0]):
56 | with env.physics.reset_context():
57 | env.physics.set_state(states[j])
58 | reward = env.task.get_reward(env.physics)
59 | reward = np.full(reward_spec.shape, reward, reward_spec.dtype)
60 | rewards.append(reward)
61 | traj_data['reward'] = np.array(rewards, dtype=reward_spec.dtype)
62 | dataset['rewards'].append(traj_data['reward'][1:])
63 | else:
64 | dataset['rewards'].append(traj_data['reward'][1:, 0])
65 |
66 | terminals = np.full((len(traj_data['observation']) - 1,), False)
67 | dataset['terminals'].append(terminals)
68 | num_steps += len(traj_data['observation']) - 1
69 | if num_steps >= dmc_dataset_size:
70 | break
71 | print("Loaded {} steps".format(num_steps))
72 | for k, v in dataset.items():
73 | dataset[k] = np.concatenate(v, axis=0)
74 | np.save(dataset_npy, dataset)
75 |
76 |
77 |
78 | # Processing
79 | masks = 1.0 - dataset['terminals']
80 | dones_float = dataset['terminals']
81 |
82 | return Dataset.create(
83 | observations=dataset['observations'],
84 | actions=dataset['actions'],
85 | rewards=dataset['rewards'],
86 | masks=masks,
87 | dones_float=dones_float,
88 | next_observations=dataset['next_observations'],
89 | )
--------------------------------------------------------------------------------
/common/envs/gc_utils.py:
--------------------------------------------------------------------------------
1 | from fre.common.dataset import Dataset
2 | from flax.core.frozen_dict import FrozenDict
3 | from flax.core import freeze
4 | import dataclasses
5 | import numpy as np
6 | import jax
7 | import ml_collections
8 |
9 | @dataclasses.dataclass
10 | class GCDataset:
11 | dataset: Dataset
12 | p_randomgoal: float
13 | p_trajgoal: float
14 | p_currgoal: float
15 | geom_sample: int
16 | discount: float
17 | terminal_key: str = 'dones_float'
18 | reward_scale: float = 1.0
19 | reward_shift: float = -1.0
20 | mask_terminal: int = 1
21 |
22 | @staticmethod
23 | def get_default_config():
24 | return ml_collections.ConfigDict({
25 | 'p_randomgoal': 0.3,
26 | 'p_trajgoal': 0.5,
27 | 'p_currgoal': 0.2,
28 | 'geom_sample': 1,
29 | 'discount': 0.99,
30 | 'reward_scale': 1.0,
31 | 'reward_shift': -1.0,
32 | 'mask_terminal': 1,
33 | })
34 |
35 | def __post_init__(self):
36 | self.terminal_locs, = np.nonzero(self.dataset[self.terminal_key] > 0)
37 | assert np.isclose(self.p_randomgoal + self.p_trajgoal + self.p_currgoal, 1.0)
38 |
39 | def sample_goals(self, indx, p_randomgoal=None, p_trajgoal=None, p_currgoal=None):
40 | if p_randomgoal is None:
41 | p_randomgoal = self.p_randomgoal
42 | if p_trajgoal is None:
43 | p_trajgoal = self.p_trajgoal
44 | if p_currgoal is None:
45 | p_currgoal = self.p_currgoal
46 |
47 | batch_size = len(indx)
48 | # Random goals
49 | goal_indx = np.random.randint(self.dataset.size, size=batch_size)
50 |
51 | # Goals from the same trajectory
52 | final_state_indx = self.terminal_locs[np.searchsorted(self.terminal_locs, indx)]
53 |
54 | distance = np.random.rand(batch_size)
55 | if self.geom_sample:
56 | us = np.random.rand(batch_size)
57 | middle_goal_indx = np.minimum(indx + np.ceil(np.log(1 - us) / np.log(self.discount)).astype(int), final_state_indx)
58 | else:
59 | middle_goal_indx = np.round((np.minimum(indx + 1, final_state_indx) * distance + final_state_indx * (1 - distance))).astype(int)
60 |
61 | goal_indx = np.where(np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal), middle_goal_indx, goal_indx)
62 |
63 | # Goals at the current state
64 | goal_indx = np.where(np.random.rand(batch_size) < p_currgoal, indx, goal_indx)
65 |
66 | return goal_indx
67 |
68 | def sample(self, batch_size: int, indx=None):
69 | if indx is None:
70 | indx = np.random.randint(self.dataset.size-1, size=batch_size)
71 |
72 | batch = self.dataset.sample(batch_size, indx)
73 | goal_indx = self.sample_goals(indx)
74 |
75 | success = (indx == goal_indx)
76 | batch['rewards'] = success.astype(float) * self.reward_scale + self.reward_shift
77 | batch['goals'] = jax.tree_map(lambda arr: arr[goal_indx], self.dataset['observations'])
78 |
79 | if self.mask_terminal:
80 | batch['masks'] = 1.0 - success.astype(float)
81 | else:
82 | batch['masks'] = np.ones(batch_size)
83 |
84 | return batch
85 |
86 | def sample_traj_random(self, batch_size, num_traj_states, num_random_states, num_random_states_decode):
87 | indx = np.random.randint(self.dataset.size-1, size=batch_size)
88 | batch = self.dataset.sample(batch_size, indx)
89 | indx_expand = np.repeat(indx, num_traj_states-1) # (batch_size * num_traj_states)
90 | traj_indx = self.sample_goals(indx_expand, p_randomgoal=0.0, p_trajgoal=1.0, p_currgoal=0.0)
91 | traj_indx = traj_indx.reshape(batch_size, num_traj_states-1) # (batch_size, num_traj_states)
92 | batch['traj_states'] = jax.tree_map(lambda arr: arr[traj_indx], self.dataset['observations'])
93 | batch['traj_states'] = np.concatenate([batch['observations'][:,None,:], batch['traj_states']], axis=1)
94 |
95 | rand_indx = np.random.randint(self.dataset.size-1, size=batch_size * num_random_states)
96 | rand_indx = rand_indx.reshape(batch_size, num_random_states)
97 | batch['random_states'] = jax.tree_map(lambda arr: arr[rand_indx], self.dataset['observations'])
98 |
99 | rand_indx_decode = np.random.randint(self.dataset.size-1, size=batch_size * num_random_states_decode)
100 | rand_indx_decode = rand_indx_decode.reshape(batch_size, num_random_states_decode)
101 | batch['random_states_decode'] = jax.tree_map(lambda arr: arr[rand_indx_decode], self.dataset['observations'])
102 | return batch
103 |
104 | def flatten_obgoal(obgoal):
105 | return np.concatenate([obgoal['observation'], obgoal['goal']], axis=-1)
--------------------------------------------------------------------------------
/common/envs/wrappers.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | #
3 | # Wrappers on top of gym environments
4 | #
5 | ###############################
6 |
7 | from typing import Dict
8 | import gym
9 | import numpy as np
10 | import time
11 |
12 | class EpisodeMonitor(gym.ActionWrapper):
13 | """A class that computes episode returns and lengths."""
14 |
15 | def __init__(self, env: gym.Env):
16 | super().__init__(env)
17 | self._reset_stats()
18 | self.total_timesteps = 0
19 |
20 | def _reset_stats(self):
21 | self.reward_sum = 0.0
22 | self.episode_length = 0
23 | self.start_time = time.time()
24 |
25 | def step(self, action: np.ndarray):
26 | observation, reward, done, info = self.env.step(action)
27 |
28 | self.reward_sum += reward
29 | self.episode_length += 1
30 | self.total_timesteps += 1
31 | info["total"] = {"timesteps": self.total_timesteps}
32 |
33 | if done:
34 | info["episode"] = {}
35 | info["episode"]["return"] = self.reward_sum
36 | info["episode"]["length"] = self.episode_length
37 | info["episode"]["duration"] = time.time() - self.start_time
38 |
39 | if hasattr(self, "get_normalized_score"):
40 | info["episode"]["normalized_return"] = (
41 | self.get_normalized_score(info["episode"]["return"]) * 100.0
42 | )
43 |
44 | return observation, reward, done, info
45 |
46 | def reset(self, **kwargs) -> np.ndarray:
47 | self._reset_stats()
48 | return self.env.reset(**kwargs)
49 |
50 | class RewardOverride(gym.ActionWrapper):
51 | def __init__(self, env: gym.Env):
52 | super().__init__(env)
53 | self.reward_fn = None
54 |
55 | def step(self, action: np.ndarray):
56 | observation, reward, done, info = self.env.step(action)
57 |
58 | if self.env.observation_space.shape[0] == 24:
59 | horizontal_velocity = self.env.physics.horizontal_velocity()
60 | torso_upright = self.env.physics.torso_upright()
61 | torso_height = self.env.physics.torso_height()
62 | aux = np.array([horizontal_velocity, torso_upright, torso_height])
63 | observation_aux = np.concatenate([observation, aux])
64 | reward = self.reward_fn(observation_aux)
65 | elif self.env.observation_space.shape[0] == 17:
66 | horizontal_velocity = self.env.physics.speed()
67 | aux = np.array([horizontal_velocity])
68 | observation_aux = np.concatenate([observation, aux])
69 | reward = self.reward_fn(observation_aux)
70 | else:
71 | reward = self.reward_fn(observation)
72 | return observation, reward, done, info
73 |
74 | def reset(self, **kwargs) -> np.ndarray:
75 | return self.env.reset(**kwargs)
76 |
77 | class TruncateObservation(gym.ObservationWrapper):
78 | def __init__(self, env: gym.Env, truncate_size: int):
79 | super().__init__(env)
80 | self.truncate_size = truncate_size
81 |
82 | def observation(self, observation: np.ndarray) -> np.ndarray:
83 | return observation[:self.truncate_size]
84 |
85 | class GoalWrapper(gym.ObservationWrapper):
86 | def __init__(self, env: gym.Env):
87 | super().__init__(env)
88 | self.custom_goal = None
89 |
90 | def observation(self, observation: np.ndarray) -> np.ndarray:
91 | if self.custom_goal is not None:
92 | return np.concatenate([observation, self.custom_goal])
93 | else:
94 | return observation
95 |
--------------------------------------------------------------------------------
/common/evaluation.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | #
3 | # Tools for evaluating policies in environments.
4 | #
5 | ###############################
6 |
7 |
8 | from typing import Dict
9 | import jax
10 | import gym
11 | import numpy as np
12 | from collections import defaultdict
13 | import time
14 | import wandb
15 |
16 |
17 | def flatten(d, parent_key="", sep="."):
18 | """
19 | Helper function that flattens a dictionary of dictionaries into a single dictionary.
20 | E.g: flatten({'a': {'b': 1}}) -> {'a.b': 1}
21 | """
22 | items = []
23 | for k, v in d.items():
24 | new_key = parent_key + sep + k if parent_key else k
25 | if hasattr(v, "items"):
26 | items.extend(flatten(v, new_key, sep=sep).items())
27 | else:
28 | items.append((new_key, v))
29 | return dict(items)
30 |
31 |
32 | def add_to(dict_of_lists, single_dict):
33 | for k, v in single_dict.items():
34 | dict_of_lists[k].append(v)
35 |
36 |
37 | def evaluate(policy_fn, env: gym.Env, num_episodes: int, record_video : bool = False,
38 | return_trajectories=False, clip_return_at_goal=False, binary_return=False, use_discrete_xy=False, clip_margin=0):
39 | print("Clip return at goal is", clip_return_at_goal)
40 | stats = defaultdict(list)
41 | frames = []
42 | trajectories = []
43 | for i in range(num_episodes):
44 | now = time.time()
45 | trajectory = defaultdict(list)
46 | ob_list = []
47 | ac_list = []
48 | observation, done = env.reset(), False
49 | ob_list.append(observation)
50 | while not done:
51 | if use_discrete_xy:
52 | import fre.common.envs.d4rl.d4rl_ant as d4rl_ant
53 | ob_input = d4rl_ant.discretize_obs(observation)
54 | else:
55 | ob_input = observation
56 | action = policy_fn(ob_input)
57 | action = np.array(action)
58 | next_observation, r, done, info = env.step(action)
59 | add_to(stats, flatten(info))
60 |
61 | if type(observation) is dict:
62 | obs_pure = observation['observation']
63 | next_obs_pure = next_observation['observation']
64 | else:
65 | obs_pure = observation
66 | next_obs_pure = next_observation
67 | transition = dict(
68 | observation=obs_pure,
69 | next_observation=next_obs_pure,
70 | action=action,
71 | reward=r,
72 | done=done,
73 | info=info,
74 | )
75 | observation = next_observation
76 | ob_list.append(observation)
77 | ac_list.append(action)
78 | add_to(trajectory, transition)
79 |
80 | if i <= 3 and record_video:
81 | frames.append(env.render(mode="rgb_array"))
82 | add_to(stats, flatten(info, parent_key="final"))
83 | trajectories.append(trajectory)
84 | print("Finished Episode", i, "in", time.time() - now, "seconds")
85 |
86 | if clip_return_at_goal and 'episode.return' in stats:
87 | print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length']))
88 | stats['episode.return'] = np.clip(np.array(stats['episode.length']) + np.array(stats['episode.return']) - clip_margin, 0, 1) # Goal is a binary indicator.
89 | print("Clipped return is {}.".format(stats['episode.return']))
90 | elif binary_return and 'episode.return' in stats:
91 | # Assume that the reward is either 0 or 1 at each timestep.
92 | print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length']))
93 | stats['episode.return'] = np.clip(np.array(stats['episode.return']), 0, 1)
94 | print("Clipped return is {}.".format(stats['episode.return']))
95 |
96 | if 'episode.return' in stats:
97 | print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length']))
98 |
99 | for k, v in stats.items():
100 | stats[k] = np.mean(v)
101 |
102 | if record_video:
103 | stacked = np.stack(frames)
104 | stacked = stacked.transpose(0, 3, 1, 2)
105 | while stacked.shape[2] > 160:
106 | stacked = stacked[:, :, ::2, ::2]
107 | stats['video'] = wandb.Video(stacked, fps=60)
108 |
109 | if return_trajectories:
110 | return stats, trajectories
111 | else:
112 | return stats
--------------------------------------------------------------------------------
/common/networks/basic.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | #
3 | # Common Flax Networks.
4 | #
5 | ###############################
6 |
7 | from fre.common.typing import *
8 |
9 | import flax.linen as nn
10 | import jax.numpy as jnp
11 |
12 | import distrax
13 | import flax.linen as nn
14 | import jax.numpy as jnp
15 | from dataclasses import field
16 |
17 | ###############################
18 | #
19 | # Common Networks
20 | #
21 | ###############################
22 |
23 | def mish(x):
24 | return x * jnp.tanh(nn.softplus(x))
25 |
26 | def default_init(scale: Optional[float] = 1.0):
27 | return nn.initializers.variance_scaling(scale, "fan_avg", "uniform")
28 |
29 | class MLP(nn.Module):
30 | hidden_dims: Sequence[int]
31 | activations: Callable[[jnp.ndarray], jnp.ndarray] = mish
32 | activate_final: int = False
33 | use_layer_norm: bool = True
34 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_init()
35 |
36 | def setup(self):
37 | self.layers = [
38 | nn.Dense(size, kernel_init=self.kernel_init) for size in self.hidden_dims
39 | ]
40 | if self.use_layer_norm:
41 | self.layer_norms = [nn.LayerNorm() for _ in self.hidden_dims]
42 |
43 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
44 | for i, layer in enumerate(self.layers):
45 | x = layer(x)
46 | if i + 1 < len(self.layers) and self.use_layer_norm:
47 | x = self.layer_norms[i](x)
48 | if i + 1 < len(self.layers) or self.activate_final:
49 | x = self.activations(x)
50 | return x
51 |
52 | ###############################
53 | #
54 | # Common RL Networks
55 | #
56 | ###############################
57 |
58 |
59 | # DQN-style critic.
60 | class DiscreteCritic(nn.Module):
61 | hidden_dims: Sequence[int]
62 | n_actions: int
63 | mlp_kwargs: Dict[str, Any] = field(default_factory=dict)
64 |
65 | @nn.compact
66 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
67 | return MLP((*self.hidden_dims, self.n_actions), **self.mlp_kwargs)(
68 | observations
69 | )
70 |
71 | # Q(s,a) critic.
72 | class Critic(nn.Module):
73 | hidden_dims: Sequence[int]
74 | mlp_kwargs: Dict[str, Any] = field(default_factory=dict)
75 |
76 | @nn.compact
77 | def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
78 | inputs = jnp.concatenate([observations, actions], -1)
79 | critic = MLP((*self.hidden_dims, 1), **self.mlp_kwargs)(inputs)
80 | return jnp.squeeze(critic, -1)
81 |
82 | # V(s) critic.
83 | class ValueCritic(nn.Module):
84 | hidden_dims: Sequence[int]
85 | mlp_kwargs: Dict[str, Any] = field(default_factory=dict)
86 |
87 | @nn.compact
88 | def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
89 | critic = MLP((*self.hidden_dims, 1), **self.mlp_kwargs)(observations)
90 | return jnp.squeeze(critic, -1)
91 |
92 | # pi(a|s). Returns a distrax distribution.
93 | class Policy(nn.Module):
94 | hidden_dims: Sequence[int]
95 | action_dim: int
96 | mlp_kwargs: Dict[str, Any] = field(default_factory=dict)
97 |
98 | is_discrete: bool = False
99 | log_std_min: Optional[float] = -20
100 | log_std_max: Optional[float] = 2
101 | mean_min: Optional[float] = -5
102 | mean_max: Optional[float] = 5
103 | tanh_squash_distribution: bool = False
104 | state_dependent_std: bool = True
105 | final_fc_init_scale: float = 1e-2
106 |
107 | @nn.compact
108 | def __call__(
109 | self, observations: jnp.ndarray, temperature: float = 1.0
110 | ) -> distrax.Distribution:
111 | outputs = MLP(
112 | self.hidden_dims,
113 | activate_final=True,
114 | **self.mlp_kwargs
115 | )(observations)
116 |
117 | if self.is_discrete:
118 | logits = nn.Dense(
119 | self.action_dim, kernel_init=default_init(self.final_fc_init_scale)
120 | )(outputs)
121 | distribution = distrax.Categorical(logits=logits / jnp.maximum(1e-6, temperature))
122 | else:
123 | means = nn.Dense(
124 | self.action_dim, kernel_init=default_init(self.final_fc_init_scale)
125 | )(outputs)
126 | if self.state_dependent_std:
127 | log_stds = nn.Dense(
128 | self.action_dim, kernel_init=default_init(self.final_fc_init_scale)
129 | )(outputs)
130 | else:
131 | log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,))
132 |
133 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)
134 | means = jnp.clip(means, self.mean_min, self.mean_max)
135 |
136 | distribution = distrax.MultivariateNormalDiag(
137 | loc=means, scale_diag=jnp.exp(log_stds) * temperature
138 | )
139 | if self.tanh_squash_distribution:
140 | distribution = TransformedWithMode(
141 | distribution, distrax.Block(distrax.Tanh(), ndims=1)
142 | )
143 | return distribution
144 |
145 | ###############################
146 | #
147 | # Helper Things
148 | #
149 | ###############################
150 |
151 |
152 | class TransformedWithMode(distrax.Transformed):
153 | def mode(self) -> jnp.ndarray:
154 | return self.bijector.forward(self.distribution.mode())
155 |
156 | def ensemblize(cls, num_qs, out_axes=0, **kwargs):
157 | """
158 | Useful for making ensembles of Q functions (e.g. double Q in SAC).
159 |
160 | Usage:
161 |
162 | critic_def = ensemblize(Critic, 2)(hidden_dims=hidden_dims)
163 |
164 | """
165 | return nn.vmap(
166 | cls,
167 | variable_axes={"params": 0},
168 | split_rngs={"params": True},
169 | in_axes=None,
170 | out_axes=out_axes,
171 | axis_size=num_qs,
172 | **kwargs
173 | )
--------------------------------------------------------------------------------
/common/networks/transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional, Tuple, Type
2 |
3 | import flax.linen as nn
4 | import jax.numpy as jnp
5 |
6 | Array = Any
7 | PRNGKey = Any
8 | Shape = Tuple[int]
9 | Dtype = Any
10 |
11 |
12 | class IdentityLayer(nn.Module):
13 | """Identity layer, convenient for giving a name to an array."""
14 |
15 | @nn.compact
16 | def __call__(self, x):
17 | return x
18 |
19 |
20 | class AddPositionEmbs(nn.Module):
21 | # Need to define function that adds the poisition embeddings to the input.
22 | posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
23 |
24 | @nn.compact
25 | def __call__(self, inputs):
26 | """
27 | inputs.shape is (batch_size, timesteps, emb_dim).
28 | Output tensor with shape `(batch_size, timesteps, in_dim)`.
29 | """
30 | assert inputs.ndim == 3, ('Number of dimensions should be 3, but it is: %d' % inputs.ndim)
31 |
32 | position_ids = jnp.arange(inputs.shape[1])[None] # (1, timesteps)
33 | pos_embeddings = nn.Embed(
34 | 128, # Max Positional Embeddings
35 | inputs.shape[2],
36 | embedding_init=self.posemb_init,
37 | dtype=inputs.dtype,
38 | )(position_ids)
39 | print("For Input Shape {}, Pos Embes Shape is {}".format(inputs.shape, pos_embeddings.shape))
40 | return inputs + pos_embeddings
41 |
42 | # pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
43 | # pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape)
44 | # return inputs + pe
45 |
46 |
47 | class MlpBlock(nn.Module):
48 | """Transformer MLP / feed-forward block."""
49 |
50 | mlp_dim: int
51 | dtype: Dtype = jnp.float32
52 | out_dim: Optional[int] = None
53 | dropout_rate: float = None
54 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
55 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)
56 |
57 | @nn.compact
58 | def __call__(self, inputs, *, deterministic):
59 | """It's just an MLP, so the input shape is (batch, len, emb)."""
60 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
61 | x = nn.Dense(
62 | features=self.mlp_dim,
63 | dtype=self.dtype,
64 | kernel_init=self.kernel_init,
65 | bias_init=self.bias_init)(inputs)
66 | x = nn.gelu(x)
67 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
68 | output = nn.Dense(
69 | features=actual_out_dim,
70 | dtype=self.dtype,
71 | kernel_init=self.kernel_init,
72 | bias_init=self.bias_init)(x)
73 | output = nn.Dropout(
74 | rate=self.dropout_rate)(output, deterministic=deterministic)
75 | return output
76 |
77 |
78 | class Encoder1DBlock(nn.Module):
79 | """Transformer encoder layer.
80 | Given a sequence, it passes it through an attention layer, then through a mlp layer.
81 | In each case it is a residual block with a layer norm.
82 | """
83 |
84 | mlp_dim: int
85 | num_heads: int
86 | causal: bool
87 | dropout_rate: float
88 | attention_dropout_rate: float
89 | dtype: Dtype = jnp.float32
90 |
91 | @nn.compact
92 | def __call__(self, inputs, *, deterministic, train=True):
93 |
94 | if self.causal:
95 | causal_mask = nn.make_causal_mask(jnp.ones((inputs.shape[0], inputs.shape[1]),
96 | dtype="bool"), dtype="bool")
97 | print("Using Causal Mask with shape", causal_mask.shape, "and inputs shape", inputs.shape, ".")
98 | else:
99 | causal_mask = None
100 |
101 | # Attention block.
102 | assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}'
103 | x = nn.LayerNorm(dtype=self.dtype)(inputs)
104 | x = nn.MultiHeadDotProductAttention(
105 | dtype=self.dtype,
106 | kernel_init=nn.initializers.xavier_uniform(),
107 | broadcast_dropout=False,
108 | deterministic=deterministic,
109 | dropout_rate=self.attention_dropout_rate,
110 | decode=False,
111 | num_heads=self.num_heads)(x, x, causal_mask)
112 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
113 | x = x + inputs
114 |
115 | # MLP block. This does NOT change the embedding dimension!
116 | y = nn.LayerNorm(dtype=self.dtype)(x)
117 | y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(y, deterministic=deterministic)
118 |
119 | return x + y
120 |
121 |
122 | class Transformer(nn.Module):
123 | """Transformer Model Encoder for sequence to sequence translation.
124 | """
125 |
126 | num_layers: int
127 | emb_dim: int
128 | mlp_dim: int
129 | num_heads: int
130 | dropout_rate: float
131 | attention_dropout_rate: float
132 | causal: bool = True
133 |
134 | @nn.compact
135 | def __call__(self, x, *, train):
136 | assert x.ndim == 3 # (batch, len, emb)
137 | assert x.shape[-1] == self.emb_dim
138 |
139 | # Input Encoder. Each layer processes x, but the shape of x does not change.
140 | for lyr in range(self.num_layers):
141 | x = Encoder1DBlock(
142 | mlp_dim=self.mlp_dim,
143 | dropout_rate=self.dropout_rate,
144 | attention_dropout_rate=self.attention_dropout_rate,
145 | name=f'encoderblock_{lyr}',
146 | causal=self.causal,
147 | num_heads=self.num_heads)(
148 | x, deterministic=not train, train=train)
149 | encoded = nn.LayerNorm(name='encoder_norm')(x)
150 |
151 | return encoded
152 |
153 | def get_default_config():
154 | import ml_collections
155 |
156 | config = ml_collections.ConfigDict({
157 | 'num_layers': 4,
158 | 'emb_dim': 256,
159 | 'mlp_dim': 256,
160 | 'num_heads': 4,
161 | 'dropout_rate': 0.0,
162 | 'attention_dropout_rate': 0.0,
163 | 'causal': True,
164 | })
165 | return config
--------------------------------------------------------------------------------
/common/train_state.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | #
3 | # Structures for managing training of flax networks.
4 | #
5 | ###############################
6 |
7 | from fre.common.typing import *
8 | import flax
9 | import flax.linen as nn
10 | import jax
11 | import jax.numpy as jnp
12 | from jax import tree_util
13 | import optax
14 | import functools
15 |
16 | import gym
17 |
18 | nonpytree_field = functools.partial(flax.struct.field, pytree_node=False)
19 |
20 |
21 | def shard_batch(batch):
22 | d = jax.local_device_count()
23 |
24 | def reshape(x):
25 | assert (
26 | x.shape[0] % d == 0
27 | ), f"Batch size needs to be divisible by # devices, got {x.shape[0]} and {d}"
28 | return x.reshape((d, x.shape[0] // d, *x.shape[1:]))
29 |
30 | return tree_util.tree_map(reshape, batch)
31 |
32 |
33 | def target_update(
34 | model: "TrainState", target_model: "TrainState", tau: float
35 | ) -> "TrainState":
36 | new_target_params = jax.tree_map(
37 | lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params
38 | )
39 | return target_model.replace(params=new_target_params)
40 |
41 |
42 | class TrainState(flax.struct.PyTreeNode):
43 | """
44 | Core abstraction of a model in this repository.
45 |
46 | Creation:
47 | ```
48 | model_def = nn.Dense(12) # or any other flax.linen Module
49 | params = model_def.init(jax.random.PRNGKey(0), jnp.ones((1, 4)))['params']
50 | model = TrainState.create(model_def, params, tx=None) # Optionally, pass in an optax optimizer
51 | ```
52 |
53 | Usage:
54 | ```
55 | y = model(jnp.ones((1, 4))) # By default, uses the `__call__` method of the model_def and params stored in TrainState
56 | y = model(jnp.ones((1, 4)), params=params) # You can pass in params (useful for gradient computation)
57 | y = model(jnp.ones((1, 4)), method=method) # You can apply a different method as well
58 | ```
59 |
60 | More complete example:
61 | ```
62 | def loss(params):
63 | y_pred = model(x, params=params)
64 | return jnp.mean((y - y_pred) ** 2)
65 |
66 | grads = jax.grad(loss)(model.params)
67 | new_model = model.apply_gradients(grads=grads) # Alternatively, new_model = model.apply_loss_fn(loss_fn=loss)
68 | ```
69 | """
70 |
71 | step: int
72 | apply_fn: Callable[..., Any] = nonpytree_field()
73 | model_def: Any = nonpytree_field()
74 | params: Params
75 | tx: Optional[optax.GradientTransformation] = nonpytree_field()
76 | opt_state: Optional[optax.OptState] = None
77 |
78 | @classmethod
79 | def create(
80 | cls,
81 | model_def: nn.Module,
82 | params: Params,
83 | tx: Optional[optax.GradientTransformation] = None,
84 | **kwargs,
85 | ) -> "TrainState":
86 | if tx is not None:
87 | opt_state = tx.init(params)
88 | else:
89 | opt_state = None
90 |
91 | return cls(
92 | step=1,
93 | apply_fn=model_def.apply,
94 | model_def=model_def,
95 | params=params,
96 | tx=tx,
97 | opt_state=opt_state,
98 | **kwargs,
99 | )
100 |
101 | def __call__(
102 | self,
103 | *args,
104 | params=None,
105 | extra_variables: dict = None,
106 | method: ModuleMethod = None,
107 | **kwargs,
108 | ):
109 | """
110 | Internally calls model_def.apply_fn with the following logic:
111 |
112 | Arguments:
113 | params: If not None, use these params instead of the ones stored in the model.
114 | extra_variables: Additional variables to pass into apply_fn
115 | method: If None, use the `__call__` method of the model_def. If a string, uses
116 | the method of the model_def with that name (e.g. 'encode' -> model_def.encode).
117 | If a function, uses that function.
118 |
119 | """
120 | if params is None:
121 | params = self.params
122 |
123 | variables = {"params": params}
124 |
125 | if extra_variables is not None:
126 | variables = {**variables, **extra_variables}
127 |
128 | if isinstance(method, str):
129 | method = getattr(self.model_def, method)
130 |
131 | return self.apply_fn(variables, *args, method=method, **kwargs)
132 |
133 | def do(self, method):
134 | return functools.partial(self, method=method)
135 |
136 | def apply_gradients(self, *, grads, **kwargs):
137 | """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
138 |
139 | Note that internally this function calls `.tx.update()` followed by a call
140 | to `optax.apply_updates()` to update `params` and `opt_state`.
141 |
142 | Args:
143 | grads: Gradients that have the same pytree structure as `.params`.
144 | **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
145 |
146 | Returns:
147 | An updated instance of `self` with `step` incremented by one, `params`
148 | and `opt_state` updated by applying `grads`, and additional attributes
149 | replaced as specified by `kwargs`.
150 | """
151 | updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
152 | new_params = optax.apply_updates(self.params, updates)
153 |
154 | return self.replace(
155 | step=self.step + 1,
156 | params=new_params,
157 | opt_state=new_opt_state,
158 | **kwargs,
159 | )
160 |
161 | def apply_loss_fn(self, *, loss_fn, pmap_axis=None, has_aux=False):
162 | """
163 | Takes a gradient step towards minimizing `loss_fn`. Internally, this calls
164 | `jax.grad` followed by `TrainState.apply_gradients`. If pmap_axis is provided,
165 | additionally it averages gradients (and info) across devices before performing update.
166 | """
167 | if has_aux:
168 | grads, info = jax.grad(loss_fn, has_aux=has_aux)(self.params)
169 | if pmap_axis is not None:
170 | grads = jax.lax.pmean(grads, axis_name=pmap_axis)
171 | info = jax.lax.pmean(info, axis_name=pmap_axis)
172 |
173 | return self.apply_gradients(grads=grads), info
174 |
175 | else:
176 | grads = jax.grad(loss_fn, has_aux=has_aux)(self.params)
177 | if pmap_axis is not None:
178 | grads = jax.lax.pmean(grads, axis_name=pmap_axis)
179 | return self.apply_gradients(grads=grads)
180 |
181 | class NormalizeActionWrapper(gym.Wrapper):
182 | """A wrapper that maps actions from [-1,1] to [low, hgih]."""
183 | def __init__(self, env):
184 | super().__init__(env)
185 | self.active = type(env.action_space) == gym.spaces.Box
186 | if self.active:
187 | self.action_low = env.action_space.low
188 | self.action_high = env.action_space.high
189 | self.action_scale = (self.action_high - self.action_low) * 0.5
190 | self.action_mid = (self.action_high + self.action_low) * 0.5
191 | print("Normalizing Action Space from [{}, {}] to [-1, 1]".format(self.action_low[0], self.action_high[0]))
192 | def step(self, action):
193 | if self.active:
194 | action = np.clip(action, -1, 1)
195 | action = action * self.action_scale
196 | action = action + self.action_mid
197 | return self.env.step(action)
198 |
199 | def reset(self, **kwargs):
200 | return self.env.reset(**kwargs)
201 |
202 |
--------------------------------------------------------------------------------
/common/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
2 | import numpy as np
3 | import jax.numpy as jnp
4 | import flax
5 |
6 | PRNGKey = Any
7 | Params = flax.core.FrozenDict[str, Any]
8 | PRNGKey = Any
9 | Shape = Sequence[int]
10 | Dtype = Any # this could be a real type?
11 | InfoDict = Dict[str, float]
12 | Array = Union[np.ndarray, jnp.ndarray]
13 | Data = Union[Array, Dict[str, "Data"]]
14 | Batch = Dict[str, Data]
15 | ModuleMethod = Union[
16 | str, Callable, None
17 | ] # A method to be passed into TrainState.__call__
18 |
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | #
3 | # Some shared utility functions
4 | #
5 | ###############################
6 |
7 | import jax
8 |
9 | def supply_rng(f, rng=jax.random.PRNGKey(0)):
10 | """
11 | Wraps a function to supply jax rng. It will remember the rng state for that function.
12 | """
13 | def wrapped(*args, **kwargs):
14 | nonlocal rng
15 | rng, key = jax.random.split(rng)
16 | return f(*args, seed=key, **kwargs)
17 |
18 | return wrapped
19 |
20 |
--------------------------------------------------------------------------------
/common/wandb.py:
--------------------------------------------------------------------------------
1 | """WandB logging helpers.
2 |
3 | Run setup_wandb(hyperparam_dict, ...) to initialize wandb logging.
4 | See default_wandb_config() for a list of available configurations.
5 |
6 | We recommend the following workflow (see examples/mujoco/d4rl_iql.py for a more full example):
7 |
8 | from ml_collections import config_flags
9 | from jaxrl_m.wandb import setup_wandb, default_wandb_config
10 | import wandb
11 |
12 | # This line allows us to change wandb config flags from the command line
13 | config_flags.DEFINE_config_dict('wandb', default_wandb_config(), lock_config=False)
14 |
15 | ...
16 | def main(argv):
17 | hyperparams = ...
18 | setup_wandb(hyperparams, **FLAGS.wandb)
19 |
20 | # Log metrics as you wish now
21 | wandb.log({'metric': 0.0}, step=0)
22 |
23 |
24 | With the following setup, you may set wandb configurations from the command line, e.g.
25 | python main.py --wandb.project=my_project --wandb.group=my_group --wandb.offline
26 | """
27 | import wandb
28 |
29 | import tempfile
30 | import absl.flags as flags
31 | import ml_collections
32 | from ml_collections.config_dict import FieldReference
33 | import datetime
34 | import wandb
35 | import time
36 | import numpy as np
37 | import os
38 |
39 |
40 | def get_flag_dict():
41 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS}
42 | for k in flag_dict:
43 | if isinstance(flag_dict[k], ml_collections.ConfigDict):
44 | flag_dict[k] = flag_dict[k].to_dict()
45 | return flag_dict
46 |
47 |
48 | def default_wandb_config():
49 | config = ml_collections.ConfigDict()
50 | config.offline = False # Syncs online or not?
51 | config.project = "jaxrl_m" # WandB Project Name
52 | config.entity = FieldReference(None, field_type=str) # Which entity to log as (default: your own user)
53 |
54 | group_name = FieldReference(None, field_type=str) # Group name
55 | config.exp_prefix = group_name # Group name (deprecated, but kept for backwards compatibility)
56 | config.group = group_name # Group name
57 |
58 | experiment_name = FieldReference(None, field_type=str) # Experiment name
59 | config.name = experiment_name # Run name (will be formatted with flags / variant)
60 | config.exp_descriptor = experiment_name # Run name (deprecated, but kept for backwards compatibility)
61 |
62 | config.unique_identifier = "" # Unique identifier for run (will be automatically generated unless provided)
63 | config.random_delay = 0 # Random delay for wandb.init (in seconds)
64 | return config
65 |
66 |
67 | def setup_wandb(
68 | hyperparam_dict,
69 | entity=None,
70 | project="jaxrl_m",
71 | group=None,
72 | name=None,
73 | unique_identifier="",
74 | offline=False,
75 | random_delay=0,
76 | **additional_init_kwargs,
77 | ):
78 | """
79 | Utility for setting up wandb logging (based on Young's simplesac):
80 |
81 | Arguments:
82 | - hyperparam_dict: dict of hyperparameters for experiment
83 | - offline: bool, whether to sync online or not
84 | - project: str, wandb project name
85 | - entity: str, wandb entity name (default is your user)
86 | - group: str, Group name for wandb
87 | - name: str, Experiment name for wandb (formatted with FLAGS & hyperparameter_dict)
88 | - unique_identifier: str, Unique identifier for wandb (default is timestamp)
89 | - random_delay: float, Random delay for wandb.init (in seconds) to avoid collisions
90 | - additional_init_kwargs: dict, additional kwargs to pass to wandb.init
91 | Returns:
92 | - wandb.run
93 |
94 | """
95 | if "exp_descriptor" in additional_init_kwargs:
96 | # Remove deprecated exp_descriptor
97 | additional_init_kwargs.pop("exp_descriptor")
98 | additional_init_kwargs.pop("exp_prefix")
99 |
100 | if not unique_identifier:
101 | if random_delay:
102 | time.sleep(np.random.uniform(0, random_delay))
103 | unique_identifier = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
104 | unique_identifier += f"_{np.random.randint(0, 1000000):06d}"
105 | flag_dict = get_flag_dict()
106 | if 'seed' in flag_dict:
107 | unique_identifier += f"_{flag_dict['seed']:02d}"
108 |
109 | if name is not None:
110 | name = name.format(**{**get_flag_dict(), **hyperparam_dict})
111 |
112 | if group is not None and name is not None:
113 | experiment_id = f"{name}_{unique_identifier}"
114 | elif name is not None:
115 | experiment_id = f"{name}_{unique_identifier}"
116 | else:
117 | experiment_id = None
118 |
119 | # check if dir exists.
120 | wandb_output_dir = tempfile.mkdtemp()
121 | tags = [group] if group is not None else None
122 |
123 | init_kwargs = dict(
124 | config=hyperparam_dict,
125 | project=project,
126 | entity=entity,
127 | tags=tags,
128 | group=group,
129 | dir=wandb_output_dir,
130 | id=experiment_id,
131 | name=name,
132 | settings=wandb.Settings(
133 | start_method="thread",
134 | _disable_stats=False,
135 | ),
136 | mode="offline" if offline else "online",
137 | save_code=True,
138 | )
139 |
140 | init_kwargs.update(additional_init_kwargs)
141 | run = wandb.init(**init_kwargs)
142 |
143 | wandb.config.update(get_flag_dict())
144 |
145 | wandb_config = dict(
146 | exp_prefix=group,
147 | exp_descriptor=name,
148 | experiment_id=experiment_id,
149 | )
150 | wandb.config.update(wandb_config)
151 | return run
152 |
--------------------------------------------------------------------------------
/deps/base_container.def:
--------------------------------------------------------------------------------
1 | Bootstrap:docker
2 | From: nvidia/cuda:11.8.0-devel-ubuntu22.04
3 |
4 | # Copy the conda env file into the container for installation
5 | %files
6 | environment.yml /contained/setup/environment.yml
7 | requirements.txt /contained/setup/requirements.txt
8 |
9 | %post -c /bin/bash
10 | apt-get update && apt-get install -y wget
11 | apt-get install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
12 | wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -P /contained/
13 | sh /contained/Miniforge3-Linux-x86_64.sh -b -p /contained/miniconda
14 | ls /contained/
15 | source /contained/miniconda/etc/profile.d/conda.sh
16 | source /contained/miniconda/etc/profile.d/mamba.sh
17 |
18 | export MUJOCO_PY_MJKEY_PATH='/contained/software/mujoco/mjkey.txt'
19 | export MUJOCO_PY_MUJOCO_PATH='/contained/software/mujoco/mujoco210'
20 | export MJKEY_PATH='/contained/software/mujoco/mjkey.txt'
21 | export MJLIB_PATH='/contained/software/mujoco/mujoco210/bin/libmujoco210.so'
22 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/contained/software/mujoco/mujoco210/bin"
23 | export D4RL_SUPPRESS_IMPORT_ERROR=1
24 |
25 | mkdir /contained/software
26 | mkdir /contained/software/mujoco
27 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
28 | tar -C /contained/software/mujoco -zxvf mujoco210-linux-x86_64.tar.gz --no-same-owner
29 |
30 | export WANDB_API_KEY=''
31 |
32 | cp -r /contained/software/mujoco /root/.mujoco
33 | CONDA_OVERRIDE_CUDA="11.8" mamba env create -f /contained/setup/environment.yml
34 |
35 | mamba clean --all
36 |
37 | # Trigger mujoco-py build
38 | mamba activate $(cat /contained/setup/environment.yml | egrep "name: .+$" | sed -e 's/^name:[ \t]*//')
39 | python -c 'import gym; gym.make("HalfCheetah-v2")'
40 |
41 | chmod -R 777 /contained
42 |
43 | %environment
44 | # Activate conda environment
45 | source /contained/miniconda/etc/profile.d/conda.sh
46 | source /contained/miniconda/etc/profile.d/mamba.sh
47 | conda activate $(cat /contained/setup/environment.yml | egrep "name: .+$" | sed -e 's/^name:[ \t]*//')
48 |
49 | export MUJOCO_PY_MJKEY_PATH='/contained/software/mujoco/mjkey.txt'
50 | export MUJOCO_PY_MUJOCO_PATH='/contained/software/mujoco/mujoco210'
51 | export MJKEY_PATH='/contained/software/mujoco/mjkey.txt'
52 | export MJLIB_PATH='/contained/software/mujoco/mujoco210/bin/libmujoco210.so'
53 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/contained/software/mujoco/mujoco210/bin"
54 | export D4RL_SUPPRESS_IMPORT_ERROR=1
55 |
56 | export WANDB_API_KEY=''
57 |
58 | # Set up python path for the research project, put your home dir here.
59 | export PYTHONPATH="$PYTHONPATH:"
60 |
61 | %runscript
62 | #! /bin/bash
63 | python -m "$@"
64 | # Entry point for singularity run
65 |
--------------------------------------------------------------------------------
/deps/environment.yml:
--------------------------------------------------------------------------------
1 | name: project-brc
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.9
6 | - pip
7 | - numpy
8 | - scipy
9 | - h5py
10 | - matplotlib
11 | - scikit-learn
12 | - jupyter
13 | - tqdm
14 | - seaborn
15 | - mesalib
16 | - glew
17 | - glfw
18 | - Cython
19 | - conda-forge::opencv
20 | - conda-forge::jax=0.4.14
21 | - conda-forge::jaxlib=0.4.14=*cuda*
22 | - pip:
23 | - -r requirements.txt
--------------------------------------------------------------------------------
/deps/requirements.txt:
--------------------------------------------------------------------------------
1 | numba
2 | opt-einsum
3 | numpy
4 | absl-py
5 | termcolor
6 | matplotlib
7 | mujoco
8 | mujoco-py
9 | ml-collections
10 | cython<3
11 | wandb
12 | imageio
13 | moviepy
14 | opensimplex
15 | pygame
16 | libtmux
17 | threadpoolctl==3.1.0
18 | plotly
19 |
20 | tensorflow-probability==0.19.0
21 | d4rl==1.1
22 | dm_control==1.0.15
23 | dm-env==1.6
24 | dm-tree==0.1.8
25 | gym==0.23.1
26 |
27 | flax==0.7.4
28 | optax==0.1.7
29 | orbax==0.1.9
30 | distrax==0.1.4
31 | chex==0.1.82
--------------------------------------------------------------------------------
/experiment/ant_helper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
4 |
5 | def get_canvas_image(canvas):
6 | canvas.draw()
7 | out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
8 | out_image = out_image.reshape(canvas.get_width_height()[::-1] + (3,))
9 | return out_image
10 |
11 | def value_image(env, dataset, value_fn, mask, clip=False):
12 | """
13 | Visualize the value function.
14 | Args:
15 | env: The environment.
16 | value_fn: a function with signature value_fn([# states, state_dim]) -> [#states, 1]
17 | Returns:
18 | A numpy array of the image.
19 | """
20 | fig = plt.figure(tight_layout=True)
21 | canvas = FigureCanvas(fig)
22 | plot_value(env, dataset, value_fn, mask, fig, plt.gca(), clip=clip)
23 | image = get_canvas_image(canvas)
24 | plt.close(fig)
25 | return image
26 |
27 | def plot_value(env, dataset, value_fn, mask, fig, ax, title=None, clip=True):
28 | N = 14
29 | M = 20
30 | ob_xy = env.XY(n=N, m=M)
31 |
32 | base_observation = np.copy(dataset['observations'][0])
33 | base_observations = np.tile(base_observation, (5, ob_xy.shape[0], 1))
34 | base_observations[:, :, :2] = ob_xy
35 | base_observations[:, :, 15:17] = 0.0
36 | base_observations[0, :, 15] = 1.0
37 | base_observations[1, :, 16] = 1.0
38 | base_observations[2, :, 15] = -1.0
39 | base_observations[3, :, 16] = -1.0
40 | print("Base observations, ", base_observations.shape)
41 |
42 |
43 | values = []
44 | for i in range(5):
45 | values.append(value_fn(base_observations[i]))
46 | values = np.stack(values, axis=0)
47 | print("Values", values.shape)
48 |
49 | x, y = ob_xy[:, 0], ob_xy[:, 1]
50 | x = x.reshape(N, M)
51 | y = y.reshape(N, M) * 0.975 + 0.7
52 | values = values.reshape(5, N, M)
53 | values[-1, 10, 0] = np.min(values[-1]) + 0.3 # Hack to make the scaling not show small errors.
54 | print("Clip:", clip)
55 | if clip:
56 | mesh = ax.pcolormesh(x, y, values[-1], cmap='viridis', vmin=-0.1, vmax=1.0)
57 | else:
58 | mesh = ax.pcolormesh(x, y, values[-1], cmap='viridis')
59 |
60 | v = (values[1] - values[3]) / 2
61 | u = (values[0] - values[2]) / 2
62 | uv_dist = np.sqrt(u**2 + v**2) + 1e-6
63 | # Normalize u,v
64 | un = u / uv_dist
65 | vn = v / uv_dist
66 | un[uv_dist < 0.1] = 0
67 | vn[uv_dist < 0.1] = 0
68 |
69 | plt.quiver(x, y, un, vn, color='r', pivot='mid', scale=0.75, scale_units='xy')
70 |
71 | if mask is not None and type(mask) == np.ndarray:
72 | # mask = NxM array of things to unmask.
73 | from matplotlib.colors import LinearSegmentedColormap
74 | colors = [(0,0,0,c) for c in np.linspace(0,1,100)]
75 | cmapred = LinearSegmentedColormap.from_list('mycmap', colors, N=5)
76 | mask_mesh_ax = ax.pcolormesh(x, y, mask, cmap=cmapred)
77 | elif mask is not None and type(mask) is list:
78 | maskmesh = np.ones((N, M))
79 | for xy in mask:
80 | for xi in range(N):
81 | for yi in range(M):
82 | if np.linalg.norm(np.array(xy) - np.array([x[xi, yi], y[xi, yi]])) < 1.4:
83 | # print(xy, x[xi, yi], y[xi, yi])
84 | maskmesh[xi,yi] = 0
85 | from matplotlib.colors import LinearSegmentedColormap
86 | colors = [(0,0,0,c) for c in np.linspace(0,1,100)]
87 | cmapred = LinearSegmentedColormap.from_list('mycmap', colors, N=5)
88 | mask_mesh_ax = ax.pcolormesh(x, y, maskmesh, cmap=cmapred)
89 |
90 | env.draw(ax, scale=0.95)
91 |
92 |
93 |
94 | # env.draw(ax, scale=1.0)
95 |
96 | # divider = make_axes_locatable(ax)
97 | # cax = divider.append_axes('right', size='5%', pad=0.05)
98 | # fig.colorbar(mesh, cax=cax, orientation='vertical')
99 |
100 | if title:
101 | ax.set_title(title)
--------------------------------------------------------------------------------
/experiment/rewards_eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | import opensimplex
4 | import jax
5 | import jax.numpy as jnp
6 | from functools import partial
7 |
8 | from fre.experiment.rewards_unsupervised import RewardFunction
9 |
10 |
11 | class VelocityRewardFunction(RewardFunction):
12 | def __init__(self):
13 | pass
14 |
15 | # Select an XY velocity from a future state in the trajectory.
16 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
17 | batch_size = traj_states.shape[0]
18 | selected_traj_state_idx = np.random.randint(traj_states.shape[1], size=(batch_size,))
19 | selected_traj_state = traj_states[np.arange(batch_size), selected_traj_state_idx] # (batch_size, obs_dim)
20 | params = selected_traj_state[:, 15:17] # (batch_size, 2)
21 | params[:batch_size//4] = np.random.uniform(-1, 1, size=(batch_size//4, 2)) # Randomize 25% of the time.
22 | params = params / np.linalg.norm(params, axis=-1, keepdims=True) # Normalize XY
23 |
24 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
25 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
26 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
27 |
28 | decode_pairs = random_states_decode
29 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
30 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
31 |
32 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
33 | masks = np.ones_like(rewards) # (batch_size,)
34 |
35 | return params, encode_pairs, decode_pairs, rewards, masks
36 |
37 | def compute_reward(self, states, params):
38 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
39 | xy_vels = states[..., 15:17] * 0.33820298
40 | return np.sum(xy_vels * params, axis=-1) # (batch_size,)
41 |
42 | def make_encoder_pairs_testing(self, params, random_states):
43 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
44 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
45 |
46 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
47 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
48 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
49 |
50 | class TestRewMatrix(RewardFunction):
51 | def __init__(self):
52 | self.pos = np.zeros((36, 24))
53 | self.xvel = np.zeros((36, 24))
54 | self.yvel = np.zeros((36, 24))
55 |
56 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
57 | batch_size = traj_states.shape[0]
58 | params = np.zeros((batch_size, 1)) # (batch_size, 1)
59 |
60 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
61 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
62 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
63 |
64 | decode_pairs = random_states_decode
65 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
66 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
67 |
68 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
69 | masks = np.ones_like(rewards) # (batch_size,)
70 |
71 | return params, encode_pairs, decode_pairs, rewards, masks
72 |
73 | def compute_reward(self, s, params):
74 | rews = np.zeros_like(s[..., 0]) # (batch, examples)
75 | # XY Vel Reward
76 | xy_vels = s[..., 15:17] * 0.33820298
77 |
78 | x = s[..., 0].astype(int).clip(0, 35)
79 | y = s[..., 1].astype(int).clip(0, 23)
80 | simplex = self.pos[x, y]
81 | simplex_xvel = self.xvel[x, y]
82 | simplex_yvel = self.yvel[x, y]
83 | rews = (simplex > 0.3).astype(float) * 0.5
84 | rews += xy_vels[...,0] * simplex_xvel + xy_vels[...,1] * simplex_yvel
85 |
86 | return rews
87 |
88 | def make_encoder_pairs_testing(self, params, random_states):
89 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
90 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
91 | batch_size = random_states.shape[0]
92 |
93 | # TODO: Be smarter about the states to use here.
94 |
95 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
96 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
97 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
98 |
99 | class SimplexRewardFunction(RewardFunction):
100 | def __init__(self, num_simplex):
101 | self.simplex_size = num_simplex
102 | self.simplex_seeds_pos = np.zeros((self.simplex_size, 36, 24))
103 | self.simplex_seeds_xvel = np.zeros((self.simplex_size, 36, 24))
104 | self.simplex_seeds_yvel = np.zeros((self.simplex_size, 36, 24))
105 | self.simplex_best_xy = np.zeros((self.simplex_size, 10, 2))
106 | print("Generating simplex seeds")
107 | xi = np.arange(36)
108 | yi = np.arange(24)
109 | for r in tqdm.tqdm(range(self.simplex_size)):
110 | opensimplex.seed(r)
111 | self.simplex_seeds_pos[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T
112 | opensimplex.seed(r + self.simplex_size)
113 | self.simplex_seeds_xvel[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T
114 | opensimplex.seed(r + self.simplex_size * 2)
115 | self.simplex_seeds_yvel[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T
116 |
117 | best_topn = np.argpartition(self.simplex_seeds_pos[r].flatten(), -10)[-10:] # (10,)
118 | best_xy = np.array(np.unravel_index(best_topn, self.simplex_seeds_pos[r].shape)).T # (10, 2)
119 | self.simplex_best_xy[r] = best_xy
120 | self.simplex_seeds_xvel[np.abs(self.simplex_seeds_xvel) < 0.5] = 0
121 | self.simplex_seeds_yvel[np.abs(self.simplex_seeds_yvel) < 0.5] = 0
122 |
123 | # Select an XY velocity from a future state in the trajectory.
124 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
125 | batch_size = traj_states.shape[0]
126 | params = np.random.randint(self.simplex_size, size=(batch_size, 1)) # (batch_size, 1)
127 |
128 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
129 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
130 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
131 |
132 | decode_pairs = random_states_decode
133 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
134 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
135 |
136 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
137 | masks = np.ones_like(rewards) # (batch_size,)
138 |
139 | return params, encode_pairs, decode_pairs, rewards, masks
140 |
141 | def compute_reward(self, states, params):
142 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
143 |
144 | simplex_id = params[..., 0].astype(int)
145 | x = states[..., 0].astype(int).clip(0, 35)
146 | y = states[..., 1].astype(int).clip(0, 23)
147 | simplex = self.simplex_seeds_pos[simplex_id, x, y]
148 | simplex_xvel = self.simplex_seeds_xvel[simplex_id, x, y]
149 | simplex_yvel = self.simplex_seeds_yvel[simplex_id, x, y]
150 | rews = -1 + (simplex > 0.3).astype(float) * 0.5
151 | xy_vels = states[..., 15:17] * 0.33820298
152 | rews += xy_vels[...,0] * simplex_xvel + xy_vels[...,1] * simplex_yvel
153 | return rews # (batch_size,)
154 |
155 | def make_encoder_pairs_testing(self, params, random_states):
156 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
157 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
158 | batch_size = random_states.shape[0]
159 |
160 | # For simplex rewards, make sure to include the top 4 best points.
161 | simplex_id = params[..., 0].astype(int)
162 | random_best_4 = np.random.randint(0, 10, size=(batch_size, 4))
163 | random_states[:, :4, :2] = self.simplex_best_xy[simplex_id[:, None], random_best_4, :]
164 |
165 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
166 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
167 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
168 |
169 | class TestRewMatrixEdges(TestRewMatrix):
170 | def __init__(self):
171 | super().__init__()
172 | self.pos[:3, :] = 1
173 | self.pos[-3:, :] = 1
174 | self.pos[:, :3] = 1
175 | self.pos[:, -3:] = 1
176 |
177 | class TestRewLoop(TestRewMatrix):
178 | def __init__(self):
179 | super().__init__()
180 | self.pos[22:33, 14:18] = 1
181 | self.xvel[22:33, 14:18] = -1
182 |
183 | self.pos[21:, 0:3] = 1
184 | self.xvel[21:, 0:3] = 1
185 |
186 | self.pos[33:, 3:18] = 1
187 | self.yvel[33:, 3:18] = 1
188 |
189 | self.pos[18:21, 0:7] = 1
190 | self.yvel[18:21, 0:7] = -1
191 |
192 | class TestRewPath(TestRewMatrix):
193 | def __init__(self):
194 | super().__init__()
195 | self.pos[3:21, 7:10] = 1
196 | self.xvel[3:21, 7:10] = -1
197 |
198 | self.pos[0:3, 3:10] = 1
199 | self.yvel[0:3, 3:10] = -1
200 |
201 | self.pos[0:18, 0:3] = 1
202 | self.xvel[0:18, 0:3] = 1
203 |
204 | class TestRewLoop2(TestRewMatrix):
205 | def __init__(self):
206 | super().__init__()
207 | self.pos[22:33, 14:18] = 1
208 | self.pos[21:, 0:3] = 1
209 | self.pos[33:, 3:18] = 1
210 | self.pos[18:21, 0:7] = 1
211 |
212 | class TestRewPath2(TestRewMatrix):
213 | def __init__(self):
214 | super().__init__()
215 | self.pos[3:21, 7:10] = 1
216 | self.pos[0:3, 3:10] = 1
217 | self.pos[0:18, 0:3] = 1
218 |
219 |
220 | # =================== For DMC
221 |
222 | class VelocityRewardFunctionWalker(RewardFunction):
223 | def __init__(self):
224 | pass
225 |
226 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
227 | batch_size = traj_states.shape[0]
228 | params = np.random.uniform(low=0, high=8, size=(batch_size, 1)) # (batch_size, 1)
229 |
230 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
231 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
232 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
233 |
234 | decode_pairs = random_states_decode
235 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
236 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
237 |
238 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
239 | masks = np.ones_like(rewards) # (batch_size,)
240 |
241 | return params, encode_pairs, decode_pairs, rewards, masks
242 |
243 | def _sigmoids(self, x, value_at_1, sigmoid):
244 | if sigmoid == 'gaussian':
245 | scale = np.sqrt(-2 * np.log(value_at_1))
246 | return np.exp(-0.5 * (x*scale)**2)
247 |
248 | elif sigmoid == 'linear':
249 | scale = 1-value_at_1
250 | scaled_x = x*scale
251 | return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0)
252 |
253 | def tolerance(self, x, lower, upper, margin=0.0, sigmoid='gaussian', value_at_margin=0.1):
254 | in_bounds = np.logical_and(lower <= x, x <= upper)
255 | d = np.where(x < lower, lower - x, x - upper) / margin
256 | value = np.where(in_bounds, 1.0, self._sigmoids(d, value_at_margin, sigmoid))
257 | return value
258 |
259 | def compute_reward(self, states, params):
260 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
261 |
262 | _STAND_HEIGHT = 1.2
263 | horizontal_velocity = states[..., 24:25]
264 | torso_upright = states[..., 25:26]
265 | torso_height = states[..., 26:27]
266 | standing = self.tolerance(torso_height, lower=_STAND_HEIGHT, upper=float('inf'), margin=_STAND_HEIGHT/2)
267 | upright = (1 + torso_upright) / 2
268 | stand_reward = (3*standing + upright) / 4
269 | move_reward = self.tolerance(horizontal_velocity,
270 | lower=params,
271 | upper=float('inf'),
272 | margin=params/2,
273 | value_at_margin=0.5,
274 | sigmoid='linear')
275 | # move_reward[params == 0] = stand_reward[params == 0]
276 | rew = stand_reward * (5*move_reward + 1) / 6
277 | return rew[..., 0]
278 |
279 | def make_encoder_pairs_testing(self, params, random_states):
280 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
281 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
282 |
283 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
284 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
285 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
286 |
287 | class VelocityRewardFunctionCheetah(RewardFunction):
288 | def __init__(self):
289 | pass
290 |
291 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
292 | batch_size = traj_states.shape[0]
293 | params = np.random.uniform(low=-10, high=10, size=(batch_size, 1)) # (batch_size, 1)
294 |
295 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
296 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
297 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
298 |
299 | decode_pairs = random_states_decode
300 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
301 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
302 |
303 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
304 | masks = np.ones_like(rewards) # (batch_size,)
305 |
306 | return params, encode_pairs, decode_pairs, rewards, masks
307 |
308 | def _sigmoids(self, x, value_at_1, sigmoid):
309 | if sigmoid == 'linear':
310 | scale = 1-value_at_1
311 | scaled_x = x*scale
312 | return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0)
313 | else:
314 | raise NotImplementedError
315 |
316 | def tolerance(self, x, lower, upper, margin=0.0, sigmoid='linear', value_at_margin=0):
317 | in_bounds = np.logical_and(lower <= x, x <= upper)
318 | d = np.where(x < lower, lower - x, x - upper) / margin
319 | value = np.where(in_bounds, 1.0, self._sigmoids(d, value_at_margin, sigmoid))
320 | return value
321 |
322 | def compute_reward(self, states, params):
323 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
324 |
325 | horizontal_velocity = states[..., 17:18]
326 | sign_of_param = np.sign(params)
327 | horizontal_velocity = horizontal_velocity * sign_of_param
328 | rew = self.tolerance(horizontal_velocity,
329 | lower=np.abs(params),
330 | upper=float('inf'),
331 | margin=np.abs(params),
332 | value_at_margin=0,
333 | sigmoid='linear')
334 | return rew[..., 0]
335 |
336 | def make_encoder_pairs_testing(self, params, random_states):
337 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
338 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
339 |
340 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
341 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
342 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
343 |
344 | # =================== For Kitchen
345 |
346 | class SingleTaskRewardFunction(RewardFunction):
347 | def __init__(self):
348 | self.obs_element_indices = {
349 | "bottom left burner": np.array([11, 12]),
350 | "top left burner": np.array([15, 16]),
351 | "light switch": np.array([17, 18]),
352 | "slide cabinet": np.array([19]),
353 | "hinge cabinet": np.array([20, 21]),
354 | "microwave": np.array([22]),
355 | "kettle": np.array([23, 24, 25, 26, 27, 28, 29]),
356 | }
357 | self.obs_element_goals = {
358 | "bottom left burner": np.array([-0.88, -0.01]),
359 | "top left burner": np.array([-0.92, -0.01]),
360 | "light switch": np.array([-0.69, -0.05]),
361 | "slide cabinet": np.array([0.37]),
362 | "hinge cabinet": np.array([0.0, 1.45]),
363 | "microwave": np.array([-0.75]),
364 | "kettle": np.array([-0.23, 0.75, 1.62, 0.99, 0.0, 0.0, -0.06]),
365 | }
366 | self.dist_thresh = 0.3
367 | self.num_tasks = len(self.obs_element_indices)
368 |
369 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
370 | batch_size = traj_states.shape[0]
371 | params = np.random.randint(self.num_tasks, size=(batch_size, 1)) # (batch_size, 1)
372 | params = np.eye(self.num_tasks)[params[:, 0]] # (batch_size, num_tasks)
373 |
374 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
375 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
376 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
377 |
378 | decode_pairs = random_states_decode
379 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
380 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
381 |
382 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
383 | masks = np.ones_like(rewards) # (batch_size,)
384 |
385 | return params, encode_pairs, decode_pairs, rewards, masks
386 |
387 | def compute_reward(self, states, params):
388 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
389 | task_rewards = []
390 | for task, target_indices in self.obs_element_indices.items():
391 | task_dists = np.linalg.norm(states[..., target_indices] - self.obs_element_goals[task], axis=-1)
392 | task_completes = (task_dists < self.dist_thresh).astype(float)
393 | task_rewards.append(task_completes)
394 | task_rewards = np.stack(task_rewards, axis=-1)
395 |
396 | return np.sum(task_rewards * params, axis=-1)
397 |
398 | def make_encoder_pairs_testing(self, params, random_states):
399 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
400 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
401 |
402 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
403 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
404 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
405 |
406 |
--------------------------------------------------------------------------------
/experiment/rewards_unsupervised.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | import opensimplex
4 | import jax
5 | import jax.numpy as jnp
6 | from functools import partial
7 |
8 | class RewardFunction():
9 | # Given a batch of trajectory states and random states, generate a reward function.
10 | # Return the labelled state-reward pairs. (batch_size, num_pairs, obs_dim + 1)
11 | def generate_params_and_pairs(self, traj_states, random_states):
12 | raise NotImplementedError
13 |
14 | # Given a batch of states and a batch of parameters, compute the reward.
15 | def compute_reward(self, states, params):
16 | raise NotImplementedError
17 |
18 | class GoalReachingRewardFunction(RewardFunction):
19 | def __init__(self):
20 | self.p_current = 0.2
21 | self.p_trajectory = 0.5
22 | self.p_random = 0.3
23 |
24 | # TODO: If this is slow, we can try and JIT it.
25 | # Select a random goal from the provided states.
26 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
27 | all_states = np.concatenate([traj_states, random_states], axis=1)
28 | batch_size = all_states.shape[0]
29 | p_trajectory_normalized = self.p_trajectory / traj_states.shape[1]
30 | p_random_normalized = self.p_random / random_states.shape[1]
31 | probabilities = [self.p_current] + [p_trajectory_normalized] * (traj_states.shape[1]-1) \
32 | + [p_random_normalized] * random_states.shape[1]
33 | probabilities = np.array(probabilities) / np.sum(probabilities)
34 | selected_goal_idx = np.random.choice(len(probabilities), size=(batch_size,), p=probabilities)
35 | selected_goal = all_states[np.arange(batch_size), selected_goal_idx]
36 |
37 | params = selected_goal # (batch_size, obs_dim)
38 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
39 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
40 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
41 |
42 | decode_pairs = random_states_decode
43 | decode_pairs[:, 0] = params # Decode the goal state too.
44 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
45 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
46 |
47 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
48 | masks = -rewards # If (rew=-1, mask=1), else (rew=0, mask=0)
49 |
50 | return params, encode_pairs, decode_pairs, rewards, masks
51 |
52 | def compute_reward(self, states, params, delta=False):
53 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
54 | if states.shape[-1] == 29: # AntMaze
55 | if delta:
56 | dists = np.linalg.norm(states - params, axis=-1)
57 | is_goal = (dists < 0.1)
58 | else:
59 | dists = np.linalg.norm(states[..., :2] - params[..., :2], axis=-1)
60 | is_goal = (dists < 2)
61 | return -1 + is_goal.astype(float) # (batch_size,)
62 | elif states.shape[-1] == 18: # Cheetah
63 | std = np.array([[0.4407440506721877, 10.070289916801876, 0.5172332956856273, 0.5601041145815341, 0.518947027289748, 0.3204431592542281, 0.5501848643154092, 0.3856393812067661, 1.9882502334402663, 1.6377168569884073, 4.308505013609855, 12.144181770553105, 13.537567521831702, 16.88983033626308, 7.715009572436841, 14.345667964212357, 10.6904255152284, 100]])
64 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
65 | # if len(states.shape) == 3:
66 | # breakpoint()
67 | dists_per_dim = states - params
68 | dists_per_dim = dists_per_dim / std
69 | dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1]
70 | is_goal = (dists < 0.08)
71 | # print(dists_per_dim)
72 | # print(dists, is_goal)
73 | return -1 + is_goal.astype(float) # (batch_size,)
74 | elif states.shape[-1] == 27: # Walker
75 | std = np.array([[0.7212967364054736, 0.6775020895964047, 0.7638155887842976, 0.6395721376821286, 0.6849394775886244, 0.7078581708129903, 0.7113168519036742, 0.6753408522523937, 0.6818095329625652, 0.7133958718133511, 0.65227578338642, 0.757622576816855, 0.7311826446274479, 0.6745824928740024, 0.36822491550384456, 2.1134839667805805, 1.813353841099317, 10.594648894374815, 17.41041469033713, 17.836743227082106, 22.399097178637533, 16.1492222730888, 15.693574546557201, 18.539929326905067, 100, 100, 100]])
76 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
77 | dists_per_dim = states - params
78 | dists_per_dim = dists_per_dim / std
79 | dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1]
80 | is_goal = (dists < 0.2)
81 | return -1 + is_goal.astype(float) # e6yfwsc ebnev (batch_size,)
82 | elif states.shape[-1] == 30: # Kitchen
83 | dists_per_dim = states - params
84 | dists_per_dim = dists_per_dim
85 | dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1]
86 | is_goal = (dists < 1e-6)
87 | return -1 + is_goal.astype(float)
88 | else:
89 | raise NotImplementedError
90 |
91 | def make_encoder_pairs_testing(self, params, random_states):
92 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
93 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
94 |
95 | if random_states.shape[-1] == 29: # AntMaze
96 | random_states[:, 0, :2] = params[:, :2] # Make sure to include the goal.
97 | else:
98 | random_states[:, 0] = params
99 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
100 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
101 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
102 |
103 | class LinearRewardFunction(RewardFunction):
104 | def __init__(self):
105 | pass
106 |
107 | # Randomly generate a linear weighting over state features.
108 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
109 | assert len(traj_states.shape) == 3, traj_states.shape # (batch_size, traj_len, obs_dim)
110 | batch_size = traj_states.shape[0]
111 | state_len = traj_states.shape[-1]
112 |
113 | params = np.random.uniform(-1, 1, size=(batch_size, state_len)) # Uniform weighting.
114 | random_mask = np.random.uniform(size=(batch_size,state_len)) < 0.9
115 | if state_len == 29:
116 | random_mask[:, :2] = True # Zero out the XY position for antmaze.
117 | random_mask_positive = np.random.randint(2, state_len, size=(batch_size))
118 | random_mask[np.arange(batch_size), random_mask_positive] = False # Force at least one positive weight.
119 | params[random_mask] = 0 # Zero out some of the weights.
120 | # if state_len == 29:
121 | # params = params / np.linalg.norm(params, axis=-1, keepdims=True) # Normalize XY
122 |
123 | # Remove auxilliary features during training.
124 | if state_len == 27:
125 | params[:, -3:] = 0
126 | if state_len == 18:
127 | params[:, -1:] = 0
128 |
129 | clip_bit = np.random.uniform(size=(batch_size,)) < 0.5
130 | params = np.concatenate([params, clip_bit[:, None]], axis=-1) # (batch_size, obs_dim + 1)
131 |
132 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
133 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
134 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
135 |
136 | decode_pairs = random_states_decode
137 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
138 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
139 |
140 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
141 | masks = np.ones_like(rewards) # (batch_size,)
142 |
143 | return params, encode_pairs, decode_pairs, rewards, masks
144 |
145 | def compute_reward(self, states, params):
146 | params_raw = params[..., :-1]
147 | assert len(states.shape) == len(params_raw.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
148 | r = np.sum(states * params_raw, axis=-1) # (batch_size,)
149 | r = np.where(params[..., -1] > 0, np.clip(r, 0, 1), np.clip(r, -1, 1))
150 | return r
151 |
152 | def make_encoder_pairs_testing(self, params, random_states):
153 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
154 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
155 |
156 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
157 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
158 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
159 |
160 | class RandomRewardFunction(RewardFunction):
161 | def __init__(self, num_simplex, obs_len=29):
162 | # Pre-compute parameter matrices.
163 | print("Generating parameter matrices...")
164 | self.simplex_size = num_simplex
165 | np_random = np.random.RandomState(0)
166 | self.param_w1 = np_random.normal(size=(self.simplex_size, obs_len, 32)) * np.sqrt(1/32)
167 | self.param_b1 = np_random.normal(size=(self.simplex_size, 1, 32)) * np.sqrt(16)
168 | self.param_w2 = np_random.normal(size=(self.simplex_size, 32, 1)) * np.sqrt(1/16)
169 |
170 | # Remove auxilliary features during training.
171 | if obs_len == 27:
172 | self.param_w1[:, -3:] = 0
173 | if obs_len == 18:
174 | self.param_w1[:, -1:] = 0
175 |
176 | # Random neural network.
177 | def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
178 | batch_size = traj_states.shape[0]
179 | params = np.random.randint(self.simplex_size, size=(batch_size, 1)) # (batch_size, 1)
180 |
181 | encode_pairs = np.concatenate([traj_states, random_states], axis=1)
182 | encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
183 | encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)
184 |
185 | decode_pairs = random_states_decode
186 | decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
187 | decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)
188 |
189 | rewards = encode_pairs[:, 0, -1] # (batch_size,)
190 | masks = np.ones_like(rewards) # (batch_size,)
191 |
192 | return params, encode_pairs, decode_pairs, rewards, masks
193 |
194 | def compute_reward(self, states, params):
195 | assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
196 |
197 | param_id = params[..., 0].astype(int)
198 | param1_w = self.param_w1[param_id]
199 | param1_b = self.param_b1[param_id]
200 | param2_w = self.param_w2[param_id]
201 |
202 | obs = states
203 | x = np.expand_dims(obs, -2) # [batch, (pairs), 1, features_in]
204 | x = np.matmul(x, param1_w) # [batch, (pairs), 1, features_out]
205 | x = x + param1_b
206 | x = np.tanh(x)
207 | x = np.matmul(x, param2_w) # [batch, (pairs), 1, 1]
208 | x = x.squeeze(-1).squeeze(-1) # [batch, (pairs)]
209 | x = np.clip(x, -1, 1)
210 | return x
211 |
212 | def make_encoder_pairs_testing(self, params, random_states):
213 | assert len(params.shape) == 2, params.shape # (batch_size, 2)
214 | assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
215 | batch_size = random_states.shape[0]
216 |
217 | reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
218 | reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
219 | return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)
220 |
--------------------------------------------------------------------------------