├── .gitignore
├── agent
├── ddpg.yaml
├── becl.yaml
├── becl.py
└── ddpg.py
├── dmc_benchmark.py
├── conda_env.yml
├── LICENSE
├── pretrain.yaml
├── custom_dmc_tasks
├── __init__.py
├── hopper.xml
├── walker.xml
├── cheetah.xml
├── cheetah.py
├── walker.py
├── hopper.py
├── jaco.py
├── quadruped.xml
└── quadruped.py
├── finetune.yaml
├── README.md
├── video.py
├── logger.py
├── replay_buffer.py
├── pretrain.py
├── finetune.py
├── utils.py
└── dmc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /.DS_Store
--------------------------------------------------------------------------------
/agent/ddpg.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.ddpg.DDPGAgent
3 | name: ddpg
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | nstep: 3
20 | batch_size: 1024 # 256 for pixels
21 | init_critic: true
22 | update_encoder: ${update_encoder}
--------------------------------------------------------------------------------
/dmc_benchmark.py:
--------------------------------------------------------------------------------
1 | DOMAINS = [
2 | 'walker',
3 | 'quadruped',
4 | 'jaco',
5 | ]
6 |
7 | WALKER_TASKS = [
8 | 'walker_stand',
9 | 'walker_walk',
10 | 'walker_run',
11 | 'walker_flip',
12 | ]
13 |
14 | QUADRUPED_TASKS = [
15 | 'quadruped_walk',
16 | 'quadruped_run',
17 | 'quadruped_stand',
18 | 'quadruped_jump',
19 | ]
20 |
21 | JACO_TASKS = [
22 | 'jaco_reach_top_left',
23 | 'jaco_reach_top_right',
24 | 'jaco_reach_bottom_left',
25 | 'jaco_reach_bottom_right',
26 | ]
27 |
28 | TASKS = WALKER_TASKS + QUADRUPED_TASKS + JACO_TASKS
29 |
30 | PRIMAL_TASKS = {
31 | 'walker': 'walker_stand',
32 | 'jaco': 'jaco_reach_top_left',
33 | 'quadruped': 'quadruped_walk'
34 | }
35 |
--------------------------------------------------------------------------------
/agent/becl.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.becl.BECLAgent
3 | name: becl
4 | reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | skill_dim: 16
20 | update_skill_every_step: 50
21 | nstep: 3
22 | batch_size: 1024
23 | init_critic: true
24 | update_encoder: ${update_encoder}
25 |
26 | # extra hyperparameter
27 | contrastive_update_rate: 3
28 | temperature: 0.5
29 |
30 | # skill finetuning ablation
31 | skill: -1
32 |
--------------------------------------------------------------------------------
/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: urlb
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.8
6 | - pip=21.1.3
7 | - numpy=1.19.2
8 | - absl-py=0.13.0
9 | - pyparsing=2.4.7
10 | - jupyterlab=3.0.14
11 | - scikit-image=0.18.1
12 | - nvidia::cudatoolkit=11.1
13 | - pytorch::pytorch
14 | - pytorch::torchvision
15 | - pytorch::torchaudio
16 | - pip:
17 | - termcolor==1.1.0
18 | - git+git://github.com/deepmind/dm_control.git
19 | - tb-nightly
20 | - imageio==2.9.0
21 | - imageio-ffmpeg==0.4.4
22 | - hydra-core==1.1.0
23 | - hydra-submitit-launcher==1.1.5
24 | - pandas==1.3.0
25 | - ipdb==0.13.9
26 | - yapf==0.31.0
27 | - mujoco_py==2.0.2.13
28 | - sklearn==0.0
29 | - matplotlib==3.4.2
30 | - opencv-python==4.5.3.56
31 | - wandb==0.11.1
32 | - moviepy==1.0.3
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pretrain.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - agent: ddpg
3 | - override hydra/launcher: submitit_local
4 |
5 | # mode
6 | reward_free: true
7 | # task settings
8 | domain: walker # primal task will be infered in runtime
9 | obs_type: states # [states, pixels]
10 | frame_stack: 3 # only works if obs_type=pixels
11 | action_repeat: 1 # set to 2 for pixels
12 | discount: 0.99
13 | # train settings
14 | num_train_frames: 2000010
15 | num_seed_frames: 4000
16 | # eval
17 | eval_every_frames: 10000
18 | num_eval_episodes: 10
19 | # snapshot
20 | snapshots: [100000, 500000, 1000000, 2000000]
21 | snapshot_dir: ../../../models/${obs_type}/${domain}/${agent.name}/${seed}
22 | # replay buffer
23 | replay_buffer_size: 1000000
24 | replay_buffer_num_workers: 4
25 | batch_size: ${agent.batch_size}
26 | nstep: ${agent.nstep}
27 | update_encoder: true # should always be true for pre-training
28 | # misc
29 | seed: 1
30 | device: cuda
31 | save_video: true
32 | save_train_video: false
33 | use_tb: false
34 | use_wandb: false
35 | # experiment
36 | experiment: exp
37 |
38 |
39 | hydra:
40 | run:
41 | dir: ./exp_local/pretrain${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name}
42 | sweep:
43 | dir: ./exp_local/pretrain${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name}
44 | subdir: ${hydra.job.num}
45 | launcher:
46 | timeout_min: 4300
47 | cpus_per_task: 10
48 | gpus_per_node: 1
49 | tasks_per_node: 1
50 | mem_gb: 160
51 | nodes: 1
52 | submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}/.slurm
53 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from custom_dmc_tasks import cheetah
2 | from custom_dmc_tasks import walker
3 | from custom_dmc_tasks import hopper
4 | from custom_dmc_tasks import quadruped
5 | from custom_dmc_tasks import jaco
6 |
7 |
8 | def make(domain, task,
9 | task_kwargs=None,
10 | environment_kwargs=None,
11 | visualize_reward=False):
12 |
13 | if domain == 'cheetah':
14 | return cheetah.make(task,
15 | task_kwargs=task_kwargs,
16 | environment_kwargs=environment_kwargs,
17 | visualize_reward=visualize_reward)
18 | elif domain == 'walker':
19 | return walker.make(task,
20 | task_kwargs=task_kwargs,
21 | environment_kwargs=environment_kwargs,
22 | visualize_reward=visualize_reward)
23 | elif domain == 'hopper':
24 | return hopper.make(task,
25 | task_kwargs=task_kwargs,
26 | environment_kwargs=environment_kwargs,
27 | visualize_reward=visualize_reward)
28 | elif domain == 'quadruped':
29 | return quadruped.make(task,
30 | task_kwargs=task_kwargs,
31 | environment_kwargs=environment_kwargs,
32 | visualize_reward=visualize_reward)
33 | else:
34 | raise f'{task} not found'
35 |
36 | assert None
37 |
38 |
39 | def make_jaco(task, obs_type, seed):
40 | return jaco.make(task, obs_type, seed)
--------------------------------------------------------------------------------
/finetune.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - agent: ddpg
3 | - override hydra/launcher: submitit_local
4 |
5 | # mode
6 | reward_free: false
7 | # task settings
8 | task: walker_stand
9 | obs_type: states # [states, pixels]
10 | frame_stack: 3 # only works if obs_type=pixels
11 | action_repeat: 1 # set to 2 for pixels
12 | discount: 0.99
13 | # train settings
14 | num_train_frames: 100010
15 | num_seed_frames: 4000
16 | # eval
17 | eval_every_frames: 10000
18 | num_eval_episodes: 10
19 | # pretrained
20 | snapshot_ts: 2000000
21 | snapshot_base_dir: ./models
22 | # replay buffer
23 | replay_buffer_size: 1000000
24 | replay_buffer_num_workers: 4
25 | batch_size: ${agent.batch_size}
26 | nstep: ${agent.nstep}
27 | update_encoder: false # can be either true or false depending if we want to fine-tune encoder
28 | # misc
29 | seed: 1
30 | device: cuda
31 | save_video: true
32 | save_train_video: false
33 | use_tb: false
34 | use_wandb: false
35 | # define specific path for experiment
36 | experiment: ${agent.name}_seed_${seed}
37 | domain: walker
38 | extra_path: .
39 |
40 | hydra:
41 | run:
42 | dir: ./exp_local/finetune_${extra_path}/${domain}/${agent.name}/fintune_${experiment}_${task}_${now:%H%M%S}
43 | sweep:
44 | dir: ./exp_local/finetune_${extra_path}/${domain}/${agent.name}/fintune_${experiment}_${task}_${now:%H%M%S}
45 | subdir: ${hydra.job.num}
46 | launcher:
47 | timeout_min: 4300
48 | cpus_per_task: 10
49 | gpus_per_node: 1
50 | tasks_per_node: 1
51 | mem_gb: 160
52 | nodes: 1
53 | submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}/.slurm
54 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Behavior Contrastive Learning (BeCL)
2 |
3 | This is the official codebase for ICML 2023 paper [BeCL:Behavior Contrastive Learning for Unsupervised Skill Discovery](https://arxiv.org/abs/2305.04477), which utilizes contrastive learning as intrinsic motivation for unsupervised skill discovery.
4 |
5 | If you find this paper useful for your research, please cite:
6 | ```
7 | @misc{yang2023behavior,
8 | title={Behavior Contrastive Learning for Unsupervised Skill Discovery},
9 | author={Rushuai Yang and Chenjia Bai and Hongyi Guo and Siyuan Li and Bin Zhao and Zhen Wang and Peng Liu and Xuelong Li},
10 | year={2023},
11 | eprint={2305.04477},
12 | archivePrefix={arXiv},
13 | primaryClass={cs.LG}
14 | }
15 | ```
16 |
17 | This codebase is built on top of the [Unsupervised Reinforcement Learning Benchmark (URLB) codebase](https://github.com/rll-research/url_benchmark). Our method `BeCL` is implemented in `agents/becl.py` and the config is specified in `agents/becl.yaml`.
18 |
19 | To pre-train BeCL, run the following command:
20 |
21 | ``` sh
22 | python pretrain.py agent=becl domain=walker seed=3
23 | ```
24 |
25 | This script will produce several agent snapshots after training for `100k`, `500k`, `1M`, and `2M` frames and snapshots will be stored in `./models/states//// `. (i.e. the snapshots path is `./models/states/walker/becl/3/ `).
26 |
27 | To finetune BeCL, run the following command:
28 |
29 | ```sh
30 | python finetune.py task=walker_stand obs_type=states agent=becl reward_free=false seed=3 domain=walker snapshot_ts=2000000
31 | ```
32 |
33 | This will load a snapshot stored in `./models/states/walker/becl/3/snapshot_2000000.pt`, initialize `DDPG` with it (both the actor and critic), and start training on `walker_stand` using the extrinsic reward of the task.
34 |
35 | ## Requirements
36 |
37 | We assume you have access to a GPU that can run CUDA 10.2 and CUDNN 8. Then, the simplest way to install all required dependencies is to create an anaconda environment by running
38 | ```sh
39 | conda env create -f conda_env.yml
40 | ```
41 | After the installation ends you can activate your environment with
42 | ```sh
43 | conda activate urlb
44 | ```
45 |
46 | ## Available Domains
47 | We support the following domains.
48 | | Domain | Tasks |
49 | |---|---|
50 | | `walker` | `stand`, `walk`, `run`, `flip` |
51 | | `quadruped` | `walk`, `run`, `stand`, `jump` |
52 | | `jaco` | `reach_top_left`, `reach_top_right`, `reach_bottom_left`, `reach_bottom_right` |
53 |
54 | ### Monitoring
55 | Logs are stored in the `exp_local` folder. To launch tensorboard run:
56 | ```sh
57 | tensorboard --logdir exp_local
58 | ```
59 | The console output is also available in the form:
60 | ```
61 | | train | F: 6000 | S: 3000 | E: 6 | L: 1000 | R: 5.5177 | FPS: 96.7586 | T: 0:00:42
62 | ```
63 | a training entry decodes as
64 | ```
65 | F : total number of environment frames
66 | S : total number of agent steps
67 | E : total number of episodes
68 | R : episode return
69 | FPS: training throughput (frames per second)
70 | T : total training time
71 | ```
72 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imageio
3 | import numpy as np
4 | import wandb
5 |
6 |
7 | class VideoRecorder:
8 | def __init__(self,
9 | root_dir,
10 | render_size=256,
11 | fps=20,
12 | camera_id=0,
13 | use_wandb=False):
14 | if root_dir is not None:
15 | self.save_dir = root_dir / 'eval_video'
16 | self.save_dir.mkdir(exist_ok=True)
17 | else:
18 | self.save_dir = None
19 |
20 | self.render_size = render_size
21 | self.fps = fps
22 | self.frames = []
23 | self.camera_id = camera_id
24 | self.use_wandb = use_wandb
25 |
26 | def init(self, env, enabled=True):
27 | self.frames = []
28 | self.enabled = self.save_dir is not None and enabled
29 | self.record(env)
30 |
31 | def record(self, env):
32 | if self.enabled:
33 | if hasattr(env, 'physics'):
34 | frame = env.physics.render(height=self.render_size,
35 | width=self.render_size,
36 | camera_id=self.camera_id)
37 | else:
38 | frame = env.render()
39 | self.frames.append(frame)
40 |
41 | def log_to_wandb(self):
42 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2))
43 | fps, skip = 6, 8
44 | wandb.log({
45 | 'eval/video':
46 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif")
47 | })
48 |
49 | def save(self, file_name):
50 | if self.enabled:
51 | if self.use_wandb:
52 | self.log_to_wandb()
53 | path = self.save_dir / file_name
54 | imageio.mimsave(str(path), self.frames, fps=self.fps)
55 |
56 |
57 | class TrainVideoRecorder:
58 | def __init__(self,
59 | root_dir,
60 | render_size=256,
61 | fps=20,
62 | camera_id=0,
63 | use_wandb=False):
64 | if root_dir is not None:
65 | self.save_dir = root_dir / 'train_video'
66 | self.save_dir.mkdir(exist_ok=True)
67 | else:
68 | self.save_dir = None
69 |
70 | self.render_size = render_size
71 | self.fps = fps
72 | self.frames = []
73 | self.camera_id = camera_id
74 | self.use_wandb = use_wandb
75 |
76 | def init(self, obs, enabled=True):
77 | self.frames = []
78 | self.enabled = self.save_dir is not None and enabled
79 | self.record(obs)
80 |
81 | def record(self, obs):
82 | if self.enabled:
83 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0),
84 | dsize=(self.render_size, self.render_size),
85 | interpolation=cv2.INTER_CUBIC)
86 | self.frames.append(frame)
87 |
88 | def log_to_wandb(self):
89 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2))
90 | fps, skip = 6, 8
91 | wandb.log({
92 | 'train/video':
93 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif")
94 | })
95 |
96 | def save(self, file_name):
97 | if self.enabled:
98 | if self.use_wandb:
99 | self.log_to_wandb()
100 | path = self.save_dir / file_name
101 | imageio.mimsave(str(path), self.frames, fps=self.fps)
102 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/cheetah.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/cheetah.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Cheetah Domain."""
16 |
17 | import collections
18 | import os
19 |
20 | from dm_control import mujoco
21 | from dm_control.rl import control
22 | from dm_control.suite import base
23 | from dm_control.suite import common
24 | from dm_control.utils import containers
25 | from dm_control.utils import rewards
26 | from dm_control.utils import io as resources
27 |
28 | # How long the simulation will run, in seconds.
29 | _DEFAULT_TIME_LIMIT = 10
30 |
31 | # Running speed above which reward is 1.
32 | _RUN_SPEED = 10
33 | _SPIN_SPEED = 5
34 |
35 | SUITE = containers.TaggedTasks()
36 |
37 |
38 | def make(task,
39 | task_kwargs=None,
40 | environment_kwargs=None,
41 | visualize_reward=False):
42 | task_kwargs = task_kwargs or {}
43 | if environment_kwargs is not None:
44 | task_kwargs = task_kwargs.copy()
45 | task_kwargs['environment_kwargs'] = environment_kwargs
46 | env = SUITE[task](**task_kwargs)
47 | env.task.visualize_reward = visualize_reward
48 | return env
49 |
50 |
51 | def get_model_and_assets():
52 | """Returns a tuple containing the model XML string and a dict of assets."""
53 | root_dir = os.path.dirname(os.path.dirname(__file__))
54 | xml = resources.GetResource(
55 | os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml'))
56 | return xml, common.ASSETS
57 |
58 |
59 |
60 | @SUITE.add('benchmarking')
61 | def run_backward(time_limit=_DEFAULT_TIME_LIMIT,
62 | random=None,
63 | environment_kwargs=None):
64 | """Returns the run task."""
65 | physics = Physics.from_xml_string(*get_model_and_assets())
66 | task = Cheetah(forward=False, flip=False, random=random)
67 | environment_kwargs = environment_kwargs or {}
68 | return control.Environment(physics,
69 | task,
70 | time_limit=time_limit,
71 | **environment_kwargs)
72 |
73 |
74 | @SUITE.add('benchmarking')
75 | def flip(time_limit=_DEFAULT_TIME_LIMIT,
76 | random=None,
77 | environment_kwargs=None):
78 | """Returns the run task."""
79 | physics = Physics.from_xml_string(*get_model_and_assets())
80 | task = Cheetah(forward=True, flip=True, random=random)
81 | environment_kwargs = environment_kwargs or {}
82 | return control.Environment(physics,
83 | task,
84 | time_limit=time_limit,
85 | **environment_kwargs)
86 |
87 |
88 | @SUITE.add('benchmarking')
89 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT,
90 | random=None,
91 | environment_kwargs=None):
92 | """Returns the run task."""
93 | physics = Physics.from_xml_string(*get_model_and_assets())
94 | task = Cheetah(forward=False, flip=True, random=random)
95 | environment_kwargs = environment_kwargs or {}
96 | return control.Environment(physics,
97 | task,
98 | time_limit=time_limit,
99 | **environment_kwargs)
100 |
101 |
102 | class Physics(mujoco.Physics):
103 | """Physics simulation with additional features for the Cheetah domain."""
104 | def speed(self):
105 | """Returns the horizontal speed of the Cheetah."""
106 | return self.named.data.sensordata['torso_subtreelinvel'][0]
107 |
108 | def angmomentum(self):
109 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
110 | return self.named.data.subtree_angmom['torso'][1]
111 |
112 |
113 | class Cheetah(base.Task):
114 | """A `Task` to train a running Cheetah."""
115 | def __init__(self, forward=True, flip=False, random=None):
116 | self._forward = 1 if forward else -1
117 | self._flip = flip
118 | super(Cheetah, self).__init__(random=random)
119 |
120 | def initialize_episode(self, physics):
121 | """Sets the state of the environment at the start of each episode."""
122 | # The indexing below assumes that all joints have a single DOF.
123 | assert physics.model.nq == physics.model.njnt
124 | is_limited = physics.model.jnt_limited == 1
125 | lower, upper = physics.model.jnt_range[is_limited].T
126 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
127 |
128 | # Stabilize the model before the actual simulation.
129 | for _ in range(200):
130 | physics.step()
131 |
132 | physics.data.time = 0
133 | self._timeout_progress = 0
134 | super().initialize_episode(physics)
135 |
136 | def get_observation(self, physics):
137 | """Returns an observation of the state, ignoring horizontal position."""
138 | obs = collections.OrderedDict()
139 | # Ignores horizontal position to maintain translational invariance.
140 | obs['position'] = physics.data.qpos[1:].copy()
141 | obs['velocity'] = physics.velocity()
142 | return obs
143 |
144 | def get_reward(self, physics):
145 | """Returns a reward to the agent."""
146 | if self._flip:
147 | reward = rewards.tolerance(self._forward * physics.angmomentum(),
148 | bounds=(_SPIN_SPEED, float('inf')),
149 | margin=_SPIN_SPEED,
150 | value_at_margin=0,
151 | sigmoid='linear')
152 |
153 | else:
154 | reward = rewards.tolerance(self._forward * physics.speed(),
155 | bounds=(_RUN_SPEED, float('inf')),
156 | margin=_RUN_SPEED,
157 | value_at_margin=0,
158 | sigmoid='linear')
159 | return reward
160 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | import wandb
9 | from termcolor import colored
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
14 | ('episode_reward', 'R', 'float'),
15 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')]
16 |
17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
19 | ('episode_reward', 'R', 'float'),
20 | ('total_time', 'T', 'time')]
21 |
22 |
23 | class AverageMeter(object):
24 | def __init__(self):
25 | self._sum = 0
26 | self._count = 0
27 |
28 | def update(self, value, n=1):
29 | self._sum += value
30 | self._count += n
31 |
32 | def value(self):
33 | return self._sum / max(1, self._count)
34 |
35 |
36 | class MetersGroup(object):
37 | def __init__(self, csv_file_name, formating, use_wandb):
38 | self._csv_file_name = csv_file_name
39 | self._formating = formating
40 | self._meters = defaultdict(AverageMeter)
41 | self._csv_file = None
42 | self._csv_writer = None
43 | self.use_wandb = use_wandb
44 |
45 | def log(self, key, value, n=1):
46 | self._meters[key].update(value, n)
47 |
48 | def _prime_meters(self):
49 | data = dict()
50 | for key, meter in self._meters.items():
51 | if key.startswith('train'):
52 | key = key[len('train') + 1:]
53 | else:
54 | key = key[len('eval') + 1:]
55 | key = key.replace('/', '_')
56 | data[key] = meter.value()
57 | return data
58 |
59 | def _remove_old_entries(self, data):
60 | rows = []
61 | with self._csv_file_name.open('r') as f:
62 | reader = csv.DictReader(f)
63 | for row in reader:
64 | if float(row['episode']) >= data['episode']:
65 | break
66 | rows.append(row)
67 | with self._csv_file_name.open('w') as f:
68 | writer = csv.DictWriter(f,
69 | fieldnames=sorted(data.keys()),
70 | restval=0.0)
71 | writer.writeheader()
72 | for row in rows:
73 | writer.writerow(row)
74 |
75 | def _dump_to_csv(self, data):
76 | if self._csv_writer is None:
77 | should_write_header = True
78 | if self._csv_file_name.exists():
79 | self._remove_old_entries(data)
80 | should_write_header = False
81 |
82 | self._csv_file = self._csv_file_name.open('a')
83 | self._csv_writer = csv.DictWriter(self._csv_file,
84 | fieldnames=sorted(data.keys()),
85 | restval=0.0)
86 | if should_write_header:
87 | self._csv_writer.writeheader()
88 |
89 | self._csv_writer.writerow(data)
90 | self._csv_file.flush()
91 |
92 | def _format(self, key, value, ty):
93 | if ty == 'int':
94 | value = int(value)
95 | return f'{key}: {value}'
96 | elif ty == 'float':
97 | return f'{key}: {value:.04f}'
98 | elif ty == 'time':
99 | value = str(datetime.timedelta(seconds=int(value)))
100 | return f'{key}: {value}'
101 | else:
102 | raise f'invalid format type: {ty}'
103 |
104 | def _dump_to_console(self, data, prefix):
105 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
106 | pieces = [f'| {prefix: <14}']
107 | for key, disp_key, ty in self._formating:
108 | value = data.get(key, 0)
109 | pieces.append(self._format(disp_key, value, ty))
110 | print(' | '.join(pieces))
111 |
112 | def _dump_to_wandb(self, data):
113 | wandb.log(data)
114 |
115 | def dump(self, step, prefix):
116 | if len(self._meters) == 0:
117 | return
118 | data = self._prime_meters()
119 | data['frame'] = step
120 | if self.use_wandb:
121 | wandb_data = {prefix + '/' + key: val for key, val in data.items()}
122 | self._dump_to_wandb(data=wandb_data)
123 | self._dump_to_csv(data)
124 | self._dump_to_console(data, prefix)
125 | self._meters.clear()
126 |
127 |
128 | class Logger(object):
129 | def __init__(self, log_dir, use_tb, use_wandb):
130 | self._log_dir = log_dir
131 | self._train_mg = MetersGroup(log_dir / 'train.csv',
132 | formating=COMMON_TRAIN_FORMAT,
133 | use_wandb=use_wandb)
134 | self._eval_mg = MetersGroup(log_dir / 'eval.csv',
135 | formating=COMMON_EVAL_FORMAT,
136 | use_wandb=use_wandb)
137 | if use_tb:
138 | self._sw = SummaryWriter(str(log_dir / 'tb'))
139 | else:
140 | self._sw = None
141 | self.use_wandb = use_wandb
142 |
143 | def _try_sw_log(self, key, value, step):
144 | if self._sw is not None:
145 | self._sw.add_scalar(key, value, step)
146 |
147 | def log(self, key, value, step):
148 | assert key.startswith('train') or key.startswith('eval')
149 | if type(value) == torch.Tensor:
150 | value = value.item()
151 | self._try_sw_log(key, value, step)
152 | mg = self._train_mg if key.startswith('train') else self._eval_mg
153 | mg.log(key, value)
154 |
155 | def log_metrics(self, metrics, step, ty):
156 | for key, value in metrics.items():
157 | self.log(f'{ty}/{key}', value, step)
158 |
159 | def dump(self, step, ty=None):
160 | if ty is None or ty == 'eval':
161 | self._eval_mg.dump(step, 'eval')
162 | if ty is None or ty == 'train':
163 | self._train_mg.dump(step, 'train')
164 |
165 | def log_and_dump_ctx(self, step, ty):
166 | return LogAndDumpCtx(self, step, ty)
167 |
168 |
169 | class LogAndDumpCtx:
170 | def __init__(self, logger, step, ty):
171 | self._logger = logger
172 | self._step = step
173 | self._ty = ty
174 |
175 | def __enter__(self):
176 | return self
177 |
178 | def __call__(self, key, value):
179 | self._logger.log(f'{self._ty}/{key}', value, self._step)
180 |
181 | def __exit__(self, *args):
182 | self._logger.dump(self._step, self._ty)
183 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/walker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Planar Walker Domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from dm_control.suite import base
27 | from dm_control.suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | from dm_control.utils import io as resources
32 | from dm_control import suite
33 |
34 | _DEFAULT_TIME_LIMIT = 25
35 | _CONTROL_TIMESTEP = .025
36 |
37 | # Minimal height of torso over foot above which stand reward is 1.
38 | _STAND_HEIGHT = 1.2
39 |
40 | # Horizontal speeds (meters/second) above which move reward is 1.
41 | _WALK_SPEED = 1
42 | _RUN_SPEED = 8
43 | _SPIN_SPEED = 5
44 |
45 | SUITE = containers.TaggedTasks()
46 |
47 | def make(task,
48 | task_kwargs=None,
49 | environment_kwargs=None,
50 | visualize_reward=False):
51 | task_kwargs = task_kwargs or {}
52 | if environment_kwargs is not None:
53 | task_kwargs = task_kwargs.copy()
54 | task_kwargs['environment_kwargs'] = environment_kwargs
55 | env = SUITE[task](**task_kwargs)
56 | env.task.visualize_reward = visualize_reward
57 | return env
58 |
59 | def get_model_and_assets():
60 | """Returns a tuple containing the model XML string and a dict of assets."""
61 | root_dir = os.path.dirname(os.path.dirname(__file__))
62 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks',
63 | 'walker.xml'))
64 | return xml, common.ASSETS
65 |
66 |
67 |
68 |
69 |
70 |
71 | @SUITE.add('benchmarking')
72 | def flip(time_limit=_DEFAULT_TIME_LIMIT,
73 | random=None,
74 | environment_kwargs=None):
75 | """Returns the Run task."""
76 | physics = Physics.from_xml_string(*get_model_and_assets())
77 | task = PlanarWalker(move_speed=_RUN_SPEED,
78 | forward=True,
79 | flip=True,
80 | random=random)
81 | environment_kwargs = environment_kwargs or {}
82 | return control.Environment(physics,
83 | task,
84 | time_limit=time_limit,
85 | control_timestep=_CONTROL_TIMESTEP,
86 | **environment_kwargs)
87 |
88 |
89 | class Physics(mujoco.Physics):
90 | """Physics simulation with additional features for the Walker domain."""
91 | def torso_upright(self):
92 | """Returns projection from z-axes of torso to the z-axes of world."""
93 | return self.named.data.xmat['torso', 'zz']
94 |
95 | def torso_height(self):
96 | """Returns the height of the torso."""
97 | return self.named.data.xpos['torso', 'z']
98 |
99 | def horizontal_velocity(self):
100 | """Returns the horizontal velocity of the center-of-mass."""
101 | return self.named.data.sensordata['torso_subtreelinvel'][0]
102 |
103 | def orientations(self):
104 | """Returns planar orientations of all bodies."""
105 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
106 |
107 | def angmomentum(self):
108 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
109 | return self.named.data.subtree_angmom['torso'][1]
110 |
111 |
112 | class PlanarWalker(base.Task):
113 | """A planar walker task."""
114 | def __init__(self, move_speed, forward=True, flip=False, random=None):
115 | """Initializes an instance of `PlanarWalker`.
116 |
117 | Args:
118 | move_speed: A float. If this value is zero, reward is given simply for
119 | standing up. Otherwise this specifies a target horizontal velocity for
120 | the walking task.
121 | random: Optional, either a `numpy.random.RandomState` instance, an
122 | integer seed for creating a new `RandomState`, or None to select a seed
123 | automatically (default).
124 | """
125 | self._move_speed = move_speed
126 | self._forward = 1 if forward else -1
127 | self._flip = flip
128 | super(PlanarWalker, self).__init__(random=random)
129 |
130 | def initialize_episode(self, physics):
131 | """Sets the state of the environment at the start of each episode.
132 |
133 | In 'standing' mode, use initial orientation and small velocities.
134 | In 'random' mode, randomize joint angles and let fall to the floor.
135 |
136 | Args:
137 | physics: An instance of `Physics`.
138 |
139 | """
140 | randomizers.randomize_limited_and_rotational_joints(
141 | physics, self.random)
142 | super(PlanarWalker, self).initialize_episode(physics)
143 |
144 | def get_observation(self, physics):
145 | """Returns an observation of body orientations, height and velocites."""
146 | obs = collections.OrderedDict()
147 | obs['orientations'] = physics.orientations()
148 | obs['height'] = physics.torso_height()
149 | obs['velocity'] = physics.velocity()
150 | return obs
151 |
152 | def get_reward(self, physics):
153 | """Returns a reward to the agent."""
154 | standing = rewards.tolerance(physics.torso_height(),
155 | bounds=(_STAND_HEIGHT, float('inf')),
156 | margin=_STAND_HEIGHT / 2)
157 | upright = (1 + physics.torso_upright()) / 2
158 | stand_reward = (3 * standing + upright) / 4
159 |
160 | if self._flip:
161 | move_reward = rewards.tolerance(self._forward *
162 | physics.angmomentum(),
163 | bounds=(_SPIN_SPEED, float('inf')),
164 | margin=_SPIN_SPEED,
165 | value_at_margin=0,
166 | sigmoid='linear')
167 | else:
168 | move_reward = rewards.tolerance(
169 | self._forward * physics.horizontal_velocity(),
170 | bounds=(self._move_speed, float('inf')),
171 | margin=self._move_speed / 2,
172 | value_at_margin=0.5,
173 | sigmoid='linear')
174 |
175 | return stand_reward * (5 * move_reward + 1) / 6
176 |
--------------------------------------------------------------------------------
/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import io
3 | import random
4 | import traceback
5 | from collections import defaultdict
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import IterableDataset
11 |
12 |
13 | def episode_len(episode):
14 | # subtract -1 because the dummy first transition
15 | return next(iter(episode.values())).shape[0] - 1
16 |
17 |
18 | def save_episode(episode, fn):
19 | with io.BytesIO() as bs:
20 | np.savez_compressed(bs, **episode)
21 | bs.seek(0)
22 | with fn.open('wb') as f:
23 | f.write(bs.read())
24 |
25 |
26 | def load_episode(fn):
27 | with fn.open('rb') as f:
28 | episode = np.load(f)
29 | episode = {k: episode[k] for k in episode.keys()}
30 | return episode
31 |
32 |
33 | class ReplayBufferStorage:
34 | def __init__(self, data_specs, meta_specs, replay_dir):
35 | self._data_specs = data_specs
36 | self._meta_specs = meta_specs
37 | self._replay_dir = replay_dir
38 | replay_dir.mkdir(exist_ok=True)
39 | self._current_episode = defaultdict(list)
40 | self._preload()
41 |
42 | def __len__(self):
43 | return self._num_transitions
44 |
45 | def add(self, time_step, meta):
46 | for key, value in meta.items():
47 | self._current_episode[key].append(value)
48 | for spec in self._data_specs:
49 | value = time_step[spec.name]
50 | if np.isscalar(value):
51 | value = np.full(spec.shape, value, spec.dtype)
52 | assert spec.shape == value.shape and spec.dtype == value.dtype
53 | self._current_episode[spec.name].append(value)
54 | if time_step.last():
55 | episode = dict()
56 | for spec in self._data_specs:
57 | value = self._current_episode[spec.name]
58 | episode[spec.name] = np.array(value, spec.dtype)
59 | for spec in self._meta_specs:
60 | value = self._current_episode[spec.name]
61 | episode[spec.name] = np.array(value, spec.dtype)
62 | self._current_episode = defaultdict(list)
63 | self._store_episode(episode)
64 |
65 | def _preload(self):
66 | self._num_episodes = 0
67 | self._num_transitions = 0
68 | for fn in self._replay_dir.glob('*.npz'):
69 | _, _, eps_len = fn.stem.split('_')
70 | self._num_episodes += 1
71 | self._num_transitions += int(eps_len)
72 |
73 | def _store_episode(self, episode):
74 | eps_idx = self._num_episodes
75 | eps_len = episode_len(episode)
76 | self._num_episodes += 1
77 | self._num_transitions += eps_len
78 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
79 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
80 | save_episode(episode, self._replay_dir / eps_fn)
81 |
82 |
83 | class ReplayBuffer(IterableDataset):
84 | def __init__(self, storage, max_size, num_workers, nstep, discount,
85 | fetch_every, save_snapshot):
86 | self._storage = storage
87 | self._size = 0
88 | self._max_size = max_size
89 | self._num_workers = max(1, num_workers)
90 | self._episode_fns = []
91 | self._episodes = dict()
92 | self._nstep = nstep
93 | self._discount = discount
94 | self._fetch_every = fetch_every
95 | self._samples_since_last_fetch = fetch_every
96 | self._save_snapshot = save_snapshot
97 |
98 | def _sample_episode(self):
99 | eps_fn = random.choice(self._episode_fns)
100 | return self._episodes[eps_fn]
101 |
102 | def _store_episode(self, eps_fn):
103 | try:
104 | episode = load_episode(eps_fn)
105 | except:
106 | return False
107 | eps_len = episode_len(episode)
108 | while eps_len + self._size > self._max_size:
109 | early_eps_fn = self._episode_fns.pop(0)
110 | early_eps = self._episodes.pop(early_eps_fn)
111 | self._size -= episode_len(early_eps)
112 | early_eps_fn.unlink(missing_ok=True)
113 | self._episode_fns.append(eps_fn)
114 | self._episode_fns.sort()
115 | self._episodes[eps_fn] = episode
116 | self._size += eps_len
117 |
118 | if not self._save_snapshot:
119 | eps_fn.unlink(missing_ok=True)
120 | return True
121 |
122 | def _try_fetch(self):
123 | if self._samples_since_last_fetch < self._fetch_every:
124 | return
125 | self._samples_since_last_fetch = 0
126 | try:
127 | worker_id = torch.utils.data.get_worker_info().id
128 | except:
129 | worker_id = 0
130 | eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True)
131 | fetched_size = 0
132 | for eps_fn in eps_fns:
133 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
134 | if eps_idx % self._num_workers != worker_id:
135 | continue
136 | if eps_fn in self._episodes.keys():
137 | break
138 | if fetched_size + eps_len > self._max_size:
139 | break
140 | fetched_size += eps_len
141 | if not self._store_episode(eps_fn):
142 | break
143 |
144 | def _sample(self):
145 | try:
146 | self._try_fetch()
147 | except:
148 | traceback.print_exc()
149 | self._samples_since_last_fetch += 1
150 | episode = self._sample_episode()
151 | # add +1 for the first dummy transition
152 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
153 | meta = []
154 | for spec in self._storage._meta_specs:
155 | meta.append(episode[spec.name][idx - 1])
156 | obs = episode['observation'][idx - 1]
157 | action = episode['action'][idx]
158 | next_obs = episode['observation'][idx + self._nstep - 1]
159 | reward = np.zeros_like(episode['reward'][idx])
160 | discount = np.ones_like(episode['discount'][idx])
161 | for i in range(self._nstep):
162 | step_reward = episode['reward'][idx + i]
163 | reward += discount * step_reward
164 | discount *= episode['discount'][idx + i] * self._discount
165 | return (obs, action, reward, discount, next_obs, *meta)
166 |
167 | def __iter__(self):
168 | while True:
169 | yield self._sample()
170 |
171 |
172 | def _worker_init_fn(worker_id):
173 | seed = np.random.get_state()[1][0] + worker_id
174 | np.random.seed(seed)
175 | random.seed(seed)
176 |
177 |
178 | def make_replay_loader(storage, max_size, batch_size, num_workers,
179 | save_snapshot, nstep, discount):
180 | max_size_per_worker = max_size // max(1, num_workers)
181 |
182 | iterable = ReplayBuffer(storage,
183 | max_size_per_worker,
184 | num_workers,
185 | nstep,
186 | discount,
187 | fetch_every=1000,
188 | save_snapshot=save_snapshot)
189 |
190 | loader = torch.utils.data.DataLoader(iterable,
191 | batch_size=batch_size,
192 | num_workers=num_workers,
193 | pin_memory=True,
194 | worker_init_fn=_worker_init_fn)
195 | return loader
196 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/hopper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Hopper domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from dm_control.suite import base
27 | from dm_control.suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | from dm_control.utils import io as resources
32 | import numpy as np
33 |
34 | SUITE = containers.TaggedTasks()
35 |
36 | _CONTROL_TIMESTEP = .02 # (Seconds)
37 |
38 | # Default duration of an episode, in seconds.
39 | _DEFAULT_TIME_LIMIT = 20
40 |
41 | # Minimal height of torso over foot above which stand reward is 1.
42 | _STAND_HEIGHT = 0.6
43 |
44 | # Hopping speed above which hop reward is 1.
45 | _HOP_SPEED = 2
46 | _SPIN_SPEED = 5
47 |
48 |
49 | def make(task,
50 | task_kwargs=None,
51 | environment_kwargs=None,
52 | visualize_reward=False):
53 | task_kwargs = task_kwargs or {}
54 | if environment_kwargs is not None:
55 | task_kwargs = task_kwargs.copy()
56 | task_kwargs['environment_kwargs'] = environment_kwargs
57 | env = SUITE[task](**task_kwargs)
58 | env.task.visualize_reward = visualize_reward
59 | return env
60 |
61 | def get_model_and_assets():
62 | """Returns a tuple containing the model XML string and a dict of assets."""
63 | root_dir = os.path.dirname(os.path.dirname(__file__))
64 | xml = resources.GetResource(
65 | os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml'))
66 | return xml, common.ASSETS
67 |
68 |
69 |
70 | @SUITE.add('benchmarking')
71 | def hop_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
72 | """Returns a Hopper that strives to hop forward."""
73 | physics = Physics.from_xml_string(*get_model_and_assets())
74 | task = Hopper(hopping=True, forward=False, flip=False, random=random)
75 | environment_kwargs = environment_kwargs or {}
76 | return control.Environment(physics,
77 | task,
78 | time_limit=time_limit,
79 | control_timestep=_CONTROL_TIMESTEP,
80 | **environment_kwargs)
81 |
82 | @SUITE.add('benchmarking')
83 | def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
84 | """Returns a Hopper that strives to hop forward."""
85 | physics = Physics.from_xml_string(*get_model_and_assets())
86 | task = Hopper(hopping=True, forward=True, flip=True, random=random)
87 | environment_kwargs = environment_kwargs or {}
88 | return control.Environment(physics,
89 | task,
90 | time_limit=time_limit,
91 | control_timestep=_CONTROL_TIMESTEP,
92 | **environment_kwargs)
93 |
94 | @SUITE.add('benchmarking')
95 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
96 | """Returns a Hopper that strives to hop forward."""
97 | physics = Physics.from_xml_string(*get_model_and_assets())
98 | task = Hopper(hopping=True, forward=False, flip=True, random=random)
99 | environment_kwargs = environment_kwargs or {}
100 | return control.Environment(physics,
101 | task,
102 | time_limit=time_limit,
103 | control_timestep=_CONTROL_TIMESTEP,
104 | **environment_kwargs)
105 |
106 |
107 | class Physics(mujoco.Physics):
108 | """Physics simulation with additional features for the Hopper domain."""
109 | def height(self):
110 | """Returns height of torso with respect to foot."""
111 | return (self.named.data.xipos['torso', 'z'] -
112 | self.named.data.xipos['foot', 'z'])
113 |
114 | def speed(self):
115 | """Returns horizontal speed of the Hopper."""
116 | return self.named.data.sensordata['torso_subtreelinvel'][0]
117 |
118 | def touch(self):
119 | """Returns the signals from two foot touch sensors."""
120 | return np.log1p(self.named.data.sensordata[['touch_toe',
121 | 'touch_heel']])
122 |
123 | def angmomentum(self):
124 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
125 | return self.named.data.subtree_angmom['torso'][1]
126 |
127 |
128 |
129 | class Hopper(base.Task):
130 | """A Hopper's `Task` to train a standing and a jumping Hopper."""
131 | def __init__(self, hopping, forward=True, flip=False, random=None):
132 | """Initialize an instance of `Hopper`.
133 |
134 | Args:
135 | hopping: Boolean, if True the task is to hop forwards, otherwise it is to
136 | balance upright.
137 | random: Optional, either a `numpy.random.RandomState` instance, an
138 | integer seed for creating a new `RandomState`, or None to select a seed
139 | automatically (default).
140 | """
141 | self._hopping = hopping
142 | self._forward = 1 if forward else -1
143 | self._flip = flip
144 | super(Hopper, self).__init__(random=random)
145 |
146 | def initialize_episode(self, physics):
147 | """Sets the state of the environment at the start of each episode."""
148 | randomizers.randomize_limited_and_rotational_joints(
149 | physics, self.random)
150 | self._timeout_progress = 0
151 | super(Hopper, self).initialize_episode(physics)
152 |
153 | def get_observation(self, physics):
154 | """Returns an observation of positions, velocities and touch sensors."""
155 | obs = collections.OrderedDict()
156 | # Ignores horizontal position to maintain translational invariance:
157 | obs['position'] = physics.data.qpos[1:].copy()
158 | obs['velocity'] = physics.velocity()
159 | obs['touch'] = physics.touch()
160 | return obs
161 |
162 | def get_reward(self, physics):
163 | """Returns a reward applicable to the performed task."""
164 | standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
165 | assert self._hopping
166 | if self._flip:
167 | hopping = rewards.tolerance(self._forward * physics.angmomentum(),
168 | bounds=(_SPIN_SPEED, float('inf')),
169 | margin=_SPIN_SPEED,
170 | value_at_margin=0,
171 | sigmoid='linear')
172 | else:
173 | hopping = rewards.tolerance(self._forward * physics.speed(),
174 | bounds=(_HOP_SPEED, float('inf')),
175 | margin=_HOP_SPEED / 2,
176 | value_at_margin=0.5,
177 | sigmoid='linear')
178 | return standing * hopping
--------------------------------------------------------------------------------
/agent/becl.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import math
3 | from collections import OrderedDict
4 |
5 | import hydra
6 | import random
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from dm_env import specs
12 |
13 | import utils
14 | from agent.ddpg import DDPGAgent
15 |
16 |
17 | class BECL(nn.Module):
18 | def __init__(self, tau_dim, feature_dim, hidden_dim):
19 | super().__init__()
20 |
21 | self.embed = nn.Sequential(nn.Linear(tau_dim, hidden_dim),
22 | nn.ReLU(),
23 | nn.Linear(hidden_dim, hidden_dim),
24 | nn.ReLU(),
25 | nn.Linear(hidden_dim, feature_dim))
26 |
27 | self.project_head = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
28 | nn.ReLU(),
29 | nn.Linear(hidden_dim, feature_dim))
30 | self.apply(utils.weight_init)
31 |
32 | def forward(self, tau):
33 | features = self.embed(tau)
34 | features = self.project_head(features)
35 | return features
36 |
37 |
38 | class BECLAgent(DDPGAgent):
39 | def __init__(self, update_skill_every_step, skill_dim,
40 | update_encoder, contrastive_update_rate, temperature, skill, **kwargs):
41 | self.skill_dim = skill_dim
42 | self.update_skill_every_step = update_skill_every_step
43 | self.update_encoder = update_encoder
44 | self.contrastive_update_rate = contrastive_update_rate
45 | self.temperature = temperature
46 | # specify skill in fine-tuning stage if needed
47 | self.skill = int(skill) if skill >= 0 else np.random.choice(self.skill_dim)
48 | # increase obs shape to include skill dim
49 | kwargs["meta_dim"] = self.skill_dim
50 | self.batch_size = kwargs['batch_size']
51 | # create actor and critic
52 | super().__init__(**kwargs)
53 |
54 | # net
55 | self.becl = BECL(self.obs_dim - self.skill_dim,
56 | self.skill_dim,
57 | kwargs['hidden_dim']).to(kwargs['device'])
58 |
59 | # optimizers
60 | self.becl_opt = torch.optim.Adam(self.becl.parameters(), lr=self.lr)
61 |
62 | self.becl.train()
63 |
64 |
65 | def get_meta_specs(self):
66 | return specs.Array((self.skill_dim,), np.float32, 'skill'),
67 |
68 | def init_meta(self):
69 | skill = np.zeros(self.skill_dim).astype(np.float32)
70 | if not self.reward_free:
71 | skill[self.skill] = 1.0
72 | else:
73 | skill[np.random.choice(self.skill_dim)] = 1.0
74 | meta = OrderedDict()
75 | meta['skill'] = skill
76 | return meta
77 |
78 | def update_meta(self, meta, global_step, time_step, finetune=False):
79 | if global_step % self.update_skill_every_step == 0:
80 | return self.init_meta()
81 | return meta
82 |
83 | def update_contrastive(self, state, skills):
84 | metrics = dict()
85 | features = self.becl(state)
86 | logits = self.compute_info_nce_loss(features, skills)
87 | loss = logits.mean()
88 |
89 | self.becl_opt.zero_grad()
90 | if self.encoder_opt is not None:
91 | self.encoder_opt.zero_grad(set_to_none=True)
92 | loss.backward()
93 | self.becl_opt.step()
94 | if self.encoder_opt is not None:
95 | self.encoder_opt.step()
96 |
97 | if self.use_tb or self.use_wandb:
98 | metrics['contrastive_loss'] = loss.item()
99 |
100 | return metrics
101 |
102 | def compute_intr_reward(self, skills, state, metrics):
103 |
104 | # compute contrastive reward
105 | features = self.becl(state)
106 | contrastive_reward = torch.exp(-self.compute_info_nce_loss(features, skills))
107 |
108 | intr_reward = contrastive_reward
109 | if self.use_tb or self.use_wandb:
110 | metrics['contrastive_reward'] = contrastive_reward.mean().item()
111 |
112 | return intr_reward
113 |
114 | def compute_info_nce_loss(self, features, skills):
115 | # features: (b,c), skills :(b, skill_dim)
116 | # label positives samples
117 | labels = torch.argmax(skills, dim=-1) #(b, 1)
118 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).long() #(b,b)
119 | labels = labels.to(self.device)
120 |
121 | features = F.normalize(features, dim=1) #(b,c)
122 | similarity_matrix = torch.matmul(features, features.T) #(b,b)
123 |
124 | # discard the main diagonal from both: labels and similarities matrix
125 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
126 | labels = labels[~mask].view(labels.shape[0], -1) #(b,b-1)
127 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) #(b,b-1)
128 |
129 | similarity_matrix = similarity_matrix / self.temperature
130 | similarity_matrix -= torch.max(similarity_matrix, 1)[0][:, None]
131 | similarity_matrix = torch.exp(similarity_matrix)
132 |
133 | pick_one_positive_sample_idx = torch.argmax(labels, dim=-1, keepdim=True)
134 | pick_one_positive_sample_idx = torch.zeros_like(labels).scatter_(-1, pick_one_positive_sample_idx, 1)
135 |
136 | positives = torch.sum(similarity_matrix * pick_one_positive_sample_idx, dim=-1, keepdim=True) #(b,1)
137 | negatives = torch.sum(similarity_matrix, dim=-1, keepdim=True) #(b,1)
138 | eps = torch.as_tensor(1e-6)
139 | loss = -torch.log(positives / (negatives + eps) + eps) #(b,1)
140 |
141 | return loss
142 |
143 |
144 | def update(self, replay_iter, step):
145 | metrics = dict()
146 |
147 | if step % self.update_every_steps != 0:
148 | return metrics
149 |
150 | if self.reward_free:
151 |
152 | batch = next(replay_iter)
153 | obs, action, reward, discount, next_obs, skill = utils.to_torch(batch, self.device)
154 | obs = self.aug_and_encode(obs)
155 | next_obs = self.aug_and_encode(next_obs)
156 |
157 | metrics.update(self.update_contrastive(next_obs, skill))
158 |
159 | for _ in range(self.contrastive_update_rate - 1):
160 | batch = next(replay_iter)
161 | obs, action, reward, discount, next_obs, skill = utils.to_torch(batch, self.device)
162 | obs = self.aug_and_encode(obs)
163 | next_obs = self.aug_and_encode(next_obs)
164 |
165 | metrics.update(self.update_contrastive(next_obs, skill))
166 |
167 | with torch.no_grad():
168 | intr_reward = self.compute_intr_reward(skill, next_obs, metrics)
169 |
170 | if self.use_tb or self.use_wandb:
171 | metrics['intr_reward'] = intr_reward.mean().item()
172 |
173 | reward = intr_reward
174 | else:
175 | batch = next(replay_iter)
176 |
177 | obs, action, extr_reward, discount, next_obs, skill = utils.to_torch(
178 | batch, self.device)
179 | obs = self.aug_and_encode(obs)
180 | next_obs = self.aug_and_encode(next_obs)
181 | reward = extr_reward
182 |
183 | if self.use_tb or self.use_wandb:
184 | metrics['batch_reward'] = reward.mean().item()
185 |
186 | if not self.update_encoder:
187 | obs = obs.detach()
188 | next_obs = next_obs.detach()
189 |
190 | # extend observations with skill
191 | obs = torch.cat([obs, skill], dim=1)
192 | next_obs = torch.cat([next_obs, skill], dim=1)
193 |
194 | # update critic
195 | metrics.update(
196 | self.update_critic(obs.detach(), action, reward, discount,
197 | next_obs.detach(), step))
198 |
199 | # update actor
200 | metrics.update(self.update_actor(obs.detach(), step))
201 |
202 | # update critic target
203 | utils.soft_update_params(self.critic, self.critic_target,
204 | self.critic_target_tau)
205 |
206 | return metrics
207 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/jaco.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """A task where the goal is to move the hand close to a target prop or site."""
17 |
18 | import collections
19 |
20 | from dm_control import composer
21 | from dm_control.composer import initializers
22 | from dm_control.composer.observation import observable
23 | from dm_control.composer.variation import distributions
24 | from dm_control.entities import props
25 | from dm_control.manipulation.shared import arenas
26 | from dm_control.manipulation.shared import cameras
27 | from dm_control.manipulation.shared import constants
28 | from dm_control.manipulation.shared import observations
29 | from dm_control.manipulation.shared import registry
30 | from dm_control.manipulation.shared import robots
31 | from dm_control.manipulation.shared import tags
32 | from dm_control.manipulation.shared import workspaces
33 | from dm_control.utils import rewards
34 | import numpy as np
35 |
36 |
37 | _ReachWorkspace = collections.namedtuple(
38 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
39 |
40 | # Ensures that the props are not touching the table before settling.
41 | _PROP_Z_OFFSET = 0.001
42 |
43 | _DUPLO_WORKSPACE = _ReachWorkspace(
44 | target_bbox=workspaces.BoundingBox(
45 | lower=(-0.1, -0.1, _PROP_Z_OFFSET),
46 | upper=(0.1, 0.1, _PROP_Z_OFFSET)),
47 | tcp_bbox=workspaces.BoundingBox(
48 | lower=(-0.1, -0.1, 0.2),
49 | upper=(0.1, 0.1, 0.4)),
50 | arm_offset=robots.ARM_OFFSET)
51 |
52 | _SITE_WORKSPACE = _ReachWorkspace(
53 | target_bbox=workspaces.BoundingBox(
54 | lower=(-0.2, -0.2, 0.02),
55 | upper=(0.2, 0.2, 0.4)),
56 | tcp_bbox=workspaces.BoundingBox(
57 | lower=(-0.2, -0.2, 0.02),
58 | upper=(0.2, 0.2, 0.4)),
59 | arm_offset=robots.ARM_OFFSET)
60 |
61 | _TARGET_RADIUS = 0.05
62 | _TIME_LIMIT = 10.
63 |
64 | TASKS = {
65 | 'reach_top_left': workspaces.BoundingBox(
66 | lower=(-0.09, 0.09, _PROP_Z_OFFSET),
67 | upper=(-0.09, 0.09, _PROP_Z_OFFSET)),
68 | 'reach_top_right': workspaces.BoundingBox(
69 | lower=(0.09, 0.09, _PROP_Z_OFFSET),
70 | upper=(0.09, 0.09, _PROP_Z_OFFSET)),
71 | 'reach_bottom_left': workspaces.BoundingBox(
72 | lower=(-0.09, -0.09, _PROP_Z_OFFSET),
73 | upper=(-0.09, -0.09, _PROP_Z_OFFSET)),
74 | 'reach_bottom_right': workspaces.BoundingBox(
75 | lower=(0.09, -0.09, _PROP_Z_OFFSET),
76 | upper=(0.09, -0.09, _PROP_Z_OFFSET)),
77 | }
78 |
79 |
80 | def make(task_id, obs_type, seed):
81 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
82 | task = _reach(task_id, obs_settings=obs_settings, use_site=False)
83 | return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed)
84 |
85 |
86 |
87 | class MTReach(composer.Task):
88 | """Bring the hand close to a target prop or site."""
89 |
90 | def __init__(
91 | self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep):
92 | """Initializes a new `Reach` task.
93 |
94 | Args:
95 | arena: `composer.Entity` instance.
96 | arm: `robot_base.RobotArm` instance.
97 | hand: `robot_base.RobotHand` instance.
98 | prop: `composer.Entity` instance specifying the prop to reach to, or None
99 | in which case the target is a fixed site whose position is specified by
100 | the workspace.
101 | obs_settings: `observations.ObservationSettings` instance.
102 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
103 | control_timestep: Float specifying the control timestep in seconds.
104 | """
105 | self._arena = arena
106 | self._arm = arm
107 | self._hand = hand
108 | self._arm.attach(self._hand)
109 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
110 | self.control_timestep = control_timestep
111 | self._tcp_initializer = initializers.ToolCenterPointInitializer(
112 | self._hand, self._arm,
113 | position=distributions.Uniform(*workspace.tcp_bbox),
114 | quaternion=workspaces.DOWN_QUATERNION)
115 |
116 | # Add custom camera observable.
117 | self._task_observables = cameras.add_camera_observables(
118 | arena, obs_settings, cameras.FRONT_CLOSE)
119 |
120 | target_pos_distribution = distributions.Uniform(*TASKS[task_id])
121 | self._prop = prop
122 | if prop:
123 | # The prop itself is used to visualize the target location.
124 | self._make_target_site(parent_entity=prop, visible=False)
125 | self._target = self._arena.add_free_entity(prop)
126 | self._prop_placer = initializers.PropPlacer(
127 | props=[prop],
128 | position=target_pos_distribution,
129 | quaternion=workspaces.uniform_z_rotation,
130 | settle_physics=True)
131 | else:
132 | self._target = self._make_target_site(parent_entity=arena, visible=True)
133 | self._target_placer = target_pos_distribution
134 |
135 | obs = observable.MJCFFeature('pos', self._target)
136 | obs.configure(**obs_settings.prop_pose._asdict())
137 | self._task_observables['target_position'] = obs
138 |
139 | # Add sites for visualizing the prop and target bounding boxes.
140 | workspaces.add_bbox_site(
141 | body=self.root_entity.mjcf_model.worldbody,
142 | lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper,
143 | rgba=constants.GREEN, name='tcp_spawn_area')
144 | workspaces.add_bbox_site(
145 | body=self.root_entity.mjcf_model.worldbody,
146 | lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper,
147 | rgba=constants.BLUE, name='target_spawn_area')
148 |
149 | def _make_target_site(self, parent_entity, visible):
150 | return workspaces.add_target_site(
151 | body=parent_entity.mjcf_model.worldbody,
152 | radius=_TARGET_RADIUS, visible=visible,
153 | rgba=constants.RED, name='target_site')
154 |
155 | @property
156 | def root_entity(self):
157 | return self._arena
158 |
159 | @property
160 | def arm(self):
161 | return self._arm
162 |
163 | @property
164 | def hand(self):
165 | return self._hand
166 |
167 | @property
168 | def task_observables(self):
169 | return self._task_observables
170 |
171 | def get_reward(self, physics):
172 | hand_pos = physics.bind(self._hand.tool_center_point).xpos
173 | target_pos = physics.bind(self._target).xpos
174 | distance = np.linalg.norm(hand_pos - target_pos)
175 | return rewards.tolerance(
176 | distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS)
177 |
178 | def initialize_episode(self, physics, random_state):
179 | self._hand.set_grasp(physics, close_factors=random_state.uniform())
180 | self._tcp_initializer(physics, random_state)
181 | if self._prop:
182 | self._prop_placer(physics, random_state)
183 | else:
184 | physics.bind(self._target).pos = (
185 | self._target_placer(random_state=random_state))
186 |
187 |
188 | def _reach(task_id, obs_settings, use_site):
189 | """Configure and instantiate a `Reach` task.
190 |
191 | Args:
192 | obs_settings: An `observations.ObservationSettings` instance.
193 | use_site: Boolean, if True then the target will be a fixed site, otherwise
194 | it will be a moveable Duplo brick.
195 |
196 | Returns:
197 | An instance of `reach.Reach`.
198 | """
199 | arena = arenas.Standard()
200 | arm = robots.make_arm(obs_settings=obs_settings)
201 | hand = robots.make_hand(obs_settings=obs_settings)
202 | if use_site:
203 | workspace = _SITE_WORKSPACE
204 | prop = None
205 | else:
206 | workspace = _DUPLO_WORKSPACE
207 | prop = props.Duplo(observable_options=observations.make_options(
208 | obs_settings, observations.FREEPROP_OBSERVABLES))
209 | task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop,
210 | obs_settings=obs_settings,
211 | workspace=workspace,
212 | control_timestep=constants.CONTROL_TIMESTEP)
213 | return task
--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings('ignore', category=DeprecationWarning)
4 |
5 | import os
6 |
7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
8 | os.environ['MUJOCO_GL'] = 'egl'
9 |
10 | from pathlib import Path
11 |
12 | import hydra
13 | import numpy as np
14 | import torch
15 | import wandb
16 | from dm_env import specs
17 |
18 | import dmc
19 | import utils
20 | from logger import Logger
21 | from replay_buffer import ReplayBufferStorage, make_replay_loader
22 | from video import TrainVideoRecorder, VideoRecorder
23 |
24 | torch.backends.cudnn.benchmark = True
25 |
26 | from dmc_benchmark import PRIMAL_TASKS
27 |
28 |
29 | def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg):
30 | cfg.obs_type = obs_type
31 | cfg.obs_shape = obs_spec.shape
32 | cfg.action_shape = action_spec.shape
33 | cfg.num_expl_steps = num_expl_steps
34 | return hydra.utils.instantiate(cfg)
35 |
36 |
37 | class Workspace:
38 | def __init__(self, cfg):
39 | self.work_dir = Path.cwd()
40 | print(f'workspace: {self.work_dir}')
41 |
42 | self.cfg = cfg
43 | utils.set_seed_everywhere(cfg.seed)
44 | self.device = torch.device(cfg.device)
45 |
46 | # create logger
47 | if cfg.use_wandb:
48 | exp_name = '_'.join([
49 | cfg.experiment, cfg.agent.name, cfg.domain, cfg.obs_type,
50 | str(cfg.seed)
51 | ])
52 | wandb.init(project="urlb", group=cfg.agent.name, name=exp_name)
53 |
54 | self.logger = Logger(self.work_dir,
55 | use_tb=cfg.use_tb,
56 | use_wandb=cfg.use_wandb)
57 | # create envs
58 | task = PRIMAL_TASKS[self.cfg.domain]
59 | self.train_env = dmc.make(task, cfg.obs_type, cfg.frame_stack,
60 | cfg.action_repeat, cfg.seed)
61 | self.eval_env = dmc.make(task, cfg.obs_type, cfg.frame_stack,
62 | cfg.action_repeat, cfg.seed)
63 |
64 | # create agent
65 | self.agent = make_agent(cfg.obs_type,
66 | self.train_env.observation_spec(),
67 | self.train_env.action_spec(),
68 | cfg.num_seed_frames // cfg.action_repeat,
69 | cfg.agent)
70 |
71 | # get meta specs
72 | meta_specs = self.agent.get_meta_specs()
73 | # create replay buffer
74 | data_specs = (self.train_env.observation_spec(),
75 | self.train_env.action_spec(),
76 | specs.Array((1,), np.float32, 'reward'),
77 | specs.Array((1,), np.float32, 'discount'))
78 |
79 | # create data storage
80 | self.replay_storage = ReplayBufferStorage(data_specs, meta_specs,
81 | self.work_dir / 'buffer')
82 |
83 | # create replay buffer
84 | self.replay_loader = make_replay_loader(self.replay_storage,
85 | cfg.replay_buffer_size,
86 | cfg.batch_size,
87 | cfg.replay_buffer_num_workers,
88 | False, cfg.nstep, cfg.discount)
89 | self._replay_iter = None
90 |
91 | # create video recorders
92 | self.video_recorder = VideoRecorder(
93 | self.work_dir if cfg.save_video else None,
94 | camera_id=0 if 'quadruped' not in self.cfg.domain else 2,
95 | use_wandb=self.cfg.use_wandb)
96 | self.train_video_recorder = TrainVideoRecorder(
97 | self.work_dir if cfg.save_train_video else None,
98 | camera_id=0 if 'quadruped' not in self.cfg.domain else 2,
99 | use_wandb=self.cfg.use_wandb)
100 |
101 | self.timer = utils.Timer()
102 | self._global_step = 0
103 | self._global_episode = 0
104 |
105 | @property
106 | def global_step(self):
107 | return self._global_step
108 |
109 | @property
110 | def global_episode(self):
111 | return self._global_episode
112 |
113 | @property
114 | def global_frame(self):
115 | return self.global_step * self.cfg.action_repeat
116 |
117 | @property
118 | def replay_iter(self):
119 | if self._replay_iter is None:
120 | self._replay_iter = iter(self.replay_loader)
121 | return self._replay_iter
122 |
123 | def eval(self):
124 | step, episode, total_reward = 0, 0, 0
125 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
126 | meta = self.agent.init_meta()
127 | while eval_until_episode(episode):
128 | time_step = self.eval_env.reset()
129 | self.video_recorder.init(self.eval_env, enabled=(episode == 0))
130 | while not time_step.last():
131 | with torch.no_grad(), utils.eval_mode(self.agent):
132 | action = self.agent.act(time_step.observation,
133 | meta,
134 | self.global_step,
135 | eval_mode=True)
136 | time_step = self.eval_env.step(action)
137 | self.video_recorder.record(self.eval_env)
138 | total_reward += time_step.reward
139 | step += 1
140 |
141 | episode += 1
142 | self.video_recorder.save(f'{self.global_frame}.mp4')
143 |
144 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
145 | log('episode_reward', total_reward / episode)
146 | log('episode_length', step * self.cfg.action_repeat / episode)
147 | log('episode', self.global_episode)
148 | log('step', self.global_step)
149 |
150 | def train(self):
151 | # predicates
152 | train_until_step = utils.Until(self.cfg.num_train_frames,
153 | self.cfg.action_repeat)
154 | seed_until_step = utils.Until(self.cfg.num_seed_frames,
155 | self.cfg.action_repeat)
156 | eval_every_step = utils.Every(self.cfg.eval_every_frames,
157 | self.cfg.action_repeat)
158 |
159 | episode_step, episode_reward = 0, 0
160 | time_step = self.train_env.reset()
161 | meta = self.agent.init_meta()
162 | self.replay_storage.add(time_step, meta)
163 | self.train_video_recorder.init(time_step.observation)
164 | metrics = None
165 | while train_until_step(self.global_step):
166 | if time_step.last():
167 | self._global_episode += 1
168 | self.train_video_recorder.save(f'{self.global_frame}.mp4')
169 | # wait until all the metrics schema is populated
170 | if metrics is not None:
171 | # log stats
172 | elapsed_time, total_time = self.timer.reset()
173 | episode_frame = episode_step * self.cfg.action_repeat
174 | with self.logger.log_and_dump_ctx(self.global_frame,
175 | ty='train') as log:
176 | log('fps', episode_frame / elapsed_time)
177 | log('total_time', total_time)
178 | log('episode_reward', episode_reward)
179 | log('episode_length', episode_frame)
180 | log('episode', self.global_episode)
181 | log('buffer_size', len(self.replay_storage))
182 | log('step', self.global_step)
183 |
184 | # reset env
185 | time_step = self.train_env.reset()
186 | meta = self.agent.init_meta()
187 | self.replay_storage.add(time_step, meta)
188 | self.train_video_recorder.init(time_step.observation)
189 | # try to save snapshot
190 | episode_step = 0
191 | episode_reward = 0
192 | if self.global_frame in self.cfg.snapshots:
193 | self.save_snapshot()
194 | # try to evaluate
195 | if eval_every_step(self.global_step):
196 | self.logger.log('eval_total_time', self.timer.total_time(),
197 | self.global_frame)
198 | self.eval()
199 |
200 | meta = self.agent.update_meta(meta, self.global_step, time_step)
201 | # sample action
202 | with torch.no_grad(), utils.eval_mode(self.agent):
203 | action = self.agent.act(time_step.observation,
204 | meta,
205 | self.global_step,
206 | eval_mode=False)
207 |
208 | # try to update the agent
209 | if not seed_until_step(self.global_step):
210 | metrics = self.agent.update(self.replay_iter, self.global_step)
211 | self.logger.log_metrics(metrics, self.global_frame, ty='train')
212 |
213 | # take env step
214 | time_step = self.train_env.step(action)
215 | episode_reward += time_step.reward
216 | self.replay_storage.add(time_step, meta)
217 | self.train_video_recorder.record(time_step.observation)
218 | episode_step += 1
219 | self._global_step += 1
220 |
221 | def save_snapshot(self):
222 | snapshot_dir = self.work_dir / Path(self.cfg.snapshot_dir)
223 | snapshot_dir.mkdir(exist_ok=True, parents=True)
224 | snapshot = snapshot_dir / f'snapshot_{self.global_frame}.pt'
225 | keys_to_save = ['agent', '_global_step', '_global_episode']
226 | payload = {k: self.__dict__[k] for k in keys_to_save}
227 | with snapshot.open('wb') as f:
228 | torch.save(payload, f)
229 |
230 |
231 | @hydra.main(config_path='.', config_name='pretrain')
232 | def main(cfg):
233 | from pretrain import Workspace as W
234 | root_dir = Path.cwd()
235 | workspace = W(cfg)
236 | snapshot = root_dir / 'snapshot.pt'
237 | if snapshot.exists():
238 | print(f'resuming: {snapshot}')
239 | workspace.load_snapshot()
240 | workspace.train()
241 |
242 |
243 | if __name__ == '__main__':
244 | main()
245 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings('ignore', category=DeprecationWarning)
4 |
5 | import os
6 |
7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
8 | os.environ['MUJOCO_GL'] = 'egl'
9 |
10 | from pathlib import Path
11 |
12 | import hydra
13 | import numpy as np
14 | import torch
15 | from dm_env import specs
16 |
17 | import dmc
18 | import utils
19 | from logger import Logger
20 | from replay_buffer import ReplayBufferStorage, make_replay_loader
21 | from video import TrainVideoRecorder, VideoRecorder
22 |
23 | torch.backends.cudnn.benchmark = True
24 |
25 | import logging
26 |
27 | def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg):
28 | cfg.obs_type = obs_type
29 | cfg.obs_shape = obs_spec.shape
30 | cfg.action_shape = action_spec.shape
31 | cfg.num_expl_steps = num_expl_steps
32 | return hydra.utils.instantiate(cfg)
33 |
34 |
35 | class Workspace:
36 | def __init__(self, cfg):
37 | self.work_dir = Path.cwd()
38 | print(f'workspace: {self.work_dir}')
39 |
40 | self.cfg = cfg
41 | utils.set_seed_everywhere(cfg.seed)
42 | self.device = torch.device(cfg.device)
43 |
44 | # create logger
45 | self.logger = Logger(self.work_dir,
46 | use_tb=cfg.use_tb,
47 | use_wandb=cfg.use_wandb)
48 | # create envs
49 | self.train_env = dmc.make(cfg.task, cfg.obs_type, cfg.frame_stack,
50 | cfg.action_repeat, cfg.seed)
51 | self.eval_env = dmc.make(cfg.task, cfg.obs_type, cfg.frame_stack,
52 | cfg.action_repeat, cfg.seed)
53 |
54 | # create agent
55 | self.agent = make_agent(cfg.obs_type,
56 | self.train_env.observation_spec(),
57 | self.train_env.action_spec(),
58 | cfg.num_seed_frames // cfg.action_repeat,
59 | cfg.agent)
60 |
61 | # initialize from pretrained
62 | if cfg.snapshot_ts > 0:
63 | pretrained_agent = self.load_snapshot()['agent']
64 | self.agent.init_from(pretrained_agent)
65 |
66 | # get meta specs
67 | meta_specs = self.agent.get_meta_specs()
68 | # create replay buffer
69 | data_specs = (self.train_env.observation_spec(),
70 | self.train_env.action_spec(),
71 | specs.Array((1,), np.float32, 'reward'),
72 | specs.Array((1,), np.float32, 'discount'))
73 |
74 | # create data storage
75 | self.replay_storage = ReplayBufferStorage(data_specs, meta_specs,
76 | self.work_dir / 'buffer')
77 |
78 | # create replay buffer
79 | self.replay_loader = make_replay_loader(self.replay_storage,
80 | cfg.replay_buffer_size,
81 | cfg.batch_size,
82 | cfg.replay_buffer_num_workers,
83 | False, cfg.nstep, cfg.discount)
84 | self._replay_iter = None
85 |
86 | # create video recorders
87 | self.video_recorder = VideoRecorder(
88 | self.work_dir if cfg.save_video else None)
89 | self.train_video_recorder = TrainVideoRecorder(
90 | self.work_dir if cfg.save_train_video else None)
91 |
92 | self.timer = utils.Timer()
93 | self._global_step = 0
94 | self._global_episode = 0
95 |
96 | @property
97 | def global_step(self):
98 | return self._global_step
99 |
100 | @property
101 | def global_episode(self):
102 | return self._global_episode
103 |
104 | @property
105 | def global_frame(self):
106 | return self.global_step * self.cfg.action_repeat
107 |
108 | @property
109 | def replay_iter(self):
110 | if self._replay_iter is None:
111 | self._replay_iter = iter(self.replay_loader)
112 | return self._replay_iter
113 |
114 | def eval(self):
115 | step, episode, total_reward = 0, 0, 0
116 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
117 | meta = self.agent.init_meta()
118 | while eval_until_episode(episode):
119 | time_step = self.eval_env.reset()
120 | self.video_recorder.init(self.eval_env, enabled=(episode == 0))
121 | while not time_step.last():
122 | with torch.no_grad(), utils.eval_mode(self.agent):
123 | action = self.agent.act(time_step.observation,
124 | meta,
125 | self.global_step,
126 | eval_mode=True)
127 | time_step = self.eval_env.step(action)
128 | self.video_recorder.record(self.eval_env)
129 | total_reward += time_step.reward
130 | step += 1
131 |
132 | episode += 1
133 | self.video_recorder.save(f'{self.global_frame}.mp4')
134 |
135 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
136 | log('episode_reward', total_reward / episode)
137 | log('episode_length', step * self.cfg.action_repeat / episode)
138 | log('episode', self.global_episode)
139 | log('step', self.global_step)
140 | if 'skill' in meta.keys():
141 | log('skill', meta['skill'].argmax())
142 |
143 | def train(self):
144 | # predicates
145 | train_until_step = utils.Until(self.cfg.num_train_frames,
146 | self.cfg.action_repeat)
147 | seed_until_step = utils.Until(self.cfg.num_seed_frames,
148 | self.cfg.action_repeat)
149 | eval_every_step = utils.Every(self.cfg.eval_every_frames,
150 | self.cfg.action_repeat)
151 |
152 | episode_step, episode_reward = 0, 0
153 | time_step = self.train_env.reset()
154 | meta = self.agent.init_meta()
155 | self.replay_storage.add(time_step, meta)
156 | self.train_video_recorder.init(time_step.observation)
157 | metrics = None
158 | while train_until_step(self.global_step):
159 | if time_step.last():
160 | self._global_episode += 1
161 | self.train_video_recorder.save(f'{self.global_frame}.mp4')
162 | # wait until all the metrics schema is populated
163 | if metrics is not None:
164 | # log stats
165 | elapsed_time, total_time = self.timer.reset()
166 | episode_frame = episode_step * self.cfg.action_repeat
167 | with self.logger.log_and_dump_ctx(self.global_frame,
168 | ty='train') as log:
169 | log('fps', episode_frame / elapsed_time)
170 | log('total_time', total_time)
171 | log('episode_reward', episode_reward)
172 | log('episode_length', episode_frame)
173 | log('episode', self.global_episode)
174 | log('buffer_size', len(self.replay_storage))
175 | log('step', self.global_step)
176 |
177 | # reset env
178 | time_step = self.train_env.reset()
179 | meta = self.agent.init_meta()
180 | self.replay_storage.add(time_step, meta)
181 | self.train_video_recorder.init(time_step.observation)
182 |
183 | episode_step = 0
184 | episode_reward = 0
185 |
186 | # try to evaluate
187 | if eval_every_step(self.global_step):
188 | self.logger.log('eval_total_time', self.timer.total_time(),
189 | self.global_frame)
190 | self.eval()
191 |
192 | meta = self.agent.update_meta(meta, self.global_step, time_step)
193 |
194 | if hasattr(self.agent, "regress_meta"):
195 | repeat = self.cfg.action_repeat
196 | every = self.agent.update_task_every_step // repeat
197 | init_step = self.agent.num_init_steps
198 | if self.global_step > (
199 | init_step // repeat) and self.global_step % every == 0:
200 | meta = self.agent.regress_meta(self.replay_iter,
201 | self.global_step)
202 |
203 | # sample action
204 | with torch.no_grad(), utils.eval_mode(self.agent):
205 | action = self.agent.act(time_step.observation,
206 | meta,
207 | self.global_step,
208 | eval_mode=False)
209 |
210 | # try to update the agent
211 | if not seed_until_step(self.global_step):
212 | metrics = self.agent.update(self.replay_iter, self.global_step)
213 | self.logger.log_metrics(metrics, self.global_frame, ty='train')
214 |
215 | # take env step
216 | time_step = self.train_env.step(action)
217 | episode_reward += time_step.reward
218 | self.replay_storage.add(time_step, meta)
219 | self.train_video_recorder.record(time_step.observation)
220 | episode_step += 1
221 | self._global_step += 1
222 |
223 | def load_snapshot(self):
224 | snapshot_base_dir = Path(self.cfg.snapshot_base_dir)
225 | domain, _ = self.cfg.task.split('_', 1)
226 | snapshot_dir = snapshot_base_dir / self.cfg.obs_type / domain / self.cfg.agent.name
227 |
228 | def try_load(seed):
229 | snapshot = '../../../../../' / snapshot_dir / str(
230 | seed) / f'snapshot_{self.cfg.snapshot_ts}.pt'
231 | logging.info("loading model :{},cwd is {}".format(str(snapshot), str(Path.cwd())))
232 | if not snapshot.exists():
233 | logging.error("no such a pretrain model")
234 | return None
235 | with snapshot.open('rb') as f:
236 | payload = torch.load(f, map_location='cuda:0')
237 | return payload
238 |
239 | # try to load current seed
240 | payload = try_load(self.cfg.seed)
241 | assert payload is not None
242 |
243 | return payload
244 |
245 |
246 | @hydra.main(config_path='.', config_name='finetune')
247 | def main(cfg):
248 | from finetune import Workspace as W
249 | root_dir = Path.cwd()
250 | logging.basicConfig(encoding="utf-8", level=logging.DEBUG)
251 | workspace = W(cfg)
252 | snapshot = root_dir / 'snapshot.pt'
253 | if snapshot.exists():
254 | print(f'resuming: {snapshot}')
255 | workspace.load_snapshot()
256 | workspace.train()
257 |
258 |
259 | if __name__ == '__main__':
260 | main()
261 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import re
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from omegaconf import OmegaConf
11 | from torch import distributions as pyd
12 | from torch.distributions.utils import _standard_normal
13 |
14 |
15 | class eval_mode:
16 | def __init__(self, *models):
17 | self.models = models
18 |
19 | def __enter__(self):
20 | self.prev_states = []
21 | for model in self.models:
22 | self.prev_states.append(model.training)
23 | model.train(False)
24 |
25 | def __exit__(self, *args):
26 | for model, state in zip(self.models, self.prev_states):
27 | model.train(state)
28 | return False
29 |
30 |
31 | def set_seed_everywhere(seed):
32 | torch.manual_seed(seed)
33 | if torch.cuda.is_available():
34 | torch.cuda.manual_seed_all(seed)
35 | np.random.seed(seed)
36 | random.seed(seed)
37 |
38 |
39 | def chain(*iterables):
40 | for it in iterables:
41 | yield from it
42 |
43 |
44 | def soft_update_params(net, target_net, tau):
45 | for param, target_param in zip(net.parameters(), target_net.parameters()):
46 | target_param.data.copy_(tau * param.data +
47 | (1 - tau) * target_param.data)
48 |
49 |
50 | def hard_update_params(net, target_net):
51 | for param, target_param in zip(net.parameters(), target_net.parameters()):
52 | target_param.data.copy_(param.data)
53 |
54 |
55 | def to_torch(xs, device):
56 | return tuple(torch.as_tensor(x, device=device) for x in xs)
57 |
58 |
59 | def weight_init(m):
60 | """Custom weight init for Conv2D and Linear layers."""
61 | if isinstance(m, nn.Linear):
62 | nn.init.orthogonal_(m.weight.data)
63 | if hasattr(m.bias, 'data'):
64 | m.bias.data.fill_(0.0)
65 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
66 | gain = nn.init.calculate_gain('relu')
67 | nn.init.orthogonal_(m.weight.data, gain)
68 | if hasattr(m.bias, 'data'):
69 | m.bias.data.fill_(0.0)
70 |
71 |
72 | def grad_norm(params, norm_type=2.0):
73 | params = [p for p in params if p.grad is not None]
74 | total_norm = torch.norm(
75 | torch.stack([torch.norm(p.grad.detach(), norm_type) for p in params]),
76 | norm_type)
77 | return total_norm.item()
78 |
79 |
80 | def param_norm(params, norm_type=2.0):
81 | total_norm = torch.norm(
82 | torch.stack([torch.norm(p.detach(), norm_type) for p in params]),
83 | norm_type)
84 | return total_norm.item()
85 |
86 |
87 | class Until:
88 | def __init__(self, until, action_repeat=1):
89 | self._until = until
90 | self._action_repeat = action_repeat
91 |
92 | def __call__(self, step):
93 | if self._until is None:
94 | return True
95 | until = self._until // self._action_repeat
96 | return step < until
97 |
98 |
99 | class Every:
100 | def __init__(self, every, action_repeat=1):
101 | self._every = every
102 | self._action_repeat = action_repeat
103 |
104 | def __call__(self, step):
105 | if self._every is None:
106 | return False
107 | every = self._every // self._action_repeat
108 | if step % every == 0:
109 | return True
110 | return False
111 |
112 |
113 | class Timer:
114 | def __init__(self):
115 | self._start_time = time.time()
116 | self._last_time = time.time()
117 |
118 | def reset(self):
119 | elapsed_time = time.time() - self._last_time
120 | self._last_time = time.time()
121 | total_time = time.time() - self._start_time
122 | return elapsed_time, total_time
123 |
124 | def total_time(self):
125 | return time.time() - self._start_time
126 |
127 |
128 | class TruncatedNormal(pyd.Normal):
129 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
130 | super().__init__(loc, scale, validate_args=False)
131 | self.low = low
132 | self.high = high
133 | self.eps = eps
134 |
135 | def _clamp(self, x):
136 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
137 | x = x - x.detach() + clamped_x.detach()
138 | return x
139 |
140 | def sample(self, clip=None, sample_shape=torch.Size()):
141 | shape = self._extended_shape(sample_shape)
142 | eps = _standard_normal(shape,
143 | dtype=self.loc.dtype,
144 | device=self.loc.device)
145 | eps *= self.scale
146 | if clip is not None:
147 | eps = torch.clamp(eps, -clip, clip)
148 | x = self.loc + eps
149 | return self._clamp(x)
150 |
151 |
152 | class TanhTransform(pyd.transforms.Transform):
153 | domain = pyd.constraints.real
154 | codomain = pyd.constraints.interval(-1.0, 1.0)
155 | bijective = True
156 | sign = +1
157 |
158 | def __init__(self, cache_size=1):
159 | super().__init__(cache_size=cache_size)
160 |
161 | @staticmethod
162 | def atanh(x):
163 | return 0.5 * (x.log1p() - (-x).log1p())
164 |
165 | def __eq__(self, other):
166 | return isinstance(other, TanhTransform)
167 |
168 | def _call(self, x):
169 | return x.tanh()
170 |
171 | def _inverse(self, y):
172 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
173 | # one should use `cache_size=1` instead
174 | return self.atanh(y)
175 |
176 | def log_abs_det_jacobian(self, x, y):
177 | # We use a formula that is more numerically stable, see details in the following link
178 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
179 | return 2. * (math.log(2.) - x - F.softplus(-2. * x))
180 |
181 |
182 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
183 | def __init__(self, loc, scale):
184 | self.loc = loc
185 | self.scale = scale
186 |
187 | self.base_dist = pyd.Normal(loc, scale)
188 | transforms = [TanhTransform()]
189 | super().__init__(self.base_dist, transforms)
190 |
191 | @property
192 | def mean(self):
193 | mu = self.loc
194 | for tr in self.transforms:
195 | mu = tr(mu)
196 | return mu
197 |
198 |
199 | def schedule(schdl, step):
200 | try:
201 | return float(schdl)
202 | except ValueError:
203 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
204 | if match:
205 | init, final, duration = [float(g) for g in match.groups()]
206 | mix = np.clip(step / duration, 0.0, 1.0)
207 | return (1.0 - mix) * init + mix * final
208 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
209 | if match:
210 | init, final1, duration1, final2, duration2 = [
211 | float(g) for g in match.groups()
212 | ]
213 | if step <= duration1:
214 | mix = np.clip(step / duration1, 0.0, 1.0)
215 | return (1.0 - mix) * init + mix * final1
216 | else:
217 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
218 | return (1.0 - mix) * final1 + mix * final2
219 | raise NotImplementedError(schdl)
220 |
221 |
222 | class RandomShiftsAug(nn.Module):
223 | def __init__(self, pad):
224 | super().__init__()
225 | self.pad = pad
226 |
227 | def forward(self, x):
228 | x = x.float()
229 | n, c, h, w = x.size()
230 | assert h == w
231 | padding = tuple([self.pad] * 4)
232 | x = F.pad(x, padding, 'replicate')
233 | eps = 1.0 / (h + 2 * self.pad)
234 | arange = torch.linspace(-1.0 + eps,
235 | 1.0 - eps,
236 | h + 2 * self.pad,
237 | device=x.device,
238 | dtype=x.dtype)[:h]
239 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
240 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
241 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
242 |
243 | shift = torch.randint(0,
244 | 2 * self.pad + 1,
245 | size=(n, 1, 1, 2),
246 | device=x.device,
247 | dtype=x.dtype)
248 | shift *= 2.0 / (h + 2 * self.pad)
249 |
250 | grid = base_grid + shift
251 | return F.grid_sample(x,
252 | grid,
253 | padding_mode='zeros',
254 | align_corners=False)
255 |
256 |
257 | class RMS(object):
258 | """running mean and std """
259 | def __init__(self, device, epsilon=1e-4, shape=(1,)):
260 | self.M = torch.zeros(shape).to(device)
261 | self.S = torch.ones(shape).to(device)
262 | self.n = epsilon
263 |
264 | def __call__(self, x):
265 | bs = x.size(0)
266 | delta = torch.mean(x, dim=0) - self.M
267 | new_M = self.M + delta * bs / (self.n + bs)
268 | new_S = (self.S * self.n + torch.var(x, dim=0) * bs +
269 | torch.square(delta) * self.n * bs /
270 | (self.n + bs)) / (self.n + bs)
271 |
272 | self.M = new_M
273 | self.S = new_S
274 | self.n += bs
275 |
276 | return self.M, self.S
277 |
278 |
279 | class PBE(object):
280 | """particle-based entropy based on knn normalized by running mean """
281 | def __init__(self, rms, knn_clip, knn_k, knn_avg, knn_rms, device):
282 | self.rms = rms
283 | self.knn_rms = knn_rms
284 | self.knn_k = knn_k
285 | self.knn_avg = knn_avg
286 | self.knn_clip = knn_clip
287 | self.device = device
288 |
289 | def __call__(self, rep):
290 | source = target = rep
291 | b1, b2 = source.size(0), target.size(0)
292 | # (b1, 1, c) - (1, b2, c) -> (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2)
293 | sim_matrix = torch.norm(source[:, None, :].view(b1, 1, -1) -
294 | target[None, :, :].view(1, b2, -1),
295 | dim=-1,
296 | p=2)
297 | reward, _ = sim_matrix.topk(self.knn_k,
298 | dim=1,
299 | largest=False,
300 | sorted=True) # (b1, k)
301 | if not self.knn_avg: # only keep k-th nearest neighbor
302 | reward = reward[:, -1]
303 | reward = reward.reshape(-1, 1) # (b1, 1)
304 | reward /= self.rms(reward)[0] if self.knn_rms else 1.0
305 | reward = torch.maximum(
306 | reward - self.knn_clip,
307 | torch.zeros_like(reward).to(self.device)
308 | ) if self.knn_clip >= 0.0 else reward # (b1, 1)
309 | else: # average over all k nearest neighbors
310 | reward = reward.reshape(-1, 1) # (b1 * k, 1)
311 | reward /= self.rms(reward)[0] if self.knn_rms else 1.0
312 | reward = torch.maximum(
313 | reward - self.knn_clip,
314 | torch.zeros_like(reward).to(
315 | self.device)) if self.knn_clip >= 0.0 else reward
316 | reward = reward.reshape((b1, self.knn_k)) # (b1, k)
317 | reward = reward.mean(dim=1, keepdim=True) # (b1, 1)
318 | reward = torch.log(reward + 1.0)
319 | return reward
320 |
--------------------------------------------------------------------------------
/agent/ddpg.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import hydra
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import utils
10 |
11 |
12 | class Encoder(nn.Module):
13 | def __init__(self, obs_shape):
14 | super().__init__()
15 |
16 | assert len(obs_shape) == 3
17 | self.repr_dim = 32 * 35 * 35
18 |
19 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
20 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
21 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
22 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
23 | nn.ReLU())
24 |
25 | self.apply(utils.weight_init)
26 |
27 | def forward(self, obs):
28 | obs = obs / 255.0 - 0.5
29 | h = self.convnet(obs)
30 | h = h.view(h.shape[0], -1)
31 | return h
32 |
33 |
34 | class Actor(nn.Module):
35 | def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim):
36 | super().__init__()
37 |
38 | feature_dim = feature_dim if obs_type == 'pixels' else hidden_dim
39 |
40 | self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim),
41 | nn.LayerNorm(feature_dim), nn.Tanh())
42 |
43 | policy_layers = []
44 | policy_layers += [
45 | nn.Linear(feature_dim, hidden_dim),
46 | nn.ReLU(inplace=True)
47 | ]
48 | # add additional hidden layer for pixels
49 | if obs_type == 'pixels':
50 | policy_layers += [
51 | nn.Linear(hidden_dim, hidden_dim),
52 | nn.ReLU(inplace=True)
53 | ]
54 | policy_layers += [nn.Linear(hidden_dim, action_dim)]
55 |
56 | self.policy = nn.Sequential(*policy_layers)
57 |
58 | self.apply(utils.weight_init)
59 |
60 | def forward(self, obs, std):
61 | h = self.trunk(obs)
62 |
63 | mu = self.policy(h)
64 | mu = torch.tanh(mu)
65 | std = torch.ones_like(mu) * std
66 |
67 | dist = utils.TruncatedNormal(mu, std)
68 | return dist
69 |
70 |
71 | class Critic(nn.Module):
72 | def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim):
73 | super().__init__()
74 |
75 | self.obs_type = obs_type
76 |
77 | if obs_type == 'pixels':
78 | # for pixels actions will be added after trunk
79 | self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim),
80 | nn.LayerNorm(feature_dim), nn.Tanh())
81 | trunk_dim = feature_dim + action_dim
82 | else:
83 | # for states actions come in the beginning
84 | self.trunk = nn.Sequential(
85 | nn.Linear(obs_dim + action_dim, hidden_dim),
86 | nn.LayerNorm(hidden_dim), nn.Tanh())
87 | trunk_dim = hidden_dim
88 |
89 | def make_q():
90 | q_layers = []
91 | q_layers += [
92 | nn.Linear(trunk_dim, hidden_dim),
93 | nn.ReLU(inplace=True)
94 | ]
95 | if obs_type == 'pixels':
96 | q_layers += [
97 | nn.Linear(hidden_dim, hidden_dim),
98 | nn.ReLU(inplace=True)
99 | ]
100 | q_layers += [nn.Linear(hidden_dim, 1)]
101 | return nn.Sequential(*q_layers)
102 |
103 | self.Q1 = make_q()
104 | self.Q2 = make_q()
105 |
106 | self.apply(utils.weight_init)
107 |
108 | def forward(self, obs, action):
109 | inpt = obs if self.obs_type == 'pixels' else torch.cat([obs, action],
110 | dim=-1)
111 | h = self.trunk(inpt)
112 | h = torch.cat([h, action], dim=-1) if self.obs_type == 'pixels' else h
113 |
114 | q1 = self.Q1(h)
115 | q2 = self.Q2(h)
116 |
117 | return q1, q2
118 |
119 |
120 | class DDPGAgent:
121 | def __init__(self,
122 | name,
123 | reward_free,
124 | obs_type,
125 | obs_shape,
126 | action_shape,
127 | device,
128 | lr,
129 | feature_dim,
130 | hidden_dim,
131 | critic_target_tau,
132 | num_expl_steps,
133 | update_every_steps,
134 | stddev_schedule,
135 | nstep,
136 | batch_size,
137 | stddev_clip,
138 | init_critic,
139 | use_tb,
140 | use_wandb,
141 | meta_dim=0):
142 | self.reward_free = reward_free
143 | self.obs_type = obs_type
144 | self.obs_shape = obs_shape
145 | self.action_dim = action_shape[0]
146 | self.hidden_dim = hidden_dim
147 | self.lr = lr
148 | self.device = device
149 | self.critic_target_tau = critic_target_tau
150 | self.update_every_steps = update_every_steps
151 | self.use_tb = use_tb
152 | self.use_wandb = use_wandb
153 | self.num_expl_steps = num_expl_steps
154 | self.stddev_schedule = stddev_schedule
155 | self.stddev_clip = stddev_clip
156 | self.init_critic = init_critic
157 | self.feature_dim = feature_dim
158 | self.solved_meta = None
159 |
160 | # models
161 | if obs_type == 'pixels':
162 | self.aug = utils.RandomShiftsAug(pad=4)
163 | self.encoder = Encoder(obs_shape).to(device)
164 | self.obs_dim = self.encoder.repr_dim + meta_dim
165 | else:
166 | self.aug = nn.Identity()
167 | self.encoder = nn.Identity()
168 | self.obs_dim = obs_shape[0] + meta_dim
169 |
170 | self.actor = Actor(obs_type, self.obs_dim, self.action_dim,
171 | feature_dim, hidden_dim).to(device)
172 |
173 | self.critic = Critic(obs_type, self.obs_dim, self.action_dim,
174 | feature_dim, hidden_dim).to(device)
175 | self.critic_target = Critic(obs_type, self.obs_dim, self.action_dim,
176 | feature_dim, hidden_dim).to(device)
177 | self.critic_target.load_state_dict(self.critic.state_dict())
178 |
179 | # optimizers
180 |
181 | if obs_type == 'pixels':
182 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(),
183 | lr=lr)
184 | else:
185 | self.encoder_opt = None
186 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
187 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
188 |
189 | self.train()
190 | self.critic_target.train()
191 |
192 | def train(self, training=True):
193 | self.training = training
194 | self.encoder.train(training)
195 | self.actor.train(training)
196 | self.critic.train(training)
197 |
198 | def init_from(self, other):
199 | # copy parameters over
200 | utils.hard_update_params(other.encoder, self.encoder)
201 | utils.hard_update_params(other.actor, self.actor)
202 | if self.init_critic:
203 | utils.hard_update_params(other.critic.trunk, self.critic.trunk)
204 |
205 | def get_meta_specs(self):
206 | return tuple()
207 |
208 | def init_meta(self):
209 | return OrderedDict()
210 |
211 | def update_meta(self, meta, global_step, time_step, finetune=False):
212 | return meta
213 |
214 | def act(self, obs, meta, step, eval_mode):
215 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
216 | h = self.encoder(obs)
217 | inputs = [h]
218 | for value in meta.values():
219 | value = torch.as_tensor(value, device=self.device).unsqueeze(0)
220 | inputs.append(value)
221 | inpt = torch.cat(inputs, dim=-1)
222 | #assert obs.shape[-1] == self.obs_shape[-1]
223 | stddev = utils.schedule(self.stddev_schedule, step)
224 | dist = self.actor(inpt, stddev)
225 | if eval_mode:
226 | action = dist.mean
227 | else:
228 | action = dist.sample(clip=None)
229 | if step < self.num_expl_steps:
230 | action.uniform_(-1.0, 1.0)
231 | return action.cpu().numpy()[0]
232 |
233 | def update_critic(self, obs, action, reward, discount, next_obs, step):
234 | metrics = dict()
235 |
236 | with torch.no_grad():
237 | stddev = utils.schedule(self.stddev_schedule, step)
238 | dist = self.actor(next_obs, stddev)
239 | next_action = dist.sample(clip=self.stddev_clip)
240 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
241 | target_V = torch.min(target_Q1, target_Q2)
242 | target_Q = reward + (discount * target_V)
243 |
244 | Q1, Q2 = self.critic(obs, action)
245 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
246 |
247 | if self.use_tb or self.use_wandb:
248 | metrics['critic_target_q'] = target_Q.mean().item()
249 | metrics['critic_q1'] = Q1.mean().item()
250 | metrics['critic_q2'] = Q2.mean().item()
251 | metrics['critic_loss'] = critic_loss.item()
252 |
253 | # optimize critic
254 | if self.encoder_opt is not None:
255 | self.encoder_opt.zero_grad(set_to_none=True)
256 | self.critic_opt.zero_grad(set_to_none=True)
257 | critic_loss.backward()
258 | self.critic_opt.step()
259 | if self.encoder_opt is not None:
260 | self.encoder_opt.step()
261 | return metrics
262 |
263 | def update_actor(self, obs, step):
264 | metrics = dict()
265 |
266 | stddev = utils.schedule(self.stddev_schedule, step)
267 | dist = self.actor(obs, stddev)
268 | action = dist.sample(clip=self.stddev_clip)
269 | log_prob = dist.log_prob(action).sum(-1, keepdim=True)
270 | Q1, Q2 = self.critic(obs, action)
271 | Q = torch.min(Q1, Q2)
272 |
273 | actor_loss = -Q.mean()
274 |
275 | # optimize actor
276 | self.actor_opt.zero_grad(set_to_none=True)
277 | actor_loss.backward()
278 | self.actor_opt.step()
279 |
280 | if self.use_tb or self.use_wandb:
281 | metrics['actor_loss'] = actor_loss.item()
282 | metrics['actor_logprob'] = log_prob.mean().item()
283 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
284 |
285 | return metrics
286 |
287 | def aug_and_encode(self, obs):
288 | obs = self.aug(obs)
289 | return self.encoder(obs)
290 |
291 | def update(self, replay_iter, step):
292 | metrics = dict()
293 | #import ipdb; ipdb.set_trace()
294 |
295 | if step % self.update_every_steps != 0:
296 | return metrics
297 |
298 | batch = next(replay_iter)
299 | obs, action, reward, discount, next_obs = utils.to_torch(
300 | batch, self.device)
301 |
302 | # augment and encode
303 | obs = self.aug_and_encode(obs)
304 | with torch.no_grad():
305 | next_obs = self.aug_and_encode(next_obs)
306 |
307 | if self.use_tb or self.use_wandb:
308 | metrics['batch_reward'] = reward.mean().item()
309 |
310 | # update critic
311 | metrics.update(
312 | self.update_critic(obs, action, reward, discount, next_obs, step))
313 |
314 | # update actor
315 | metrics.update(self.update_actor(obs.detach(), step))
316 |
317 | # update critic target
318 | utils.soft_update_params(self.critic, self.critic_target,
319 | self.critic_target_tau)
320 |
321 | return metrics
322 |
--------------------------------------------------------------------------------
/dmc.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict, deque
2 | from typing import Any, NamedTuple
3 |
4 | import dm_env
5 | import numpy as np
6 | from dm_control import manipulation, suite
7 | from dm_control.suite.wrappers import action_scale, pixels
8 | from dm_env import StepType, specs
9 |
10 | import custom_dmc_tasks as cdmc
11 |
12 |
13 | class ExtendedTimeStep(NamedTuple):
14 | step_type: Any
15 | reward: Any
16 | discount: Any
17 | observation: Any
18 | action: Any
19 |
20 | def first(self):
21 | return self.step_type == StepType.FIRST
22 |
23 | def mid(self):
24 | return self.step_type == StepType.MID
25 |
26 | def last(self):
27 | return self.step_type == StepType.LAST
28 |
29 | def __getitem__(self, attr):
30 | return getattr(self, attr)
31 |
32 |
33 | class FlattenJacoObservationWrapper(dm_env.Environment):
34 | def __init__(self, env):
35 | self._env = env
36 | self._obs_spec = OrderedDict()
37 | wrapped_obs_spec = env.observation_spec().copy()
38 | if 'front_close' in wrapped_obs_spec:
39 | spec = wrapped_obs_spec['front_close']
40 | # drop batch dim
41 | self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:],
42 | dtype=spec.dtype,
43 | minimum=spec.minimum,
44 | maximum=spec.maximum,
45 | name='pixels')
46 | wrapped_obs_spec.pop('front_close')
47 |
48 | for key, spec in wrapped_obs_spec.items():
49 | assert spec.dtype == np.float64
50 | assert type(spec) == specs.Array
51 | dim = np.sum(
52 | np.fromiter((np.int(np.prod(spec.shape))
53 | for spec in wrapped_obs_spec.values()), np.int32))
54 |
55 | self._obs_spec['observations'] = specs.Array(shape=(dim,),
56 | dtype=np.float32,
57 | name='observations')
58 |
59 | def _transform_observation(self, time_step):
60 | obs = OrderedDict()
61 |
62 | if 'front_close' in time_step.observation:
63 | pixels = time_step.observation['front_close']
64 | time_step.observation.pop('front_close')
65 | pixels = np.squeeze(pixels)
66 | obs['pixels'] = pixels
67 |
68 | features = []
69 | for feature in time_step.observation.values():
70 | features.append(feature.ravel())
71 | obs['observations'] = np.concatenate(features, axis=0)
72 | return time_step._replace(observation=obs)
73 |
74 | def reset(self):
75 | time_step = self._env.reset()
76 | return self._transform_observation(time_step)
77 |
78 | def step(self, action):
79 | time_step = self._env.step(action)
80 | return self._transform_observation(time_step)
81 |
82 | def observation_spec(self):
83 | return self._obs_spec
84 |
85 | def action_spec(self):
86 | return self._env.action_spec()
87 |
88 | def __getattr__(self, name):
89 | return getattr(self._env, name)
90 |
91 |
92 | class ActionRepeatWrapper(dm_env.Environment):
93 | def __init__(self, env, num_repeats):
94 | self._env = env
95 | self._num_repeats = num_repeats
96 |
97 | def step(self, action):
98 | reward = 0.0
99 | discount = 1.0
100 | for i in range(self._num_repeats):
101 | time_step = self._env.step(action)
102 | reward += (time_step.reward or 0.0) * discount
103 | discount *= time_step.discount
104 | if time_step.last():
105 | break
106 |
107 | return time_step._replace(reward=reward, discount=discount)
108 |
109 | def observation_spec(self):
110 | return self._env.observation_spec()
111 |
112 | def action_spec(self):
113 | return self._env.action_spec()
114 |
115 | def reset(self):
116 | return self._env.reset()
117 |
118 | def __getattr__(self, name):
119 | return getattr(self._env, name)
120 |
121 |
122 | class FrameStackWrapper(dm_env.Environment):
123 | def __init__(self, env, num_frames, pixels_key='pixels'):
124 | self._env = env
125 | self._num_frames = num_frames
126 | self._frames = deque([], maxlen=num_frames)
127 | self._pixels_key = pixels_key
128 |
129 | wrapped_obs_spec = env.observation_spec()
130 | assert pixels_key in wrapped_obs_spec
131 |
132 | pixels_shape = wrapped_obs_spec[pixels_key].shape
133 | # remove batch dim
134 | if len(pixels_shape) == 4:
135 | pixels_shape = pixels_shape[1:]
136 | self._obs_spec = specs.BoundedArray(shape=np.concatenate(
137 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
138 | dtype=np.uint8,
139 | minimum=0,
140 | maximum=255,
141 | name='observation')
142 |
143 | def _transform_observation(self, time_step):
144 | assert len(self._frames) == self._num_frames
145 | obs = np.concatenate(list(self._frames), axis=0)
146 | return time_step._replace(observation=obs)
147 |
148 | def _extract_pixels(self, time_step):
149 | pixels = time_step.observation[self._pixels_key]
150 | # remove batch dim
151 | if len(pixels.shape) == 4:
152 | pixels = pixels[0]
153 | return pixels.transpose(2, 0, 1).copy()
154 |
155 | def reset(self):
156 | time_step = self._env.reset()
157 | pixels = self._extract_pixels(time_step)
158 | for _ in range(self._num_frames):
159 | self._frames.append(pixels)
160 | return self._transform_observation(time_step)
161 |
162 | def step(self, action):
163 | time_step = self._env.step(action)
164 | pixels = self._extract_pixels(time_step)
165 | self._frames.append(pixels)
166 | return self._transform_observation(time_step)
167 |
168 | def observation_spec(self):
169 | return self._obs_spec
170 |
171 | def action_spec(self):
172 | return self._env.action_spec()
173 |
174 | def __getattr__(self, name):
175 | return getattr(self._env, name)
176 |
177 |
178 | class ActionDTypeWrapper(dm_env.Environment):
179 | def __init__(self, env, dtype):
180 | self._env = env
181 | wrapped_action_spec = env.action_spec()
182 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
183 | dtype,
184 | wrapped_action_spec.minimum,
185 | wrapped_action_spec.maximum,
186 | 'action')
187 |
188 | def step(self, action):
189 | action = action.astype(self._env.action_spec().dtype)
190 | return self._env.step(action)
191 |
192 | def observation_spec(self):
193 | return self._env.observation_spec()
194 |
195 | def action_spec(self):
196 | return self._action_spec
197 |
198 | def reset(self):
199 | return self._env.reset()
200 |
201 | def __getattr__(self, name):
202 | return getattr(self._env, name)
203 |
204 |
205 | class ObservationDTypeWrapper(dm_env.Environment):
206 | def __init__(self, env, dtype):
207 | self._env = env
208 | self._dtype = dtype
209 | wrapped_obs_spec = env.observation_spec()['observations']
210 | self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype,
211 | 'observation')
212 |
213 | def _transform_observation(self, time_step):
214 | obs = time_step.observation['observations'].astype(self._dtype)
215 | return time_step._replace(observation=obs)
216 |
217 | def reset(self):
218 | time_step = self._env.reset()
219 | return self._transform_observation(time_step)
220 |
221 | def step(self, action):
222 | time_step = self._env.step(action)
223 | return self._transform_observation(time_step)
224 |
225 | def observation_spec(self):
226 | return self._obs_spec
227 |
228 | def action_spec(self):
229 | return self._env.action_spec()
230 |
231 | def __getattr__(self, name):
232 | return getattr(self._env, name)
233 |
234 |
235 | class ExtendedTimeStepWrapper(dm_env.Environment):
236 | def __init__(self, env):
237 | self._env = env
238 |
239 | def reset(self):
240 | time_step = self._env.reset()
241 | return self._augment_time_step(time_step)
242 |
243 | def step(self, action):
244 | time_step = self._env.step(action)
245 | return self._augment_time_step(time_step, action)
246 |
247 | def _augment_time_step(self, time_step, action=None):
248 | if action is None:
249 | action_spec = self.action_spec()
250 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
251 | return ExtendedTimeStep(observation=time_step.observation,
252 | step_type=time_step.step_type,
253 | action=action,
254 | reward=time_step.reward or 0.0,
255 | discount=time_step.discount or 1.0)
256 |
257 | def observation_spec(self):
258 | return self._env.observation_spec()
259 |
260 | def action_spec(self):
261 | return self._env.action_spec()
262 |
263 | def __getattr__(self, name):
264 | return getattr(self._env, name)
265 |
266 |
267 | def _make_jaco(obs_type, domain, task, frame_stack, action_repeat, seed):
268 | env = cdmc.make_jaco(task, obs_type, seed)
269 | env = ActionDTypeWrapper(env, np.float32)
270 | env = ActionRepeatWrapper(env, action_repeat)
271 | env = FlattenJacoObservationWrapper(env)
272 | return env
273 |
274 |
275 | def _make_dmc(obs_type, domain, task, frame_stack, action_repeat, seed):
276 | visualize_reward = False
277 | if (domain, task) in suite.ALL_TASKS:
278 | env = suite.load(domain,
279 | task,
280 | task_kwargs=dict(random=seed),
281 | environment_kwargs=dict(flat_observation=True),
282 | visualize_reward=visualize_reward)
283 | else:
284 | env = cdmc.make(domain,
285 | task,
286 | task_kwargs=dict(random=seed),
287 | environment_kwargs=dict(flat_observation=True),
288 | visualize_reward=visualize_reward)
289 |
290 | env = ActionDTypeWrapper(env, np.float32)
291 | env = ActionRepeatWrapper(env, action_repeat)
292 | if obs_type == 'pixels':
293 | # zoom in camera for quadruped
294 | camera_id = dict(quadruped=2).get(domain, 0)
295 | render_kwargs = dict(height=84, width=84, camera_id=camera_id)
296 | env = pixels.Wrapper(env,
297 | pixels_only=True,
298 | render_kwargs=render_kwargs)
299 | return env
300 |
301 |
302 | def make(name, obs_type, frame_stack, action_repeat, seed):
303 | assert obs_type in ['states', 'pixels']
304 | domain, task = name.split('_', 1)
305 | domain = dict(cup='ball_in_cup').get(domain, domain)
306 |
307 | make_fn = _make_jaco if domain == 'jaco' else _make_dmc
308 | env = make_fn(obs_type, domain, task, frame_stack, action_repeat, seed)
309 |
310 | if obs_type == 'pixels':
311 | env = FrameStackWrapper(env, frame_stack)
312 | else:
313 | env = ObservationDTypeWrapper(env, np.float32)
314 |
315 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
316 | env = ExtendedTimeStepWrapper(env)
317 | return env
318 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/quadruped.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/quadruped.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Quadruped Domain."""
17 |
18 | import collections
19 |
20 | from dm_control import mujoco
21 | from dm_control.mujoco.wrapper import mjbindings
22 | from dm_control.rl import control
23 | from dm_control.suite import base
24 | from dm_control.suite import common
25 | from dm_control.utils import containers
26 | from dm_control.utils import rewards
27 | from dm_control.utils import xml_tools
28 | from lxml import etree
29 | import numpy as np
30 | from scipy import ndimage
31 | import os
32 |
33 | enums = mjbindings.enums
34 | mjlib = mjbindings.mjlib
35 |
36 |
37 | _DEFAULT_TIME_LIMIT = 20
38 | _CONTROL_TIMESTEP = .02
39 |
40 | # Horizontal speeds above which the move reward is 1.
41 | _RUN_SPEED = 5
42 | _WALK_SPEED = 0.5
43 |
44 | _JUMP_HEIGHT = 1.0
45 |
46 | # Constants related to terrain generation.
47 | _HEIGHTFIELD_ID = 0
48 | _TERRAIN_SMOOTHNESS = 0.15 # 0.0: maximally bumpy; 1.0: completely smooth.
49 | _TERRAIN_BUMP_SCALE = 2 # Spatial scale of terrain bumps (in meters).
50 |
51 | # Named model elements.
52 | _TOES = ['toe_front_left', 'toe_back_left', 'toe_back_right', 'toe_front_right']
53 | _WALLS = ['wall_px', 'wall_py', 'wall_nx', 'wall_ny']
54 |
55 | SUITE = containers.TaggedTasks()
56 |
57 | def make(task,
58 | task_kwargs=None,
59 | environment_kwargs=None,
60 | visualize_reward=False):
61 | task_kwargs = task_kwargs or {}
62 | if environment_kwargs is not None:
63 | task_kwargs = task_kwargs.copy()
64 | task_kwargs['environment_kwargs'] = environment_kwargs
65 | env = SUITE[task](**task_kwargs)
66 | env.task.visualize_reward = visualize_reward
67 | return env
68 |
69 | def get_model_and_assets():
70 | """Returns a tuple containing the model XML string and a dict of assets."""
71 | root_dir = os.path.dirname(os.path.dirname(__file__))
72 | xml = resources.GetResource(
73 | os.path.join(root_dir, 'custom_dmc_tasks', 'quadruped.xml'))
74 | return xml, common.ASSETS
75 |
76 |
77 | def make_model(floor_size=None, terrain=False, rangefinders=False,
78 | walls_and_ball=False):
79 | """Returns the model XML string."""
80 | root_dir = os.path.dirname(os.path.dirname(__file__))
81 | xml_string = common.read_model(os.path.join(root_dir, 'custom_dmc_tasks', 'quadruped.xml'))
82 | parser = etree.XMLParser(remove_blank_text=True)
83 | mjcf = etree.XML(xml_string, parser)
84 |
85 | # Set floor size.
86 | if floor_size is not None:
87 | floor_geom = mjcf.find('.//geom[@name=\'floor\']')
88 | floor_geom.attrib['size'] = f'{floor_size} {floor_size} .5'
89 |
90 | # Remove walls, ball and target.
91 | if not walls_and_ball:
92 | for wall in _WALLS:
93 | wall_geom = xml_tools.find_element(mjcf, 'geom', wall)
94 | wall_geom.getparent().remove(wall_geom)
95 |
96 | # Remove ball.
97 | ball_body = xml_tools.find_element(mjcf, 'body', 'ball')
98 | ball_body.getparent().remove(ball_body)
99 |
100 | # Remove target.
101 | target_site = xml_tools.find_element(mjcf, 'site', 'target')
102 | target_site.getparent().remove(target_site)
103 |
104 | # Remove terrain.
105 | if not terrain:
106 | terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain')
107 | terrain_geom.getparent().remove(terrain_geom)
108 |
109 | # Remove rangefinders if they're not used, as range computations can be
110 | # expensive, especially in a scene with heightfields.
111 | if not rangefinders:
112 | rangefinder_sensors = mjcf.findall('.//rangefinder')
113 | for rf in rangefinder_sensors:
114 | rf.getparent().remove(rf)
115 |
116 | return etree.tostring(mjcf, pretty_print=True)
117 |
118 |
119 | @SUITE.add()
120 | def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
121 | """Returns the Walk task."""
122 | xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
123 | physics = Physics.from_xml_string(xml_string, common.ASSETS)
124 | task = Stand(random=random)
125 | environment_kwargs = environment_kwargs or {}
126 | return control.Environment(physics, task, time_limit=time_limit,
127 | control_timestep=_CONTROL_TIMESTEP,
128 | **environment_kwargs)
129 |
130 | @SUITE.add()
131 | def jump(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
132 | """Returns the Walk task."""
133 | xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
134 | physics = Physics.from_xml_string(xml_string, common.ASSETS)
135 | task = Jump(desired_height=_JUMP_HEIGHT, random=random)
136 | environment_kwargs = environment_kwargs or {}
137 | return control.Environment(physics, task, time_limit=time_limit,
138 | control_timestep=_CONTROL_TIMESTEP,
139 | **environment_kwargs)
140 |
141 | @SUITE.add()
142 | def roll(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
143 | """Returns the Walk task."""
144 | xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
145 | physics = Physics.from_xml_string(xml_string, common.ASSETS)
146 | task = Roll(desired_speed=_WALK_SPEED, random=random)
147 | environment_kwargs = environment_kwargs or {}
148 | return control.Environment(physics, task, time_limit=time_limit,
149 | control_timestep=_CONTROL_TIMESTEP,
150 | **environment_kwargs)
151 |
152 | @SUITE.add()
153 | def roll_fast(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
154 | """Returns the Walk task."""
155 | xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
156 | physics = Physics.from_xml_string(xml_string, common.ASSETS)
157 | task = Roll(desired_speed=_RUN_SPEED, random=random)
158 | environment_kwargs = environment_kwargs or {}
159 | return control.Environment(physics, task, time_limit=time_limit,
160 | control_timestep=_CONTROL_TIMESTEP,
161 | **environment_kwargs)
162 |
163 | @SUITE.add()
164 | def escape(time_limit=_DEFAULT_TIME_LIMIT, random=None,
165 | environment_kwargs=None):
166 | """Returns the Escape task."""
167 | xml_string = make_model(floor_size=40, terrain=True, rangefinders=True)
168 | physics = Physics.from_xml_string(xml_string, common.ASSETS)
169 | task = Escape(random=random)
170 | environment_kwargs = environment_kwargs or {}
171 | return control.Environment(physics, task, time_limit=time_limit,
172 | control_timestep=_CONTROL_TIMESTEP,
173 | **environment_kwargs)
174 |
175 |
176 | @SUITE.add()
177 | def fetch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
178 | """Returns the Fetch task."""
179 | xml_string = make_model(walls_and_ball=True)
180 | physics = Physics.from_xml_string(xml_string, common.ASSETS)
181 | task = Fetch(random=random)
182 | environment_kwargs = environment_kwargs or {}
183 | return control.Environment(physics, task, time_limit=time_limit,
184 | control_timestep=_CONTROL_TIMESTEP,
185 | **environment_kwargs)
186 |
187 |
188 | class Physics(mujoco.Physics):
189 | """Physics simulation with additional features for the Quadruped domain."""
190 |
191 | def _reload_from_data(self, data):
192 | super()._reload_from_data(data)
193 | # Clear cached sensor names when the physics is reloaded.
194 | self._sensor_types_to_names = {}
195 | self._hinge_names = []
196 |
197 | def _get_sensor_names(self, *sensor_types):
198 | try:
199 | sensor_names = self._sensor_types_to_names[sensor_types]
200 | except KeyError:
201 | [sensor_ids] = np.where(np.in1d(self.model.sensor_type, sensor_types))
202 | sensor_names = [self.model.id2name(s_id, 'sensor') for s_id in sensor_ids]
203 | self._sensor_types_to_names[sensor_types] = sensor_names
204 | return sensor_names
205 |
206 | def torso_upright(self):
207 | """Returns the dot-product of the torso z-axis and the global z-axis."""
208 | return np.asarray(self.named.data.xmat['torso', 'zz'])
209 |
210 | def torso_velocity(self):
211 | """Returns the velocity of the torso, in the local frame."""
212 | return self.named.data.sensordata['velocimeter'].copy()
213 |
214 | def com_height(self):
215 | return self.named.data.sensordata['center_of_mass'].copy()[2]
216 |
217 | def egocentric_state(self):
218 | """Returns the state without global orientation or position."""
219 | if not self._hinge_names:
220 | [hinge_ids] = np.nonzero(self.model.jnt_type ==
221 | enums.mjtJoint.mjJNT_HINGE)
222 | self._hinge_names = [self.model.id2name(j_id, 'joint')
223 | for j_id in hinge_ids]
224 | return np.hstack((self.named.data.qpos[self._hinge_names],
225 | self.named.data.qvel[self._hinge_names],
226 | self.data.act))
227 |
228 | def toe_positions(self):
229 | """Returns toe positions in egocentric frame."""
230 | torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
231 | torso_pos = self.named.data.xpos['torso']
232 | torso_to_toe = self.named.data.xpos[_TOES] - torso_pos
233 | return torso_to_toe.dot(torso_frame)
234 |
235 | def force_torque(self):
236 | """Returns scaled force/torque sensor readings at the toes."""
237 | force_torque_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_FORCE,
238 | enums.mjtSensor.mjSENS_TORQUE)
239 | return np.arcsinh(self.named.data.sensordata[force_torque_sensors])
240 |
241 | def imu(self):
242 | """Returns IMU-like sensor readings."""
243 | imu_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_GYRO,
244 | enums.mjtSensor.mjSENS_ACCELEROMETER)
245 | return self.named.data.sensordata[imu_sensors]
246 |
247 | def rangefinder(self):
248 | """Returns scaled rangefinder sensor readings."""
249 | rf_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_RANGEFINDER)
250 | rf_readings = self.named.data.sensordata[rf_sensors]
251 | no_intersection = -1.0
252 | return np.where(rf_readings == no_intersection, 1.0, np.tanh(rf_readings))
253 |
254 | def origin_distance(self):
255 | """Returns the distance from the origin to the workspace."""
256 | return np.asarray(np.linalg.norm(self.named.data.site_xpos['workspace']))
257 |
258 | def origin(self):
259 | """Returns origin position in the torso frame."""
260 | torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
261 | torso_pos = self.named.data.xpos['torso']
262 | return -torso_pos.dot(torso_frame)
263 |
264 | def ball_state(self):
265 | """Returns ball position and velocity relative to the torso frame."""
266 | data = self.named.data
267 | torso_frame = data.xmat['torso'].reshape(3, 3)
268 | ball_rel_pos = data.xpos['ball'] - data.xpos['torso']
269 | ball_rel_vel = data.qvel['ball_root'][:3] - data.qvel['root'][:3]
270 | ball_rot_vel = data.qvel['ball_root'][3:]
271 | ball_state = np.vstack((ball_rel_pos, ball_rel_vel, ball_rot_vel))
272 | return ball_state.dot(torso_frame).ravel()
273 |
274 | def target_position(self):
275 | """Returns target position in torso frame."""
276 | torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
277 | torso_pos = self.named.data.xpos['torso']
278 | torso_to_target = self.named.data.site_xpos['target'] - torso_pos
279 | return torso_to_target.dot(torso_frame)
280 |
281 | def ball_to_target_distance(self):
282 | """Returns horizontal distance from the ball to the target."""
283 | ball_to_target = (self.named.data.site_xpos['target'] -
284 | self.named.data.xpos['ball'])
285 | return np.linalg.norm(ball_to_target[:2])
286 |
287 | def self_to_ball_distance(self):
288 | """Returns horizontal distance from the quadruped workspace to the ball."""
289 | self_to_ball = (self.named.data.site_xpos['workspace']
290 | -self.named.data.xpos['ball'])
291 | return np.linalg.norm(self_to_ball[:2])
292 |
293 |
294 | def _find_non_contacting_height(physics, orientation, x_pos=0.0, y_pos=0.0):
295 | """Find a height with no contacts given a body orientation.
296 | Args:
297 | physics: An instance of `Physics`.
298 | orientation: A quaternion.
299 | x_pos: A float. Position along global x-axis.
300 | y_pos: A float. Position along global y-axis.
301 | Raises:
302 | RuntimeError: If a non-contacting configuration has not been found after
303 | 10,000 attempts.
304 | """
305 | z_pos = 0.0 # Start embedded in the floor.
306 | num_contacts = 1
307 | num_attempts = 0
308 | # Move up in 1cm increments until no contacts.
309 | while num_contacts > 0:
310 | try:
311 | with physics.reset_context():
312 | physics.named.data.qpos['root'][:3] = x_pos, y_pos, z_pos
313 | physics.named.data.qpos['root'][3:] = orientation
314 | except control.PhysicsError:
315 | # We may encounter a PhysicsError here due to filling the contact
316 | # buffer, in which case we simply increment the height and continue.
317 | pass
318 | num_contacts = physics.data.ncon
319 | z_pos += 0.01
320 | num_attempts += 1
321 | if num_attempts > 10000:
322 | raise RuntimeError('Failed to find a non-contacting configuration.')
323 |
324 |
325 | def _common_observations(physics):
326 | """Returns the observations common to all tasks."""
327 | obs = collections.OrderedDict()
328 | obs['egocentric_state'] = physics.egocentric_state()
329 | obs['torso_velocity'] = physics.torso_velocity()
330 | obs['torso_upright'] = physics.torso_upright()
331 | obs['imu'] = physics.imu()
332 | obs['force_torque'] = physics.force_torque()
333 | return obs
334 |
335 |
336 | def _upright_reward(physics, deviation_angle=0):
337 | """Returns a reward proportional to how upright the torso is.
338 | Args:
339 | physics: an instance of `Physics`.
340 | deviation_angle: A float, in degrees. The reward is 0 when the torso is
341 | exactly upside-down and 1 when the torso's z-axis is less than
342 | `deviation_angle` away from the global z-axis.
343 | """
344 | deviation = np.cos(np.deg2rad(deviation_angle))
345 | return rewards.tolerance(
346 | physics.torso_upright(),
347 | bounds=(deviation, float('inf')),
348 | sigmoid='linear',
349 | margin=1 + deviation,
350 | value_at_margin=0)
351 |
352 |
353 | class Move(base.Task):
354 | """A quadruped task solved by moving forward at a designated speed."""
355 |
356 | def __init__(self, desired_speed, random=None):
357 | """Initializes an instance of `Move`.
358 | Args:
359 | desired_speed: A float. If this value is zero, reward is given simply
360 | for standing upright. Otherwise this specifies the horizontal velocity
361 | at which the velocity-dependent reward component is maximized.
362 | random: Optional, either a `numpy.random.RandomState` instance, an
363 | integer seed for creating a new `RandomState`, or None to select a seed
364 | automatically (default).
365 | """
366 | self._desired_speed = desired_speed
367 | super().__init__(random=random)
368 |
369 | def initialize_episode(self, physics):
370 | """Sets the state of the environment at the start of each episode.
371 | Args:
372 | physics: An instance of `Physics`.
373 | """
374 | # Initial configuration.
375 | orientation = self.random.randn(4)
376 | orientation /= np.linalg.norm(orientation)
377 | _find_non_contacting_height(physics, orientation)
378 | super().initialize_episode(physics)
379 |
380 | def get_observation(self, physics):
381 | """Returns an observation to the agent."""
382 | return _common_observations(physics)
383 |
384 | def get_reward(self, physics):
385 | """Returns a reward to the agent."""
386 |
387 | # Move reward term.
388 | move_reward = rewards.tolerance(
389 | physics.torso_velocity()[0],
390 | bounds=(self._desired_speed, float('inf')),
391 | margin=self._desired_speed,
392 | value_at_margin=0.5,
393 | sigmoid='linear')
394 |
395 | return _upright_reward(physics) * move_reward
396 |
397 |
398 | class Stand(base.Task):
399 | """A quadruped task solved by moving forward at a designated speed."""
400 |
401 | def __init__(self, random=None):
402 | """Initializes an instance of `Move`.
403 | Args:
404 | desired_speed: A float. If this value is zero, reward is given simply
405 | for standing upright. Otherwise this specifies the horizontal velocity
406 | at which the velocity-dependent reward component is maximized.
407 | random: Optional, either a `numpy.random.RandomState` instance, an
408 | integer seed for creating a new `RandomState`, or None to select a seed
409 | automatically (default).
410 | """
411 | super().__init__(random=random)
412 |
413 | def initialize_episode(self, physics):
414 | """Sets the state of the environment at the start of each episode.
415 | Args:
416 | physics: An instance of `Physics`.
417 | """
418 | # Initial configuration.
419 | orientation = self.random.randn(4)
420 | orientation /= np.linalg.norm(orientation)
421 | _find_non_contacting_height(physics, orientation)
422 | super().initialize_episode(physics)
423 |
424 | def get_observation(self, physics):
425 | """Returns an observation to the agent."""
426 | return _common_observations(physics)
427 |
428 | def get_reward(self, physics):
429 | """Returns a reward to the agent."""
430 |
431 | return _upright_reward(physics)
432 |
433 | class Jump(base.Task):
434 | """A quadruped task solved by moving forward at a designated speed."""
435 |
436 | def __init__(self, desired_height, random=None):
437 | """Initializes an instance of `Move`.
438 | Args:
439 | desired_speed: A float. If this value is zero, reward is given simply
440 | for standing upright. Otherwise this specifies the horizontal velocity
441 | at which the velocity-dependent reward component is maximized.
442 | random: Optional, either a `numpy.random.RandomState` instance, an
443 | integer seed for creating a new `RandomState`, or None to select a seed
444 | automatically (default).
445 | """
446 | self._desired_height = desired_height
447 | super().__init__(random=random)
448 |
449 | def initialize_episode(self, physics):
450 | """Sets the state of the environment at the start of each episode.
451 | Args:
452 | physics: An instance of `Physics`.
453 | """
454 | # Initial configuration.
455 | orientation = self.random.randn(4)
456 | orientation /= np.linalg.norm(orientation)
457 | _find_non_contacting_height(physics, orientation)
458 | super().initialize_episode(physics)
459 |
460 | def get_observation(self, physics):
461 | """Returns an observation to the agent."""
462 | return _common_observations(physics)
463 |
464 | def get_reward(self, physics):
465 | """Returns a reward to the agent."""
466 |
467 | # Move reward term.
468 | jump_up = rewards.tolerance(
469 | physics.com_height(),
470 | bounds=(self._desired_height, float('inf')),
471 | margin=self._desired_height,
472 | value_at_margin=0.5,
473 | sigmoid='linear')
474 |
475 | return _upright_reward(physics) * jump_up
476 |
477 |
478 | class Roll(base.Task):
479 | """A quadruped task solved by moving forward at a designated speed."""
480 |
481 | def __init__(self, desired_speed, random=None):
482 | """Initializes an instance of `Move`.
483 | Args:
484 | desired_speed: A float. If this value is zero, reward is given simply
485 | for standing upright. Otherwise this specifies the horizontal velocity
486 | at which the velocity-dependent reward component is maximized.
487 | random: Optional, either a `numpy.random.RandomState` instance, an
488 | integer seed for creating a new `RandomState`, or None to select a seed
489 | automatically (default).
490 | """
491 | self._desired_speed = desired_speed
492 | super().__init__(random=random)
493 |
494 | def initialize_episode(self, physics):
495 | """Sets the state of the environment at the start of each episode.
496 | Args:
497 | physics: An instance of `Physics`.
498 | """
499 | # Initial configuration.
500 | orientation = self.random.randn(4)
501 | orientation /= np.linalg.norm(orientation)
502 | _find_non_contacting_height(physics, orientation)
503 | super().initialize_episode(physics)
504 |
505 | def get_observation(self, physics):
506 | """Returns an observation to the agent."""
507 | return _common_observations(physics)
508 |
509 | def get_reward(self, physics):
510 | """Returns a reward to the agent."""
511 | # Move reward term.
512 | move_reward = rewards.tolerance(
513 | np.linalg.norm(physics.torso_velocity()),
514 | bounds=(self._desired_speed, float('inf')),
515 | margin=self._desired_speed,
516 | value_at_margin=0.5,
517 | sigmoid='linear')
518 |
519 | return _upright_reward(physics) * move_reward
520 |
521 |
522 | class Escape(base.Task):
523 | """A quadruped task solved by escaping a bowl-shaped terrain."""
524 |
525 | def initialize_episode(self, physics):
526 | """Sets the state of the environment at the start of each episode.
527 | Args:
528 | physics: An instance of `Physics`.
529 | """
530 | # Get heightfield resolution, assert that it is square.
531 | res = physics.model.hfield_nrow[_HEIGHTFIELD_ID]
532 | assert res == physics.model.hfield_ncol[_HEIGHTFIELD_ID]
533 | # Sinusoidal bowl shape.
534 | row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j]
535 | radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .04, 1)
536 | bowl_shape = .5 - np.cos(2*np.pi*radius)/2
537 | # Random smooth bumps.
538 | terrain_size = 2 * physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
539 | bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE)
540 | bumps = self.random.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res))
541 | smooth_bumps = ndimage.zoom(bumps, res / float(bump_res))
542 | # Terrain is elementwise product.
543 | terrain = bowl_shape * smooth_bumps
544 | start_idx = physics.model.hfield_adr[_HEIGHTFIELD_ID]
545 | physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel()
546 | super().initialize_episode(physics)
547 |
548 | # If we have a rendering context, we need to re-upload the modified
549 | # heightfield data.
550 | if physics.contexts:
551 | with physics.contexts.gl.make_current() as ctx:
552 | ctx.call(mjlib.mjr_uploadHField,
553 | physics.model.ptr,
554 | physics.contexts.mujoco.ptr,
555 | _HEIGHTFIELD_ID)
556 |
557 | # Initial configuration.
558 | orientation = self.random.randn(4)
559 | orientation /= np.linalg.norm(orientation)
560 | _find_non_contacting_height(physics, orientation)
561 |
562 | def get_observation(self, physics):
563 | """Returns an observation to the agent."""
564 | obs = _common_observations(physics)
565 | obs['origin'] = physics.origin()
566 | obs['rangefinder'] = physics.rangefinder()
567 | return obs
568 |
569 | def get_reward(self, physics):
570 | """Returns a reward to the agent."""
571 |
572 | # Escape reward term.
573 | terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
574 | escape_reward = rewards.tolerance(
575 | physics.origin_distance(),
576 | bounds=(terrain_size, float('inf')),
577 | margin=terrain_size,
578 | value_at_margin=0,
579 | sigmoid='linear')
580 |
581 | return _upright_reward(physics, deviation_angle=20) * escape_reward
582 |
583 |
584 | class Fetch(base.Task):
585 | """A quadruped task solved by bringing a ball to the origin."""
586 |
587 | def initialize_episode(self, physics):
588 | """Sets the state of the environment at the start of each episode.
589 | Args:
590 | physics: An instance of `Physics`.
591 | """
592 | # Initial configuration, random azimuth and horizontal position.
593 | azimuth = self.random.uniform(0, 2*np.pi)
594 | orientation = np.array((np.cos(azimuth/2), 0, 0, np.sin(azimuth/2)))
595 | spawn_radius = 0.9 * physics.named.model.geom_size['floor', 0]
596 | x_pos, y_pos = self.random.uniform(-spawn_radius, spawn_radius, size=(2,))
597 | _find_non_contacting_height(physics, orientation, x_pos, y_pos)
598 |
599 | # Initial ball state.
600 | physics.named.data.qpos['ball_root'][:2] = self.random.uniform(
601 | -spawn_radius, spawn_radius, size=(2,))
602 | physics.named.data.qpos['ball_root'][2] = 2
603 | physics.named.data.qvel['ball_root'][:2] = 5*self.random.randn(2)
604 | super().initialize_episode(physics)
605 |
606 | def get_observation(self, physics):
607 | """Returns an observation to the agent."""
608 | obs = _common_observations(physics)
609 | obs['ball_state'] = physics.ball_state()
610 | obs['target_position'] = physics.target_position()
611 | return obs
612 |
613 | def get_reward(self, physics):
614 | """Returns a reward to the agent."""
615 |
616 | # Reward for moving close to the ball.
617 | arena_radius = physics.named.model.geom_size['floor', 0] * np.sqrt(2)
618 | workspace_radius = physics.named.model.site_size['workspace', 0]
619 | ball_radius = physics.named.model.geom_size['ball', 0]
620 | reach_reward = rewards.tolerance(
621 | physics.self_to_ball_distance(),
622 | bounds=(0, workspace_radius+ball_radius),
623 | sigmoid='linear',
624 | margin=arena_radius, value_at_margin=0)
625 |
626 | # Reward for bringing the ball to the target.
627 | target_radius = physics.named.model.site_size['target', 0]
628 | fetch_reward = rewards.tolerance(
629 | physics.ball_to_target_distance(),
630 | bounds=(0, target_radius),
631 | sigmoid='linear',
632 | margin=arena_radius, value_at_margin=0)
633 |
634 | reach_then_fetch = reach_reward * (0.5 + 0.5*fetch_reward)
635 |
636 | return _upright_reward(physics) * reach_then_fetch
--------------------------------------------------------------------------------