├── .gitignore ├── agent ├── ddpg.yaml ├── becl.yaml ├── becl.py └── ddpg.py ├── dmc_benchmark.py ├── conda_env.yml ├── LICENSE ├── pretrain.yaml ├── custom_dmc_tasks ├── __init__.py ├── hopper.xml ├── walker.xml ├── cheetah.xml ├── cheetah.py ├── walker.py ├── hopper.py ├── jaco.py ├── quadruped.xml └── quadruped.py ├── finetune.yaml ├── README.md ├── video.py ├── logger.py ├── replay_buffer.py ├── pretrain.py ├── finetune.py ├── utils.py └── dmc.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.DS_Store -------------------------------------------------------------------------------- /agent/ddpg.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.ddpg.DDPGAgent 3 | name: ddpg 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | nstep: 3 20 | batch_size: 1024 # 256 for pixels 21 | init_critic: true 22 | update_encoder: ${update_encoder} -------------------------------------------------------------------------------- /dmc_benchmark.py: -------------------------------------------------------------------------------- 1 | DOMAINS = [ 2 | 'walker', 3 | 'quadruped', 4 | 'jaco', 5 | ] 6 | 7 | WALKER_TASKS = [ 8 | 'walker_stand', 9 | 'walker_walk', 10 | 'walker_run', 11 | 'walker_flip', 12 | ] 13 | 14 | QUADRUPED_TASKS = [ 15 | 'quadruped_walk', 16 | 'quadruped_run', 17 | 'quadruped_stand', 18 | 'quadruped_jump', 19 | ] 20 | 21 | JACO_TASKS = [ 22 | 'jaco_reach_top_left', 23 | 'jaco_reach_top_right', 24 | 'jaco_reach_bottom_left', 25 | 'jaco_reach_bottom_right', 26 | ] 27 | 28 | TASKS = WALKER_TASKS + QUADRUPED_TASKS + JACO_TASKS 29 | 30 | PRIMAL_TASKS = { 31 | 'walker': 'walker_stand', 32 | 'jaco': 'jaco_reach_top_left', 33 | 'quadruped': 'quadruped_walk' 34 | } 35 | -------------------------------------------------------------------------------- /agent/becl.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.becl.BECLAgent 3 | name: becl 4 | reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | skill_dim: 16 20 | update_skill_every_step: 50 21 | nstep: 3 22 | batch_size: 1024 23 | init_critic: true 24 | update_encoder: ${update_encoder} 25 | 26 | # extra hyperparameter 27 | contrastive_update_rate: 3 28 | temperature: 0.5 29 | 30 | # skill finetuning ablation 31 | skill: -1 32 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: urlb 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.8 6 | - pip=21.1.3 7 | - numpy=1.19.2 8 | - absl-py=0.13.0 9 | - pyparsing=2.4.7 10 | - jupyterlab=3.0.14 11 | - scikit-image=0.18.1 12 | - nvidia::cudatoolkit=11.1 13 | - pytorch::pytorch 14 | - pytorch::torchvision 15 | - pytorch::torchaudio 16 | - pip: 17 | - termcolor==1.1.0 18 | - git+git://github.com/deepmind/dm_control.git 19 | - tb-nightly 20 | - imageio==2.9.0 21 | - imageio-ffmpeg==0.4.4 22 | - hydra-core==1.1.0 23 | - hydra-submitit-launcher==1.1.5 24 | - pandas==1.3.0 25 | - ipdb==0.13.9 26 | - yapf==0.31.0 27 | - mujoco_py==2.0.2.13 28 | - sklearn==0.0 29 | - matplotlib==3.4.2 30 | - opencv-python==4.5.3.56 31 | - wandb==0.11.1 32 | - moviepy==1.0.3 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pretrain.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: ddpg 3 | - override hydra/launcher: submitit_local 4 | 5 | # mode 6 | reward_free: true 7 | # task settings 8 | domain: walker # primal task will be infered in runtime 9 | obs_type: states # [states, pixels] 10 | frame_stack: 3 # only works if obs_type=pixels 11 | action_repeat: 1 # set to 2 for pixels 12 | discount: 0.99 13 | # train settings 14 | num_train_frames: 2000010 15 | num_seed_frames: 4000 16 | # eval 17 | eval_every_frames: 10000 18 | num_eval_episodes: 10 19 | # snapshot 20 | snapshots: [100000, 500000, 1000000, 2000000] 21 | snapshot_dir: ../../../models/${obs_type}/${domain}/${agent.name}/${seed} 22 | # replay buffer 23 | replay_buffer_size: 1000000 24 | replay_buffer_num_workers: 4 25 | batch_size: ${agent.batch_size} 26 | nstep: ${agent.nstep} 27 | update_encoder: true # should always be true for pre-training 28 | # misc 29 | seed: 1 30 | device: cuda 31 | save_video: true 32 | save_train_video: false 33 | use_tb: false 34 | use_wandb: false 35 | # experiment 36 | experiment: exp 37 | 38 | 39 | hydra: 40 | run: 41 | dir: ./exp_local/pretrain${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name} 42 | sweep: 43 | dir: ./exp_local/pretrain${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name} 44 | subdir: ${hydra.job.num} 45 | launcher: 46 | timeout_min: 4300 47 | cpus_per_task: 10 48 | gpus_per_node: 1 49 | tasks_per_node: 1 50 | mem_gb: 160 51 | nodes: 1 52 | submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}/.slurm 53 | -------------------------------------------------------------------------------- /custom_dmc_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from custom_dmc_tasks import cheetah 2 | from custom_dmc_tasks import walker 3 | from custom_dmc_tasks import hopper 4 | from custom_dmc_tasks import quadruped 5 | from custom_dmc_tasks import jaco 6 | 7 | 8 | def make(domain, task, 9 | task_kwargs=None, 10 | environment_kwargs=None, 11 | visualize_reward=False): 12 | 13 | if domain == 'cheetah': 14 | return cheetah.make(task, 15 | task_kwargs=task_kwargs, 16 | environment_kwargs=environment_kwargs, 17 | visualize_reward=visualize_reward) 18 | elif domain == 'walker': 19 | return walker.make(task, 20 | task_kwargs=task_kwargs, 21 | environment_kwargs=environment_kwargs, 22 | visualize_reward=visualize_reward) 23 | elif domain == 'hopper': 24 | return hopper.make(task, 25 | task_kwargs=task_kwargs, 26 | environment_kwargs=environment_kwargs, 27 | visualize_reward=visualize_reward) 28 | elif domain == 'quadruped': 29 | return quadruped.make(task, 30 | task_kwargs=task_kwargs, 31 | environment_kwargs=environment_kwargs, 32 | visualize_reward=visualize_reward) 33 | else: 34 | raise f'{task} not found' 35 | 36 | assert None 37 | 38 | 39 | def make_jaco(task, obs_type, seed): 40 | return jaco.make(task, obs_type, seed) -------------------------------------------------------------------------------- /finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: ddpg 3 | - override hydra/launcher: submitit_local 4 | 5 | # mode 6 | reward_free: false 7 | # task settings 8 | task: walker_stand 9 | obs_type: states # [states, pixels] 10 | frame_stack: 3 # only works if obs_type=pixels 11 | action_repeat: 1 # set to 2 for pixels 12 | discount: 0.99 13 | # train settings 14 | num_train_frames: 100010 15 | num_seed_frames: 4000 16 | # eval 17 | eval_every_frames: 10000 18 | num_eval_episodes: 10 19 | # pretrained 20 | snapshot_ts: 2000000 21 | snapshot_base_dir: ./models 22 | # replay buffer 23 | replay_buffer_size: 1000000 24 | replay_buffer_num_workers: 4 25 | batch_size: ${agent.batch_size} 26 | nstep: ${agent.nstep} 27 | update_encoder: false # can be either true or false depending if we want to fine-tune encoder 28 | # misc 29 | seed: 1 30 | device: cuda 31 | save_video: true 32 | save_train_video: false 33 | use_tb: false 34 | use_wandb: false 35 | # define specific path for experiment 36 | experiment: ${agent.name}_seed_${seed} 37 | domain: walker 38 | extra_path: . 39 | 40 | hydra: 41 | run: 42 | dir: ./exp_local/finetune_${extra_path}/${domain}/${agent.name}/fintune_${experiment}_${task}_${now:%H%M%S} 43 | sweep: 44 | dir: ./exp_local/finetune_${extra_path}/${domain}/${agent.name}/fintune_${experiment}_${task}_${now:%H%M%S} 45 | subdir: ${hydra.job.num} 46 | launcher: 47 | timeout_min: 4300 48 | cpus_per_task: 10 49 | gpus_per_node: 1 50 | tasks_per_node: 1 51 | mem_gb: 160 52 | nodes: 1 53 | submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}/.slurm 54 | -------------------------------------------------------------------------------- /custom_dmc_tasks/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Behavior Contrastive Learning (BeCL) 2 | 3 | This is the official codebase for ICML 2023 paper [BeCL:Behavior Contrastive Learning for Unsupervised Skill Discovery](https://arxiv.org/abs/2305.04477), which utilizes contrastive learning as intrinsic motivation for unsupervised skill discovery. 4 | 5 | If you find this paper useful for your research, please cite: 6 | ``` 7 | @misc{yang2023behavior, 8 | title={Behavior Contrastive Learning for Unsupervised Skill Discovery}, 9 | author={Rushuai Yang and Chenjia Bai and Hongyi Guo and Siyuan Li and Bin Zhao and Zhen Wang and Peng Liu and Xuelong Li}, 10 | year={2023}, 11 | eprint={2305.04477}, 12 | archivePrefix={arXiv}, 13 | primaryClass={cs.LG} 14 | } 15 | ``` 16 | 17 | This codebase is built on top of the [Unsupervised Reinforcement Learning Benchmark (URLB) codebase](https://github.com/rll-research/url_benchmark). Our method `BeCL` is implemented in `agents/becl.py` and the config is specified in `agents/becl.yaml`. 18 | 19 | To pre-train BeCL, run the following command: 20 | 21 | ``` sh 22 | python pretrain.py agent=becl domain=walker seed=3 23 | ``` 24 | 25 | This script will produce several agent snapshots after training for `100k`, `500k`, `1M`, and `2M` frames and snapshots will be stored in `./models/states//// `. (i.e. the snapshots path is `./models/states/walker/becl/3/ `). 26 | 27 | To finetune BeCL, run the following command: 28 | 29 | ```sh 30 | python finetune.py task=walker_stand obs_type=states agent=becl reward_free=false seed=3 domain=walker snapshot_ts=2000000 31 | ``` 32 | 33 | This will load a snapshot stored in `./models/states/walker/becl/3/snapshot_2000000.pt`, initialize `DDPG` with it (both the actor and critic), and start training on `walker_stand` using the extrinsic reward of the task. 34 | 35 | ## Requirements 36 | 37 | We assume you have access to a GPU that can run CUDA 10.2 and CUDNN 8. Then, the simplest way to install all required dependencies is to create an anaconda environment by running 38 | ```sh 39 | conda env create -f conda_env.yml 40 | ``` 41 | After the installation ends you can activate your environment with 42 | ```sh 43 | conda activate urlb 44 | ``` 45 | 46 | ## Available Domains 47 | We support the following domains. 48 | | Domain | Tasks | 49 | |---|---| 50 | | `walker` | `stand`, `walk`, `run`, `flip` | 51 | | `quadruped` | `walk`, `run`, `stand`, `jump` | 52 | | `jaco` | `reach_top_left`, `reach_top_right`, `reach_bottom_left`, `reach_bottom_right` | 53 | 54 | ### Monitoring 55 | Logs are stored in the `exp_local` folder. To launch tensorboard run: 56 | ```sh 57 | tensorboard --logdir exp_local 58 | ``` 59 | The console output is also available in the form: 60 | ``` 61 | | train | F: 6000 | S: 3000 | E: 6 | L: 1000 | R: 5.5177 | FPS: 96.7586 | T: 0:00:42 62 | ``` 63 | a training entry decodes as 64 | ``` 65 | F : total number of environment frames 66 | S : total number of agent steps 67 | E : total number of episodes 68 | R : episode return 69 | FPS: training throughput (frames per second) 70 | T : total training time 71 | ``` 72 | -------------------------------------------------------------------------------- /custom_dmc_tasks/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | import wandb 5 | 6 | 7 | class VideoRecorder: 8 | def __init__(self, 9 | root_dir, 10 | render_size=256, 11 | fps=20, 12 | camera_id=0, 13 | use_wandb=False): 14 | if root_dir is not None: 15 | self.save_dir = root_dir / 'eval_video' 16 | self.save_dir.mkdir(exist_ok=True) 17 | else: 18 | self.save_dir = None 19 | 20 | self.render_size = render_size 21 | self.fps = fps 22 | self.frames = [] 23 | self.camera_id = camera_id 24 | self.use_wandb = use_wandb 25 | 26 | def init(self, env, enabled=True): 27 | self.frames = [] 28 | self.enabled = self.save_dir is not None and enabled 29 | self.record(env) 30 | 31 | def record(self, env): 32 | if self.enabled: 33 | if hasattr(env, 'physics'): 34 | frame = env.physics.render(height=self.render_size, 35 | width=self.render_size, 36 | camera_id=self.camera_id) 37 | else: 38 | frame = env.render() 39 | self.frames.append(frame) 40 | 41 | def log_to_wandb(self): 42 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2)) 43 | fps, skip = 6, 8 44 | wandb.log({ 45 | 'eval/video': 46 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif") 47 | }) 48 | 49 | def save(self, file_name): 50 | if self.enabled: 51 | if self.use_wandb: 52 | self.log_to_wandb() 53 | path = self.save_dir / file_name 54 | imageio.mimsave(str(path), self.frames, fps=self.fps) 55 | 56 | 57 | class TrainVideoRecorder: 58 | def __init__(self, 59 | root_dir, 60 | render_size=256, 61 | fps=20, 62 | camera_id=0, 63 | use_wandb=False): 64 | if root_dir is not None: 65 | self.save_dir = root_dir / 'train_video' 66 | self.save_dir.mkdir(exist_ok=True) 67 | else: 68 | self.save_dir = None 69 | 70 | self.render_size = render_size 71 | self.fps = fps 72 | self.frames = [] 73 | self.camera_id = camera_id 74 | self.use_wandb = use_wandb 75 | 76 | def init(self, obs, enabled=True): 77 | self.frames = [] 78 | self.enabled = self.save_dir is not None and enabled 79 | self.record(obs) 80 | 81 | def record(self, obs): 82 | if self.enabled: 83 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0), 84 | dsize=(self.render_size, self.render_size), 85 | interpolation=cv2.INTER_CUBIC) 86 | self.frames.append(frame) 87 | 88 | def log_to_wandb(self): 89 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2)) 90 | fps, skip = 6, 8 91 | wandb.log({ 92 | 'train/video': 93 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif") 94 | }) 95 | 96 | def save(self, file_name): 97 | if self.enabled: 98 | if self.use_wandb: 99 | self.log_to_wandb() 100 | path = self.save_dir / file_name 101 | imageio.mimsave(str(path), self.frames, fps=self.fps) 102 | -------------------------------------------------------------------------------- /custom_dmc_tasks/cheetah.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /custom_dmc_tasks/cheetah.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Cheetah Domain.""" 16 | 17 | import collections 18 | import os 19 | 20 | from dm_control import mujoco 21 | from dm_control.rl import control 22 | from dm_control.suite import base 23 | from dm_control.suite import common 24 | from dm_control.utils import containers 25 | from dm_control.utils import rewards 26 | from dm_control.utils import io as resources 27 | 28 | # How long the simulation will run, in seconds. 29 | _DEFAULT_TIME_LIMIT = 10 30 | 31 | # Running speed above which reward is 1. 32 | _RUN_SPEED = 10 33 | _SPIN_SPEED = 5 34 | 35 | SUITE = containers.TaggedTasks() 36 | 37 | 38 | def make(task, 39 | task_kwargs=None, 40 | environment_kwargs=None, 41 | visualize_reward=False): 42 | task_kwargs = task_kwargs or {} 43 | if environment_kwargs is not None: 44 | task_kwargs = task_kwargs.copy() 45 | task_kwargs['environment_kwargs'] = environment_kwargs 46 | env = SUITE[task](**task_kwargs) 47 | env.task.visualize_reward = visualize_reward 48 | return env 49 | 50 | 51 | def get_model_and_assets(): 52 | """Returns a tuple containing the model XML string and a dict of assets.""" 53 | root_dir = os.path.dirname(os.path.dirname(__file__)) 54 | xml = resources.GetResource( 55 | os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml')) 56 | return xml, common.ASSETS 57 | 58 | 59 | 60 | @SUITE.add('benchmarking') 61 | def run_backward(time_limit=_DEFAULT_TIME_LIMIT, 62 | random=None, 63 | environment_kwargs=None): 64 | """Returns the run task.""" 65 | physics = Physics.from_xml_string(*get_model_and_assets()) 66 | task = Cheetah(forward=False, flip=False, random=random) 67 | environment_kwargs = environment_kwargs or {} 68 | return control.Environment(physics, 69 | task, 70 | time_limit=time_limit, 71 | **environment_kwargs) 72 | 73 | 74 | @SUITE.add('benchmarking') 75 | def flip(time_limit=_DEFAULT_TIME_LIMIT, 76 | random=None, 77 | environment_kwargs=None): 78 | """Returns the run task.""" 79 | physics = Physics.from_xml_string(*get_model_and_assets()) 80 | task = Cheetah(forward=True, flip=True, random=random) 81 | environment_kwargs = environment_kwargs or {} 82 | return control.Environment(physics, 83 | task, 84 | time_limit=time_limit, 85 | **environment_kwargs) 86 | 87 | 88 | @SUITE.add('benchmarking') 89 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, 90 | random=None, 91 | environment_kwargs=None): 92 | """Returns the run task.""" 93 | physics = Physics.from_xml_string(*get_model_and_assets()) 94 | task = Cheetah(forward=False, flip=True, random=random) 95 | environment_kwargs = environment_kwargs or {} 96 | return control.Environment(physics, 97 | task, 98 | time_limit=time_limit, 99 | **environment_kwargs) 100 | 101 | 102 | class Physics(mujoco.Physics): 103 | """Physics simulation with additional features for the Cheetah domain.""" 104 | def speed(self): 105 | """Returns the horizontal speed of the Cheetah.""" 106 | return self.named.data.sensordata['torso_subtreelinvel'][0] 107 | 108 | def angmomentum(self): 109 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 110 | return self.named.data.subtree_angmom['torso'][1] 111 | 112 | 113 | class Cheetah(base.Task): 114 | """A `Task` to train a running Cheetah.""" 115 | def __init__(self, forward=True, flip=False, random=None): 116 | self._forward = 1 if forward else -1 117 | self._flip = flip 118 | super(Cheetah, self).__init__(random=random) 119 | 120 | def initialize_episode(self, physics): 121 | """Sets the state of the environment at the start of each episode.""" 122 | # The indexing below assumes that all joints have a single DOF. 123 | assert physics.model.nq == physics.model.njnt 124 | is_limited = physics.model.jnt_limited == 1 125 | lower, upper = physics.model.jnt_range[is_limited].T 126 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper) 127 | 128 | # Stabilize the model before the actual simulation. 129 | for _ in range(200): 130 | physics.step() 131 | 132 | physics.data.time = 0 133 | self._timeout_progress = 0 134 | super().initialize_episode(physics) 135 | 136 | def get_observation(self, physics): 137 | """Returns an observation of the state, ignoring horizontal position.""" 138 | obs = collections.OrderedDict() 139 | # Ignores horizontal position to maintain translational invariance. 140 | obs['position'] = physics.data.qpos[1:].copy() 141 | obs['velocity'] = physics.velocity() 142 | return obs 143 | 144 | def get_reward(self, physics): 145 | """Returns a reward to the agent.""" 146 | if self._flip: 147 | reward = rewards.tolerance(self._forward * physics.angmomentum(), 148 | bounds=(_SPIN_SPEED, float('inf')), 149 | margin=_SPIN_SPEED, 150 | value_at_margin=0, 151 | sigmoid='linear') 152 | 153 | else: 154 | reward = rewards.tolerance(self._forward * physics.speed(), 155 | bounds=(_RUN_SPEED, float('inf')), 156 | margin=_RUN_SPEED, 157 | value_at_margin=0, 158 | sigmoid='linear') 159 | return reward 160 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | import wandb 9 | from termcolor import colored 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 14 | ('episode_reward', 'R', 'float'), 15 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')] 16 | 17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 19 | ('episode_reward', 'R', 'float'), 20 | ('total_time', 'T', 'time')] 21 | 22 | 23 | class AverageMeter(object): 24 | def __init__(self): 25 | self._sum = 0 26 | self._count = 0 27 | 28 | def update(self, value, n=1): 29 | self._sum += value 30 | self._count += n 31 | 32 | def value(self): 33 | return self._sum / max(1, self._count) 34 | 35 | 36 | class MetersGroup(object): 37 | def __init__(self, csv_file_name, formating, use_wandb): 38 | self._csv_file_name = csv_file_name 39 | self._formating = formating 40 | self._meters = defaultdict(AverageMeter) 41 | self._csv_file = None 42 | self._csv_writer = None 43 | self.use_wandb = use_wandb 44 | 45 | def log(self, key, value, n=1): 46 | self._meters[key].update(value, n) 47 | 48 | def _prime_meters(self): 49 | data = dict() 50 | for key, meter in self._meters.items(): 51 | if key.startswith('train'): 52 | key = key[len('train') + 1:] 53 | else: 54 | key = key[len('eval') + 1:] 55 | key = key.replace('/', '_') 56 | data[key] = meter.value() 57 | return data 58 | 59 | def _remove_old_entries(self, data): 60 | rows = [] 61 | with self._csv_file_name.open('r') as f: 62 | reader = csv.DictReader(f) 63 | for row in reader: 64 | if float(row['episode']) >= data['episode']: 65 | break 66 | rows.append(row) 67 | with self._csv_file_name.open('w') as f: 68 | writer = csv.DictWriter(f, 69 | fieldnames=sorted(data.keys()), 70 | restval=0.0) 71 | writer.writeheader() 72 | for row in rows: 73 | writer.writerow(row) 74 | 75 | def _dump_to_csv(self, data): 76 | if self._csv_writer is None: 77 | should_write_header = True 78 | if self._csv_file_name.exists(): 79 | self._remove_old_entries(data) 80 | should_write_header = False 81 | 82 | self._csv_file = self._csv_file_name.open('a') 83 | self._csv_writer = csv.DictWriter(self._csv_file, 84 | fieldnames=sorted(data.keys()), 85 | restval=0.0) 86 | if should_write_header: 87 | self._csv_writer.writeheader() 88 | 89 | self._csv_writer.writerow(data) 90 | self._csv_file.flush() 91 | 92 | def _format(self, key, value, ty): 93 | if ty == 'int': 94 | value = int(value) 95 | return f'{key}: {value}' 96 | elif ty == 'float': 97 | return f'{key}: {value:.04f}' 98 | elif ty == 'time': 99 | value = str(datetime.timedelta(seconds=int(value))) 100 | return f'{key}: {value}' 101 | else: 102 | raise f'invalid format type: {ty}' 103 | 104 | def _dump_to_console(self, data, prefix): 105 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 106 | pieces = [f'| {prefix: <14}'] 107 | for key, disp_key, ty in self._formating: 108 | value = data.get(key, 0) 109 | pieces.append(self._format(disp_key, value, ty)) 110 | print(' | '.join(pieces)) 111 | 112 | def _dump_to_wandb(self, data): 113 | wandb.log(data) 114 | 115 | def dump(self, step, prefix): 116 | if len(self._meters) == 0: 117 | return 118 | data = self._prime_meters() 119 | data['frame'] = step 120 | if self.use_wandb: 121 | wandb_data = {prefix + '/' + key: val for key, val in data.items()} 122 | self._dump_to_wandb(data=wandb_data) 123 | self._dump_to_csv(data) 124 | self._dump_to_console(data, prefix) 125 | self._meters.clear() 126 | 127 | 128 | class Logger(object): 129 | def __init__(self, log_dir, use_tb, use_wandb): 130 | self._log_dir = log_dir 131 | self._train_mg = MetersGroup(log_dir / 'train.csv', 132 | formating=COMMON_TRAIN_FORMAT, 133 | use_wandb=use_wandb) 134 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 135 | formating=COMMON_EVAL_FORMAT, 136 | use_wandb=use_wandb) 137 | if use_tb: 138 | self._sw = SummaryWriter(str(log_dir / 'tb')) 139 | else: 140 | self._sw = None 141 | self.use_wandb = use_wandb 142 | 143 | def _try_sw_log(self, key, value, step): 144 | if self._sw is not None: 145 | self._sw.add_scalar(key, value, step) 146 | 147 | def log(self, key, value, step): 148 | assert key.startswith('train') or key.startswith('eval') 149 | if type(value) == torch.Tensor: 150 | value = value.item() 151 | self._try_sw_log(key, value, step) 152 | mg = self._train_mg if key.startswith('train') else self._eval_mg 153 | mg.log(key, value) 154 | 155 | def log_metrics(self, metrics, step, ty): 156 | for key, value in metrics.items(): 157 | self.log(f'{ty}/{key}', value, step) 158 | 159 | def dump(self, step, ty=None): 160 | if ty is None or ty == 'eval': 161 | self._eval_mg.dump(step, 'eval') 162 | if ty is None or ty == 'train': 163 | self._train_mg.dump(step, 'train') 164 | 165 | def log_and_dump_ctx(self, step, ty): 166 | return LogAndDumpCtx(self, step, ty) 167 | 168 | 169 | class LogAndDumpCtx: 170 | def __init__(self, logger, step, ty): 171 | self._logger = logger 172 | self._step = step 173 | self._ty = ty 174 | 175 | def __enter__(self): 176 | return self 177 | 178 | def __call__(self, key, value): 179 | self._logger.log(f'{self._ty}/{key}', value, self._step) 180 | 181 | def __exit__(self, *args): 182 | self._logger.dump(self._step, self._ty) 183 | -------------------------------------------------------------------------------- /custom_dmc_tasks/walker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Planar Walker Domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from dm_control.suite import base 27 | from dm_control.suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | from dm_control.utils import io as resources 32 | from dm_control import suite 33 | 34 | _DEFAULT_TIME_LIMIT = 25 35 | _CONTROL_TIMESTEP = .025 36 | 37 | # Minimal height of torso over foot above which stand reward is 1. 38 | _STAND_HEIGHT = 1.2 39 | 40 | # Horizontal speeds (meters/second) above which move reward is 1. 41 | _WALK_SPEED = 1 42 | _RUN_SPEED = 8 43 | _SPIN_SPEED = 5 44 | 45 | SUITE = containers.TaggedTasks() 46 | 47 | def make(task, 48 | task_kwargs=None, 49 | environment_kwargs=None, 50 | visualize_reward=False): 51 | task_kwargs = task_kwargs or {} 52 | if environment_kwargs is not None: 53 | task_kwargs = task_kwargs.copy() 54 | task_kwargs['environment_kwargs'] = environment_kwargs 55 | env = SUITE[task](**task_kwargs) 56 | env.task.visualize_reward = visualize_reward 57 | return env 58 | 59 | def get_model_and_assets(): 60 | """Returns a tuple containing the model XML string and a dict of assets.""" 61 | root_dir = os.path.dirname(os.path.dirname(__file__)) 62 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks', 63 | 'walker.xml')) 64 | return xml, common.ASSETS 65 | 66 | 67 | 68 | 69 | 70 | 71 | @SUITE.add('benchmarking') 72 | def flip(time_limit=_DEFAULT_TIME_LIMIT, 73 | random=None, 74 | environment_kwargs=None): 75 | """Returns the Run task.""" 76 | physics = Physics.from_xml_string(*get_model_and_assets()) 77 | task = PlanarWalker(move_speed=_RUN_SPEED, 78 | forward=True, 79 | flip=True, 80 | random=random) 81 | environment_kwargs = environment_kwargs or {} 82 | return control.Environment(physics, 83 | task, 84 | time_limit=time_limit, 85 | control_timestep=_CONTROL_TIMESTEP, 86 | **environment_kwargs) 87 | 88 | 89 | class Physics(mujoco.Physics): 90 | """Physics simulation with additional features for the Walker domain.""" 91 | def torso_upright(self): 92 | """Returns projection from z-axes of torso to the z-axes of world.""" 93 | return self.named.data.xmat['torso', 'zz'] 94 | 95 | def torso_height(self): 96 | """Returns the height of the torso.""" 97 | return self.named.data.xpos['torso', 'z'] 98 | 99 | def horizontal_velocity(self): 100 | """Returns the horizontal velocity of the center-of-mass.""" 101 | return self.named.data.sensordata['torso_subtreelinvel'][0] 102 | 103 | def orientations(self): 104 | """Returns planar orientations of all bodies.""" 105 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel() 106 | 107 | def angmomentum(self): 108 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 109 | return self.named.data.subtree_angmom['torso'][1] 110 | 111 | 112 | class PlanarWalker(base.Task): 113 | """A planar walker task.""" 114 | def __init__(self, move_speed, forward=True, flip=False, random=None): 115 | """Initializes an instance of `PlanarWalker`. 116 | 117 | Args: 118 | move_speed: A float. If this value is zero, reward is given simply for 119 | standing up. Otherwise this specifies a target horizontal velocity for 120 | the walking task. 121 | random: Optional, either a `numpy.random.RandomState` instance, an 122 | integer seed for creating a new `RandomState`, or None to select a seed 123 | automatically (default). 124 | """ 125 | self._move_speed = move_speed 126 | self._forward = 1 if forward else -1 127 | self._flip = flip 128 | super(PlanarWalker, self).__init__(random=random) 129 | 130 | def initialize_episode(self, physics): 131 | """Sets the state of the environment at the start of each episode. 132 | 133 | In 'standing' mode, use initial orientation and small velocities. 134 | In 'random' mode, randomize joint angles and let fall to the floor. 135 | 136 | Args: 137 | physics: An instance of `Physics`. 138 | 139 | """ 140 | randomizers.randomize_limited_and_rotational_joints( 141 | physics, self.random) 142 | super(PlanarWalker, self).initialize_episode(physics) 143 | 144 | def get_observation(self, physics): 145 | """Returns an observation of body orientations, height and velocites.""" 146 | obs = collections.OrderedDict() 147 | obs['orientations'] = physics.orientations() 148 | obs['height'] = physics.torso_height() 149 | obs['velocity'] = physics.velocity() 150 | return obs 151 | 152 | def get_reward(self, physics): 153 | """Returns a reward to the agent.""" 154 | standing = rewards.tolerance(physics.torso_height(), 155 | bounds=(_STAND_HEIGHT, float('inf')), 156 | margin=_STAND_HEIGHT / 2) 157 | upright = (1 + physics.torso_upright()) / 2 158 | stand_reward = (3 * standing + upright) / 4 159 | 160 | if self._flip: 161 | move_reward = rewards.tolerance(self._forward * 162 | physics.angmomentum(), 163 | bounds=(_SPIN_SPEED, float('inf')), 164 | margin=_SPIN_SPEED, 165 | value_at_margin=0, 166 | sigmoid='linear') 167 | else: 168 | move_reward = rewards.tolerance( 169 | self._forward * physics.horizontal_velocity(), 170 | bounds=(self._move_speed, float('inf')), 171 | margin=self._move_speed / 2, 172 | value_at_margin=0.5, 173 | sigmoid='linear') 174 | 175 | return stand_reward * (5 * move_reward + 1) / 6 176 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import random 4 | import traceback 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import IterableDataset 11 | 12 | 13 | def episode_len(episode): 14 | # subtract -1 because the dummy first transition 15 | return next(iter(episode.values())).shape[0] - 1 16 | 17 | 18 | def save_episode(episode, fn): 19 | with io.BytesIO() as bs: 20 | np.savez_compressed(bs, **episode) 21 | bs.seek(0) 22 | with fn.open('wb') as f: 23 | f.write(bs.read()) 24 | 25 | 26 | def load_episode(fn): 27 | with fn.open('rb') as f: 28 | episode = np.load(f) 29 | episode = {k: episode[k] for k in episode.keys()} 30 | return episode 31 | 32 | 33 | class ReplayBufferStorage: 34 | def __init__(self, data_specs, meta_specs, replay_dir): 35 | self._data_specs = data_specs 36 | self._meta_specs = meta_specs 37 | self._replay_dir = replay_dir 38 | replay_dir.mkdir(exist_ok=True) 39 | self._current_episode = defaultdict(list) 40 | self._preload() 41 | 42 | def __len__(self): 43 | return self._num_transitions 44 | 45 | def add(self, time_step, meta): 46 | for key, value in meta.items(): 47 | self._current_episode[key].append(value) 48 | for spec in self._data_specs: 49 | value = time_step[spec.name] 50 | if np.isscalar(value): 51 | value = np.full(spec.shape, value, spec.dtype) 52 | assert spec.shape == value.shape and spec.dtype == value.dtype 53 | self._current_episode[spec.name].append(value) 54 | if time_step.last(): 55 | episode = dict() 56 | for spec in self._data_specs: 57 | value = self._current_episode[spec.name] 58 | episode[spec.name] = np.array(value, spec.dtype) 59 | for spec in self._meta_specs: 60 | value = self._current_episode[spec.name] 61 | episode[spec.name] = np.array(value, spec.dtype) 62 | self._current_episode = defaultdict(list) 63 | self._store_episode(episode) 64 | 65 | def _preload(self): 66 | self._num_episodes = 0 67 | self._num_transitions = 0 68 | for fn in self._replay_dir.glob('*.npz'): 69 | _, _, eps_len = fn.stem.split('_') 70 | self._num_episodes += 1 71 | self._num_transitions += int(eps_len) 72 | 73 | def _store_episode(self, episode): 74 | eps_idx = self._num_episodes 75 | eps_len = episode_len(episode) 76 | self._num_episodes += 1 77 | self._num_transitions += eps_len 78 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 79 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz' 80 | save_episode(episode, self._replay_dir / eps_fn) 81 | 82 | 83 | class ReplayBuffer(IterableDataset): 84 | def __init__(self, storage, max_size, num_workers, nstep, discount, 85 | fetch_every, save_snapshot): 86 | self._storage = storage 87 | self._size = 0 88 | self._max_size = max_size 89 | self._num_workers = max(1, num_workers) 90 | self._episode_fns = [] 91 | self._episodes = dict() 92 | self._nstep = nstep 93 | self._discount = discount 94 | self._fetch_every = fetch_every 95 | self._samples_since_last_fetch = fetch_every 96 | self._save_snapshot = save_snapshot 97 | 98 | def _sample_episode(self): 99 | eps_fn = random.choice(self._episode_fns) 100 | return self._episodes[eps_fn] 101 | 102 | def _store_episode(self, eps_fn): 103 | try: 104 | episode = load_episode(eps_fn) 105 | except: 106 | return False 107 | eps_len = episode_len(episode) 108 | while eps_len + self._size > self._max_size: 109 | early_eps_fn = self._episode_fns.pop(0) 110 | early_eps = self._episodes.pop(early_eps_fn) 111 | self._size -= episode_len(early_eps) 112 | early_eps_fn.unlink(missing_ok=True) 113 | self._episode_fns.append(eps_fn) 114 | self._episode_fns.sort() 115 | self._episodes[eps_fn] = episode 116 | self._size += eps_len 117 | 118 | if not self._save_snapshot: 119 | eps_fn.unlink(missing_ok=True) 120 | return True 121 | 122 | def _try_fetch(self): 123 | if self._samples_since_last_fetch < self._fetch_every: 124 | return 125 | self._samples_since_last_fetch = 0 126 | try: 127 | worker_id = torch.utils.data.get_worker_info().id 128 | except: 129 | worker_id = 0 130 | eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True) 131 | fetched_size = 0 132 | for eps_fn in eps_fns: 133 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]] 134 | if eps_idx % self._num_workers != worker_id: 135 | continue 136 | if eps_fn in self._episodes.keys(): 137 | break 138 | if fetched_size + eps_len > self._max_size: 139 | break 140 | fetched_size += eps_len 141 | if not self._store_episode(eps_fn): 142 | break 143 | 144 | def _sample(self): 145 | try: 146 | self._try_fetch() 147 | except: 148 | traceback.print_exc() 149 | self._samples_since_last_fetch += 1 150 | episode = self._sample_episode() 151 | # add +1 for the first dummy transition 152 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1 153 | meta = [] 154 | for spec in self._storage._meta_specs: 155 | meta.append(episode[spec.name][idx - 1]) 156 | obs = episode['observation'][idx - 1] 157 | action = episode['action'][idx] 158 | next_obs = episode['observation'][idx + self._nstep - 1] 159 | reward = np.zeros_like(episode['reward'][idx]) 160 | discount = np.ones_like(episode['discount'][idx]) 161 | for i in range(self._nstep): 162 | step_reward = episode['reward'][idx + i] 163 | reward += discount * step_reward 164 | discount *= episode['discount'][idx + i] * self._discount 165 | return (obs, action, reward, discount, next_obs, *meta) 166 | 167 | def __iter__(self): 168 | while True: 169 | yield self._sample() 170 | 171 | 172 | def _worker_init_fn(worker_id): 173 | seed = np.random.get_state()[1][0] + worker_id 174 | np.random.seed(seed) 175 | random.seed(seed) 176 | 177 | 178 | def make_replay_loader(storage, max_size, batch_size, num_workers, 179 | save_snapshot, nstep, discount): 180 | max_size_per_worker = max_size // max(1, num_workers) 181 | 182 | iterable = ReplayBuffer(storage, 183 | max_size_per_worker, 184 | num_workers, 185 | nstep, 186 | discount, 187 | fetch_every=1000, 188 | save_snapshot=save_snapshot) 189 | 190 | loader = torch.utils.data.DataLoader(iterable, 191 | batch_size=batch_size, 192 | num_workers=num_workers, 193 | pin_memory=True, 194 | worker_init_fn=_worker_init_fn) 195 | return loader 196 | -------------------------------------------------------------------------------- /custom_dmc_tasks/hopper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Hopper domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from dm_control.suite import base 27 | from dm_control.suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | from dm_control.utils import io as resources 32 | import numpy as np 33 | 34 | SUITE = containers.TaggedTasks() 35 | 36 | _CONTROL_TIMESTEP = .02 # (Seconds) 37 | 38 | # Default duration of an episode, in seconds. 39 | _DEFAULT_TIME_LIMIT = 20 40 | 41 | # Minimal height of torso over foot above which stand reward is 1. 42 | _STAND_HEIGHT = 0.6 43 | 44 | # Hopping speed above which hop reward is 1. 45 | _HOP_SPEED = 2 46 | _SPIN_SPEED = 5 47 | 48 | 49 | def make(task, 50 | task_kwargs=None, 51 | environment_kwargs=None, 52 | visualize_reward=False): 53 | task_kwargs = task_kwargs or {} 54 | if environment_kwargs is not None: 55 | task_kwargs = task_kwargs.copy() 56 | task_kwargs['environment_kwargs'] = environment_kwargs 57 | env = SUITE[task](**task_kwargs) 58 | env.task.visualize_reward = visualize_reward 59 | return env 60 | 61 | def get_model_and_assets(): 62 | """Returns a tuple containing the model XML string and a dict of assets.""" 63 | root_dir = os.path.dirname(os.path.dirname(__file__)) 64 | xml = resources.GetResource( 65 | os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml')) 66 | return xml, common.ASSETS 67 | 68 | 69 | 70 | @SUITE.add('benchmarking') 71 | def hop_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 72 | """Returns a Hopper that strives to hop forward.""" 73 | physics = Physics.from_xml_string(*get_model_and_assets()) 74 | task = Hopper(hopping=True, forward=False, flip=False, random=random) 75 | environment_kwargs = environment_kwargs or {} 76 | return control.Environment(physics, 77 | task, 78 | time_limit=time_limit, 79 | control_timestep=_CONTROL_TIMESTEP, 80 | **environment_kwargs) 81 | 82 | @SUITE.add('benchmarking') 83 | def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 84 | """Returns a Hopper that strives to hop forward.""" 85 | physics = Physics.from_xml_string(*get_model_and_assets()) 86 | task = Hopper(hopping=True, forward=True, flip=True, random=random) 87 | environment_kwargs = environment_kwargs or {} 88 | return control.Environment(physics, 89 | task, 90 | time_limit=time_limit, 91 | control_timestep=_CONTROL_TIMESTEP, 92 | **environment_kwargs) 93 | 94 | @SUITE.add('benchmarking') 95 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 96 | """Returns a Hopper that strives to hop forward.""" 97 | physics = Physics.from_xml_string(*get_model_and_assets()) 98 | task = Hopper(hopping=True, forward=False, flip=True, random=random) 99 | environment_kwargs = environment_kwargs or {} 100 | return control.Environment(physics, 101 | task, 102 | time_limit=time_limit, 103 | control_timestep=_CONTROL_TIMESTEP, 104 | **environment_kwargs) 105 | 106 | 107 | class Physics(mujoco.Physics): 108 | """Physics simulation with additional features for the Hopper domain.""" 109 | def height(self): 110 | """Returns height of torso with respect to foot.""" 111 | return (self.named.data.xipos['torso', 'z'] - 112 | self.named.data.xipos['foot', 'z']) 113 | 114 | def speed(self): 115 | """Returns horizontal speed of the Hopper.""" 116 | return self.named.data.sensordata['torso_subtreelinvel'][0] 117 | 118 | def touch(self): 119 | """Returns the signals from two foot touch sensors.""" 120 | return np.log1p(self.named.data.sensordata[['touch_toe', 121 | 'touch_heel']]) 122 | 123 | def angmomentum(self): 124 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 125 | return self.named.data.subtree_angmom['torso'][1] 126 | 127 | 128 | 129 | class Hopper(base.Task): 130 | """A Hopper's `Task` to train a standing and a jumping Hopper.""" 131 | def __init__(self, hopping, forward=True, flip=False, random=None): 132 | """Initialize an instance of `Hopper`. 133 | 134 | Args: 135 | hopping: Boolean, if True the task is to hop forwards, otherwise it is to 136 | balance upright. 137 | random: Optional, either a `numpy.random.RandomState` instance, an 138 | integer seed for creating a new `RandomState`, or None to select a seed 139 | automatically (default). 140 | """ 141 | self._hopping = hopping 142 | self._forward = 1 if forward else -1 143 | self._flip = flip 144 | super(Hopper, self).__init__(random=random) 145 | 146 | def initialize_episode(self, physics): 147 | """Sets the state of the environment at the start of each episode.""" 148 | randomizers.randomize_limited_and_rotational_joints( 149 | physics, self.random) 150 | self._timeout_progress = 0 151 | super(Hopper, self).initialize_episode(physics) 152 | 153 | def get_observation(self, physics): 154 | """Returns an observation of positions, velocities and touch sensors.""" 155 | obs = collections.OrderedDict() 156 | # Ignores horizontal position to maintain translational invariance: 157 | obs['position'] = physics.data.qpos[1:].copy() 158 | obs['velocity'] = physics.velocity() 159 | obs['touch'] = physics.touch() 160 | return obs 161 | 162 | def get_reward(self, physics): 163 | """Returns a reward applicable to the performed task.""" 164 | standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2)) 165 | assert self._hopping 166 | if self._flip: 167 | hopping = rewards.tolerance(self._forward * physics.angmomentum(), 168 | bounds=(_SPIN_SPEED, float('inf')), 169 | margin=_SPIN_SPEED, 170 | value_at_margin=0, 171 | sigmoid='linear') 172 | else: 173 | hopping = rewards.tolerance(self._forward * physics.speed(), 174 | bounds=(_HOP_SPEED, float('inf')), 175 | margin=_HOP_SPEED / 2, 176 | value_at_margin=0.5, 177 | sigmoid='linear') 178 | return standing * hopping -------------------------------------------------------------------------------- /agent/becl.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | from collections import OrderedDict 4 | 5 | import hydra 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from dm_env import specs 12 | 13 | import utils 14 | from agent.ddpg import DDPGAgent 15 | 16 | 17 | class BECL(nn.Module): 18 | def __init__(self, tau_dim, feature_dim, hidden_dim): 19 | super().__init__() 20 | 21 | self.embed = nn.Sequential(nn.Linear(tau_dim, hidden_dim), 22 | nn.ReLU(), 23 | nn.Linear(hidden_dim, hidden_dim), 24 | nn.ReLU(), 25 | nn.Linear(hidden_dim, feature_dim)) 26 | 27 | self.project_head = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 28 | nn.ReLU(), 29 | nn.Linear(hidden_dim, feature_dim)) 30 | self.apply(utils.weight_init) 31 | 32 | def forward(self, tau): 33 | features = self.embed(tau) 34 | features = self.project_head(features) 35 | return features 36 | 37 | 38 | class BECLAgent(DDPGAgent): 39 | def __init__(self, update_skill_every_step, skill_dim, 40 | update_encoder, contrastive_update_rate, temperature, skill, **kwargs): 41 | self.skill_dim = skill_dim 42 | self.update_skill_every_step = update_skill_every_step 43 | self.update_encoder = update_encoder 44 | self.contrastive_update_rate = contrastive_update_rate 45 | self.temperature = temperature 46 | # specify skill in fine-tuning stage if needed 47 | self.skill = int(skill) if skill >= 0 else np.random.choice(self.skill_dim) 48 | # increase obs shape to include skill dim 49 | kwargs["meta_dim"] = self.skill_dim 50 | self.batch_size = kwargs['batch_size'] 51 | # create actor and critic 52 | super().__init__(**kwargs) 53 | 54 | # net 55 | self.becl = BECL(self.obs_dim - self.skill_dim, 56 | self.skill_dim, 57 | kwargs['hidden_dim']).to(kwargs['device']) 58 | 59 | # optimizers 60 | self.becl_opt = torch.optim.Adam(self.becl.parameters(), lr=self.lr) 61 | 62 | self.becl.train() 63 | 64 | 65 | def get_meta_specs(self): 66 | return specs.Array((self.skill_dim,), np.float32, 'skill'), 67 | 68 | def init_meta(self): 69 | skill = np.zeros(self.skill_dim).astype(np.float32) 70 | if not self.reward_free: 71 | skill[self.skill] = 1.0 72 | else: 73 | skill[np.random.choice(self.skill_dim)] = 1.0 74 | meta = OrderedDict() 75 | meta['skill'] = skill 76 | return meta 77 | 78 | def update_meta(self, meta, global_step, time_step, finetune=False): 79 | if global_step % self.update_skill_every_step == 0: 80 | return self.init_meta() 81 | return meta 82 | 83 | def update_contrastive(self, state, skills): 84 | metrics = dict() 85 | features = self.becl(state) 86 | logits = self.compute_info_nce_loss(features, skills) 87 | loss = logits.mean() 88 | 89 | self.becl_opt.zero_grad() 90 | if self.encoder_opt is not None: 91 | self.encoder_opt.zero_grad(set_to_none=True) 92 | loss.backward() 93 | self.becl_opt.step() 94 | if self.encoder_opt is not None: 95 | self.encoder_opt.step() 96 | 97 | if self.use_tb or self.use_wandb: 98 | metrics['contrastive_loss'] = loss.item() 99 | 100 | return metrics 101 | 102 | def compute_intr_reward(self, skills, state, metrics): 103 | 104 | # compute contrastive reward 105 | features = self.becl(state) 106 | contrastive_reward = torch.exp(-self.compute_info_nce_loss(features, skills)) 107 | 108 | intr_reward = contrastive_reward 109 | if self.use_tb or self.use_wandb: 110 | metrics['contrastive_reward'] = contrastive_reward.mean().item() 111 | 112 | return intr_reward 113 | 114 | def compute_info_nce_loss(self, features, skills): 115 | # features: (b,c), skills :(b, skill_dim) 116 | # label positives samples 117 | labels = torch.argmax(skills, dim=-1) #(b, 1) 118 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).long() #(b,b) 119 | labels = labels.to(self.device) 120 | 121 | features = F.normalize(features, dim=1) #(b,c) 122 | similarity_matrix = torch.matmul(features, features.T) #(b,b) 123 | 124 | # discard the main diagonal from both: labels and similarities matrix 125 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device) 126 | labels = labels[~mask].view(labels.shape[0], -1) #(b,b-1) 127 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) #(b,b-1) 128 | 129 | similarity_matrix = similarity_matrix / self.temperature 130 | similarity_matrix -= torch.max(similarity_matrix, 1)[0][:, None] 131 | similarity_matrix = torch.exp(similarity_matrix) 132 | 133 | pick_one_positive_sample_idx = torch.argmax(labels, dim=-1, keepdim=True) 134 | pick_one_positive_sample_idx = torch.zeros_like(labels).scatter_(-1, pick_one_positive_sample_idx, 1) 135 | 136 | positives = torch.sum(similarity_matrix * pick_one_positive_sample_idx, dim=-1, keepdim=True) #(b,1) 137 | negatives = torch.sum(similarity_matrix, dim=-1, keepdim=True) #(b,1) 138 | eps = torch.as_tensor(1e-6) 139 | loss = -torch.log(positives / (negatives + eps) + eps) #(b,1) 140 | 141 | return loss 142 | 143 | 144 | def update(self, replay_iter, step): 145 | metrics = dict() 146 | 147 | if step % self.update_every_steps != 0: 148 | return metrics 149 | 150 | if self.reward_free: 151 | 152 | batch = next(replay_iter) 153 | obs, action, reward, discount, next_obs, skill = utils.to_torch(batch, self.device) 154 | obs = self.aug_and_encode(obs) 155 | next_obs = self.aug_and_encode(next_obs) 156 | 157 | metrics.update(self.update_contrastive(next_obs, skill)) 158 | 159 | for _ in range(self.contrastive_update_rate - 1): 160 | batch = next(replay_iter) 161 | obs, action, reward, discount, next_obs, skill = utils.to_torch(batch, self.device) 162 | obs = self.aug_and_encode(obs) 163 | next_obs = self.aug_and_encode(next_obs) 164 | 165 | metrics.update(self.update_contrastive(next_obs, skill)) 166 | 167 | with torch.no_grad(): 168 | intr_reward = self.compute_intr_reward(skill, next_obs, metrics) 169 | 170 | if self.use_tb or self.use_wandb: 171 | metrics['intr_reward'] = intr_reward.mean().item() 172 | 173 | reward = intr_reward 174 | else: 175 | batch = next(replay_iter) 176 | 177 | obs, action, extr_reward, discount, next_obs, skill = utils.to_torch( 178 | batch, self.device) 179 | obs = self.aug_and_encode(obs) 180 | next_obs = self.aug_and_encode(next_obs) 181 | reward = extr_reward 182 | 183 | if self.use_tb or self.use_wandb: 184 | metrics['batch_reward'] = reward.mean().item() 185 | 186 | if not self.update_encoder: 187 | obs = obs.detach() 188 | next_obs = next_obs.detach() 189 | 190 | # extend observations with skill 191 | obs = torch.cat([obs, skill], dim=1) 192 | next_obs = torch.cat([next_obs, skill], dim=1) 193 | 194 | # update critic 195 | metrics.update( 196 | self.update_critic(obs.detach(), action, reward, discount, 197 | next_obs.detach(), step)) 198 | 199 | # update actor 200 | metrics.update(self.update_actor(obs.detach(), step)) 201 | 202 | # update critic target 203 | utils.soft_update_params(self.critic, self.critic_target, 204 | self.critic_target_tau) 205 | 206 | return metrics 207 | -------------------------------------------------------------------------------- /custom_dmc_tasks/jaco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """A task where the goal is to move the hand close to a target prop or site.""" 17 | 18 | import collections 19 | 20 | from dm_control import composer 21 | from dm_control.composer import initializers 22 | from dm_control.composer.observation import observable 23 | from dm_control.composer.variation import distributions 24 | from dm_control.entities import props 25 | from dm_control.manipulation.shared import arenas 26 | from dm_control.manipulation.shared import cameras 27 | from dm_control.manipulation.shared import constants 28 | from dm_control.manipulation.shared import observations 29 | from dm_control.manipulation.shared import registry 30 | from dm_control.manipulation.shared import robots 31 | from dm_control.manipulation.shared import tags 32 | from dm_control.manipulation.shared import workspaces 33 | from dm_control.utils import rewards 34 | import numpy as np 35 | 36 | 37 | _ReachWorkspace = collections.namedtuple( 38 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset']) 39 | 40 | # Ensures that the props are not touching the table before settling. 41 | _PROP_Z_OFFSET = 0.001 42 | 43 | _DUPLO_WORKSPACE = _ReachWorkspace( 44 | target_bbox=workspaces.BoundingBox( 45 | lower=(-0.1, -0.1, _PROP_Z_OFFSET), 46 | upper=(0.1, 0.1, _PROP_Z_OFFSET)), 47 | tcp_bbox=workspaces.BoundingBox( 48 | lower=(-0.1, -0.1, 0.2), 49 | upper=(0.1, 0.1, 0.4)), 50 | arm_offset=robots.ARM_OFFSET) 51 | 52 | _SITE_WORKSPACE = _ReachWorkspace( 53 | target_bbox=workspaces.BoundingBox( 54 | lower=(-0.2, -0.2, 0.02), 55 | upper=(0.2, 0.2, 0.4)), 56 | tcp_bbox=workspaces.BoundingBox( 57 | lower=(-0.2, -0.2, 0.02), 58 | upper=(0.2, 0.2, 0.4)), 59 | arm_offset=robots.ARM_OFFSET) 60 | 61 | _TARGET_RADIUS = 0.05 62 | _TIME_LIMIT = 10. 63 | 64 | TASKS = { 65 | 'reach_top_left': workspaces.BoundingBox( 66 | lower=(-0.09, 0.09, _PROP_Z_OFFSET), 67 | upper=(-0.09, 0.09, _PROP_Z_OFFSET)), 68 | 'reach_top_right': workspaces.BoundingBox( 69 | lower=(0.09, 0.09, _PROP_Z_OFFSET), 70 | upper=(0.09, 0.09, _PROP_Z_OFFSET)), 71 | 'reach_bottom_left': workspaces.BoundingBox( 72 | lower=(-0.09, -0.09, _PROP_Z_OFFSET), 73 | upper=(-0.09, -0.09, _PROP_Z_OFFSET)), 74 | 'reach_bottom_right': workspaces.BoundingBox( 75 | lower=(0.09, -0.09, _PROP_Z_OFFSET), 76 | upper=(0.09, -0.09, _PROP_Z_OFFSET)), 77 | } 78 | 79 | 80 | def make(task_id, obs_type, seed): 81 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES 82 | task = _reach(task_id, obs_settings=obs_settings, use_site=False) 83 | return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed) 84 | 85 | 86 | 87 | class MTReach(composer.Task): 88 | """Bring the hand close to a target prop or site.""" 89 | 90 | def __init__( 91 | self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep): 92 | """Initializes a new `Reach` task. 93 | 94 | Args: 95 | arena: `composer.Entity` instance. 96 | arm: `robot_base.RobotArm` instance. 97 | hand: `robot_base.RobotHand` instance. 98 | prop: `composer.Entity` instance specifying the prop to reach to, or None 99 | in which case the target is a fixed site whose position is specified by 100 | the workspace. 101 | obs_settings: `observations.ObservationSettings` instance. 102 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP. 103 | control_timestep: Float specifying the control timestep in seconds. 104 | """ 105 | self._arena = arena 106 | self._arm = arm 107 | self._hand = hand 108 | self._arm.attach(self._hand) 109 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset) 110 | self.control_timestep = control_timestep 111 | self._tcp_initializer = initializers.ToolCenterPointInitializer( 112 | self._hand, self._arm, 113 | position=distributions.Uniform(*workspace.tcp_bbox), 114 | quaternion=workspaces.DOWN_QUATERNION) 115 | 116 | # Add custom camera observable. 117 | self._task_observables = cameras.add_camera_observables( 118 | arena, obs_settings, cameras.FRONT_CLOSE) 119 | 120 | target_pos_distribution = distributions.Uniform(*TASKS[task_id]) 121 | self._prop = prop 122 | if prop: 123 | # The prop itself is used to visualize the target location. 124 | self._make_target_site(parent_entity=prop, visible=False) 125 | self._target = self._arena.add_free_entity(prop) 126 | self._prop_placer = initializers.PropPlacer( 127 | props=[prop], 128 | position=target_pos_distribution, 129 | quaternion=workspaces.uniform_z_rotation, 130 | settle_physics=True) 131 | else: 132 | self._target = self._make_target_site(parent_entity=arena, visible=True) 133 | self._target_placer = target_pos_distribution 134 | 135 | obs = observable.MJCFFeature('pos', self._target) 136 | obs.configure(**obs_settings.prop_pose._asdict()) 137 | self._task_observables['target_position'] = obs 138 | 139 | # Add sites for visualizing the prop and target bounding boxes. 140 | workspaces.add_bbox_site( 141 | body=self.root_entity.mjcf_model.worldbody, 142 | lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper, 143 | rgba=constants.GREEN, name='tcp_spawn_area') 144 | workspaces.add_bbox_site( 145 | body=self.root_entity.mjcf_model.worldbody, 146 | lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper, 147 | rgba=constants.BLUE, name='target_spawn_area') 148 | 149 | def _make_target_site(self, parent_entity, visible): 150 | return workspaces.add_target_site( 151 | body=parent_entity.mjcf_model.worldbody, 152 | radius=_TARGET_RADIUS, visible=visible, 153 | rgba=constants.RED, name='target_site') 154 | 155 | @property 156 | def root_entity(self): 157 | return self._arena 158 | 159 | @property 160 | def arm(self): 161 | return self._arm 162 | 163 | @property 164 | def hand(self): 165 | return self._hand 166 | 167 | @property 168 | def task_observables(self): 169 | return self._task_observables 170 | 171 | def get_reward(self, physics): 172 | hand_pos = physics.bind(self._hand.tool_center_point).xpos 173 | target_pos = physics.bind(self._target).xpos 174 | distance = np.linalg.norm(hand_pos - target_pos) 175 | return rewards.tolerance( 176 | distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS) 177 | 178 | def initialize_episode(self, physics, random_state): 179 | self._hand.set_grasp(physics, close_factors=random_state.uniform()) 180 | self._tcp_initializer(physics, random_state) 181 | if self._prop: 182 | self._prop_placer(physics, random_state) 183 | else: 184 | physics.bind(self._target).pos = ( 185 | self._target_placer(random_state=random_state)) 186 | 187 | 188 | def _reach(task_id, obs_settings, use_site): 189 | """Configure and instantiate a `Reach` task. 190 | 191 | Args: 192 | obs_settings: An `observations.ObservationSettings` instance. 193 | use_site: Boolean, if True then the target will be a fixed site, otherwise 194 | it will be a moveable Duplo brick. 195 | 196 | Returns: 197 | An instance of `reach.Reach`. 198 | """ 199 | arena = arenas.Standard() 200 | arm = robots.make_arm(obs_settings=obs_settings) 201 | hand = robots.make_hand(obs_settings=obs_settings) 202 | if use_site: 203 | workspace = _SITE_WORKSPACE 204 | prop = None 205 | else: 206 | workspace = _DUPLO_WORKSPACE 207 | prop = props.Duplo(observable_options=observations.make_options( 208 | obs_settings, observations.FREEPROP_OBSERVABLES)) 209 | task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop, 210 | obs_settings=obs_settings, 211 | workspace=workspace, 212 | control_timestep=constants.CONTROL_TIMESTEP) 213 | return task -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | 5 | import os 6 | 7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 8 | os.environ['MUJOCO_GL'] = 'egl' 9 | 10 | from pathlib import Path 11 | 12 | import hydra 13 | import numpy as np 14 | import torch 15 | import wandb 16 | from dm_env import specs 17 | 18 | import dmc 19 | import utils 20 | from logger import Logger 21 | from replay_buffer import ReplayBufferStorage, make_replay_loader 22 | from video import TrainVideoRecorder, VideoRecorder 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | from dmc_benchmark import PRIMAL_TASKS 27 | 28 | 29 | def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg): 30 | cfg.obs_type = obs_type 31 | cfg.obs_shape = obs_spec.shape 32 | cfg.action_shape = action_spec.shape 33 | cfg.num_expl_steps = num_expl_steps 34 | return hydra.utils.instantiate(cfg) 35 | 36 | 37 | class Workspace: 38 | def __init__(self, cfg): 39 | self.work_dir = Path.cwd() 40 | print(f'workspace: {self.work_dir}') 41 | 42 | self.cfg = cfg 43 | utils.set_seed_everywhere(cfg.seed) 44 | self.device = torch.device(cfg.device) 45 | 46 | # create logger 47 | if cfg.use_wandb: 48 | exp_name = '_'.join([ 49 | cfg.experiment, cfg.agent.name, cfg.domain, cfg.obs_type, 50 | str(cfg.seed) 51 | ]) 52 | wandb.init(project="urlb", group=cfg.agent.name, name=exp_name) 53 | 54 | self.logger = Logger(self.work_dir, 55 | use_tb=cfg.use_tb, 56 | use_wandb=cfg.use_wandb) 57 | # create envs 58 | task = PRIMAL_TASKS[self.cfg.domain] 59 | self.train_env = dmc.make(task, cfg.obs_type, cfg.frame_stack, 60 | cfg.action_repeat, cfg.seed) 61 | self.eval_env = dmc.make(task, cfg.obs_type, cfg.frame_stack, 62 | cfg.action_repeat, cfg.seed) 63 | 64 | # create agent 65 | self.agent = make_agent(cfg.obs_type, 66 | self.train_env.observation_spec(), 67 | self.train_env.action_spec(), 68 | cfg.num_seed_frames // cfg.action_repeat, 69 | cfg.agent) 70 | 71 | # get meta specs 72 | meta_specs = self.agent.get_meta_specs() 73 | # create replay buffer 74 | data_specs = (self.train_env.observation_spec(), 75 | self.train_env.action_spec(), 76 | specs.Array((1,), np.float32, 'reward'), 77 | specs.Array((1,), np.float32, 'discount')) 78 | 79 | # create data storage 80 | self.replay_storage = ReplayBufferStorage(data_specs, meta_specs, 81 | self.work_dir / 'buffer') 82 | 83 | # create replay buffer 84 | self.replay_loader = make_replay_loader(self.replay_storage, 85 | cfg.replay_buffer_size, 86 | cfg.batch_size, 87 | cfg.replay_buffer_num_workers, 88 | False, cfg.nstep, cfg.discount) 89 | self._replay_iter = None 90 | 91 | # create video recorders 92 | self.video_recorder = VideoRecorder( 93 | self.work_dir if cfg.save_video else None, 94 | camera_id=0 if 'quadruped' not in self.cfg.domain else 2, 95 | use_wandb=self.cfg.use_wandb) 96 | self.train_video_recorder = TrainVideoRecorder( 97 | self.work_dir if cfg.save_train_video else None, 98 | camera_id=0 if 'quadruped' not in self.cfg.domain else 2, 99 | use_wandb=self.cfg.use_wandb) 100 | 101 | self.timer = utils.Timer() 102 | self._global_step = 0 103 | self._global_episode = 0 104 | 105 | @property 106 | def global_step(self): 107 | return self._global_step 108 | 109 | @property 110 | def global_episode(self): 111 | return self._global_episode 112 | 113 | @property 114 | def global_frame(self): 115 | return self.global_step * self.cfg.action_repeat 116 | 117 | @property 118 | def replay_iter(self): 119 | if self._replay_iter is None: 120 | self._replay_iter = iter(self.replay_loader) 121 | return self._replay_iter 122 | 123 | def eval(self): 124 | step, episode, total_reward = 0, 0, 0 125 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 126 | meta = self.agent.init_meta() 127 | while eval_until_episode(episode): 128 | time_step = self.eval_env.reset() 129 | self.video_recorder.init(self.eval_env, enabled=(episode == 0)) 130 | while not time_step.last(): 131 | with torch.no_grad(), utils.eval_mode(self.agent): 132 | action = self.agent.act(time_step.observation, 133 | meta, 134 | self.global_step, 135 | eval_mode=True) 136 | time_step = self.eval_env.step(action) 137 | self.video_recorder.record(self.eval_env) 138 | total_reward += time_step.reward 139 | step += 1 140 | 141 | episode += 1 142 | self.video_recorder.save(f'{self.global_frame}.mp4') 143 | 144 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log: 145 | log('episode_reward', total_reward / episode) 146 | log('episode_length', step * self.cfg.action_repeat / episode) 147 | log('episode', self.global_episode) 148 | log('step', self.global_step) 149 | 150 | def train(self): 151 | # predicates 152 | train_until_step = utils.Until(self.cfg.num_train_frames, 153 | self.cfg.action_repeat) 154 | seed_until_step = utils.Until(self.cfg.num_seed_frames, 155 | self.cfg.action_repeat) 156 | eval_every_step = utils.Every(self.cfg.eval_every_frames, 157 | self.cfg.action_repeat) 158 | 159 | episode_step, episode_reward = 0, 0 160 | time_step = self.train_env.reset() 161 | meta = self.agent.init_meta() 162 | self.replay_storage.add(time_step, meta) 163 | self.train_video_recorder.init(time_step.observation) 164 | metrics = None 165 | while train_until_step(self.global_step): 166 | if time_step.last(): 167 | self._global_episode += 1 168 | self.train_video_recorder.save(f'{self.global_frame}.mp4') 169 | # wait until all the metrics schema is populated 170 | if metrics is not None: 171 | # log stats 172 | elapsed_time, total_time = self.timer.reset() 173 | episode_frame = episode_step * self.cfg.action_repeat 174 | with self.logger.log_and_dump_ctx(self.global_frame, 175 | ty='train') as log: 176 | log('fps', episode_frame / elapsed_time) 177 | log('total_time', total_time) 178 | log('episode_reward', episode_reward) 179 | log('episode_length', episode_frame) 180 | log('episode', self.global_episode) 181 | log('buffer_size', len(self.replay_storage)) 182 | log('step', self.global_step) 183 | 184 | # reset env 185 | time_step = self.train_env.reset() 186 | meta = self.agent.init_meta() 187 | self.replay_storage.add(time_step, meta) 188 | self.train_video_recorder.init(time_step.observation) 189 | # try to save snapshot 190 | episode_step = 0 191 | episode_reward = 0 192 | if self.global_frame in self.cfg.snapshots: 193 | self.save_snapshot() 194 | # try to evaluate 195 | if eval_every_step(self.global_step): 196 | self.logger.log('eval_total_time', self.timer.total_time(), 197 | self.global_frame) 198 | self.eval() 199 | 200 | meta = self.agent.update_meta(meta, self.global_step, time_step) 201 | # sample action 202 | with torch.no_grad(), utils.eval_mode(self.agent): 203 | action = self.agent.act(time_step.observation, 204 | meta, 205 | self.global_step, 206 | eval_mode=False) 207 | 208 | # try to update the agent 209 | if not seed_until_step(self.global_step): 210 | metrics = self.agent.update(self.replay_iter, self.global_step) 211 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 212 | 213 | # take env step 214 | time_step = self.train_env.step(action) 215 | episode_reward += time_step.reward 216 | self.replay_storage.add(time_step, meta) 217 | self.train_video_recorder.record(time_step.observation) 218 | episode_step += 1 219 | self._global_step += 1 220 | 221 | def save_snapshot(self): 222 | snapshot_dir = self.work_dir / Path(self.cfg.snapshot_dir) 223 | snapshot_dir.mkdir(exist_ok=True, parents=True) 224 | snapshot = snapshot_dir / f'snapshot_{self.global_frame}.pt' 225 | keys_to_save = ['agent', '_global_step', '_global_episode'] 226 | payload = {k: self.__dict__[k] for k in keys_to_save} 227 | with snapshot.open('wb') as f: 228 | torch.save(payload, f) 229 | 230 | 231 | @hydra.main(config_path='.', config_name='pretrain') 232 | def main(cfg): 233 | from pretrain import Workspace as W 234 | root_dir = Path.cwd() 235 | workspace = W(cfg) 236 | snapshot = root_dir / 'snapshot.pt' 237 | if snapshot.exists(): 238 | print(f'resuming: {snapshot}') 239 | workspace.load_snapshot() 240 | workspace.train() 241 | 242 | 243 | if __name__ == '__main__': 244 | main() 245 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | 5 | import os 6 | 7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 8 | os.environ['MUJOCO_GL'] = 'egl' 9 | 10 | from pathlib import Path 11 | 12 | import hydra 13 | import numpy as np 14 | import torch 15 | from dm_env import specs 16 | 17 | import dmc 18 | import utils 19 | from logger import Logger 20 | from replay_buffer import ReplayBufferStorage, make_replay_loader 21 | from video import TrainVideoRecorder, VideoRecorder 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | import logging 26 | 27 | def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg): 28 | cfg.obs_type = obs_type 29 | cfg.obs_shape = obs_spec.shape 30 | cfg.action_shape = action_spec.shape 31 | cfg.num_expl_steps = num_expl_steps 32 | return hydra.utils.instantiate(cfg) 33 | 34 | 35 | class Workspace: 36 | def __init__(self, cfg): 37 | self.work_dir = Path.cwd() 38 | print(f'workspace: {self.work_dir}') 39 | 40 | self.cfg = cfg 41 | utils.set_seed_everywhere(cfg.seed) 42 | self.device = torch.device(cfg.device) 43 | 44 | # create logger 45 | self.logger = Logger(self.work_dir, 46 | use_tb=cfg.use_tb, 47 | use_wandb=cfg.use_wandb) 48 | # create envs 49 | self.train_env = dmc.make(cfg.task, cfg.obs_type, cfg.frame_stack, 50 | cfg.action_repeat, cfg.seed) 51 | self.eval_env = dmc.make(cfg.task, cfg.obs_type, cfg.frame_stack, 52 | cfg.action_repeat, cfg.seed) 53 | 54 | # create agent 55 | self.agent = make_agent(cfg.obs_type, 56 | self.train_env.observation_spec(), 57 | self.train_env.action_spec(), 58 | cfg.num_seed_frames // cfg.action_repeat, 59 | cfg.agent) 60 | 61 | # initialize from pretrained 62 | if cfg.snapshot_ts > 0: 63 | pretrained_agent = self.load_snapshot()['agent'] 64 | self.agent.init_from(pretrained_agent) 65 | 66 | # get meta specs 67 | meta_specs = self.agent.get_meta_specs() 68 | # create replay buffer 69 | data_specs = (self.train_env.observation_spec(), 70 | self.train_env.action_spec(), 71 | specs.Array((1,), np.float32, 'reward'), 72 | specs.Array((1,), np.float32, 'discount')) 73 | 74 | # create data storage 75 | self.replay_storage = ReplayBufferStorage(data_specs, meta_specs, 76 | self.work_dir / 'buffer') 77 | 78 | # create replay buffer 79 | self.replay_loader = make_replay_loader(self.replay_storage, 80 | cfg.replay_buffer_size, 81 | cfg.batch_size, 82 | cfg.replay_buffer_num_workers, 83 | False, cfg.nstep, cfg.discount) 84 | self._replay_iter = None 85 | 86 | # create video recorders 87 | self.video_recorder = VideoRecorder( 88 | self.work_dir if cfg.save_video else None) 89 | self.train_video_recorder = TrainVideoRecorder( 90 | self.work_dir if cfg.save_train_video else None) 91 | 92 | self.timer = utils.Timer() 93 | self._global_step = 0 94 | self._global_episode = 0 95 | 96 | @property 97 | def global_step(self): 98 | return self._global_step 99 | 100 | @property 101 | def global_episode(self): 102 | return self._global_episode 103 | 104 | @property 105 | def global_frame(self): 106 | return self.global_step * self.cfg.action_repeat 107 | 108 | @property 109 | def replay_iter(self): 110 | if self._replay_iter is None: 111 | self._replay_iter = iter(self.replay_loader) 112 | return self._replay_iter 113 | 114 | def eval(self): 115 | step, episode, total_reward = 0, 0, 0 116 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 117 | meta = self.agent.init_meta() 118 | while eval_until_episode(episode): 119 | time_step = self.eval_env.reset() 120 | self.video_recorder.init(self.eval_env, enabled=(episode == 0)) 121 | while not time_step.last(): 122 | with torch.no_grad(), utils.eval_mode(self.agent): 123 | action = self.agent.act(time_step.observation, 124 | meta, 125 | self.global_step, 126 | eval_mode=True) 127 | time_step = self.eval_env.step(action) 128 | self.video_recorder.record(self.eval_env) 129 | total_reward += time_step.reward 130 | step += 1 131 | 132 | episode += 1 133 | self.video_recorder.save(f'{self.global_frame}.mp4') 134 | 135 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log: 136 | log('episode_reward', total_reward / episode) 137 | log('episode_length', step * self.cfg.action_repeat / episode) 138 | log('episode', self.global_episode) 139 | log('step', self.global_step) 140 | if 'skill' in meta.keys(): 141 | log('skill', meta['skill'].argmax()) 142 | 143 | def train(self): 144 | # predicates 145 | train_until_step = utils.Until(self.cfg.num_train_frames, 146 | self.cfg.action_repeat) 147 | seed_until_step = utils.Until(self.cfg.num_seed_frames, 148 | self.cfg.action_repeat) 149 | eval_every_step = utils.Every(self.cfg.eval_every_frames, 150 | self.cfg.action_repeat) 151 | 152 | episode_step, episode_reward = 0, 0 153 | time_step = self.train_env.reset() 154 | meta = self.agent.init_meta() 155 | self.replay_storage.add(time_step, meta) 156 | self.train_video_recorder.init(time_step.observation) 157 | metrics = None 158 | while train_until_step(self.global_step): 159 | if time_step.last(): 160 | self._global_episode += 1 161 | self.train_video_recorder.save(f'{self.global_frame}.mp4') 162 | # wait until all the metrics schema is populated 163 | if metrics is not None: 164 | # log stats 165 | elapsed_time, total_time = self.timer.reset() 166 | episode_frame = episode_step * self.cfg.action_repeat 167 | with self.logger.log_and_dump_ctx(self.global_frame, 168 | ty='train') as log: 169 | log('fps', episode_frame / elapsed_time) 170 | log('total_time', total_time) 171 | log('episode_reward', episode_reward) 172 | log('episode_length', episode_frame) 173 | log('episode', self.global_episode) 174 | log('buffer_size', len(self.replay_storage)) 175 | log('step', self.global_step) 176 | 177 | # reset env 178 | time_step = self.train_env.reset() 179 | meta = self.agent.init_meta() 180 | self.replay_storage.add(time_step, meta) 181 | self.train_video_recorder.init(time_step.observation) 182 | 183 | episode_step = 0 184 | episode_reward = 0 185 | 186 | # try to evaluate 187 | if eval_every_step(self.global_step): 188 | self.logger.log('eval_total_time', self.timer.total_time(), 189 | self.global_frame) 190 | self.eval() 191 | 192 | meta = self.agent.update_meta(meta, self.global_step, time_step) 193 | 194 | if hasattr(self.agent, "regress_meta"): 195 | repeat = self.cfg.action_repeat 196 | every = self.agent.update_task_every_step // repeat 197 | init_step = self.agent.num_init_steps 198 | if self.global_step > ( 199 | init_step // repeat) and self.global_step % every == 0: 200 | meta = self.agent.regress_meta(self.replay_iter, 201 | self.global_step) 202 | 203 | # sample action 204 | with torch.no_grad(), utils.eval_mode(self.agent): 205 | action = self.agent.act(time_step.observation, 206 | meta, 207 | self.global_step, 208 | eval_mode=False) 209 | 210 | # try to update the agent 211 | if not seed_until_step(self.global_step): 212 | metrics = self.agent.update(self.replay_iter, self.global_step) 213 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 214 | 215 | # take env step 216 | time_step = self.train_env.step(action) 217 | episode_reward += time_step.reward 218 | self.replay_storage.add(time_step, meta) 219 | self.train_video_recorder.record(time_step.observation) 220 | episode_step += 1 221 | self._global_step += 1 222 | 223 | def load_snapshot(self): 224 | snapshot_base_dir = Path(self.cfg.snapshot_base_dir) 225 | domain, _ = self.cfg.task.split('_', 1) 226 | snapshot_dir = snapshot_base_dir / self.cfg.obs_type / domain / self.cfg.agent.name 227 | 228 | def try_load(seed): 229 | snapshot = '../../../../../' / snapshot_dir / str( 230 | seed) / f'snapshot_{self.cfg.snapshot_ts}.pt' 231 | logging.info("loading model :{},cwd is {}".format(str(snapshot), str(Path.cwd()))) 232 | if not snapshot.exists(): 233 | logging.error("no such a pretrain model") 234 | return None 235 | with snapshot.open('rb') as f: 236 | payload = torch.load(f, map_location='cuda:0') 237 | return payload 238 | 239 | # try to load current seed 240 | payload = try_load(self.cfg.seed) 241 | assert payload is not None 242 | 243 | return payload 244 | 245 | 246 | @hydra.main(config_path='.', config_name='finetune') 247 | def main(cfg): 248 | from finetune import Workspace as W 249 | root_dir = Path.cwd() 250 | logging.basicConfig(encoding="utf-8", level=logging.DEBUG) 251 | workspace = W(cfg) 252 | snapshot = root_dir / 'snapshot.pt' 253 | if snapshot.exists(): 254 | print(f'resuming: {snapshot}') 255 | workspace.load_snapshot() 256 | workspace.train() 257 | 258 | 259 | if __name__ == '__main__': 260 | main() 261 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import re 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from omegaconf import OmegaConf 11 | from torch import distributions as pyd 12 | from torch.distributions.utils import _standard_normal 13 | 14 | 15 | class eval_mode: 16 | def __init__(self, *models): 17 | self.models = models 18 | 19 | def __enter__(self): 20 | self.prev_states = [] 21 | for model in self.models: 22 | self.prev_states.append(model.training) 23 | model.train(False) 24 | 25 | def __exit__(self, *args): 26 | for model, state in zip(self.models, self.prev_states): 27 | model.train(state) 28 | return False 29 | 30 | 31 | def set_seed_everywhere(seed): 32 | torch.manual_seed(seed) 33 | if torch.cuda.is_available(): 34 | torch.cuda.manual_seed_all(seed) 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | 38 | 39 | def chain(*iterables): 40 | for it in iterables: 41 | yield from it 42 | 43 | 44 | def soft_update_params(net, target_net, tau): 45 | for param, target_param in zip(net.parameters(), target_net.parameters()): 46 | target_param.data.copy_(tau * param.data + 47 | (1 - tau) * target_param.data) 48 | 49 | 50 | def hard_update_params(net, target_net): 51 | for param, target_param in zip(net.parameters(), target_net.parameters()): 52 | target_param.data.copy_(param.data) 53 | 54 | 55 | def to_torch(xs, device): 56 | return tuple(torch.as_tensor(x, device=device) for x in xs) 57 | 58 | 59 | def weight_init(m): 60 | """Custom weight init for Conv2D and Linear layers.""" 61 | if isinstance(m, nn.Linear): 62 | nn.init.orthogonal_(m.weight.data) 63 | if hasattr(m.bias, 'data'): 64 | m.bias.data.fill_(0.0) 65 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 66 | gain = nn.init.calculate_gain('relu') 67 | nn.init.orthogonal_(m.weight.data, gain) 68 | if hasattr(m.bias, 'data'): 69 | m.bias.data.fill_(0.0) 70 | 71 | 72 | def grad_norm(params, norm_type=2.0): 73 | params = [p for p in params if p.grad is not None] 74 | total_norm = torch.norm( 75 | torch.stack([torch.norm(p.grad.detach(), norm_type) for p in params]), 76 | norm_type) 77 | return total_norm.item() 78 | 79 | 80 | def param_norm(params, norm_type=2.0): 81 | total_norm = torch.norm( 82 | torch.stack([torch.norm(p.detach(), norm_type) for p in params]), 83 | norm_type) 84 | return total_norm.item() 85 | 86 | 87 | class Until: 88 | def __init__(self, until, action_repeat=1): 89 | self._until = until 90 | self._action_repeat = action_repeat 91 | 92 | def __call__(self, step): 93 | if self._until is None: 94 | return True 95 | until = self._until // self._action_repeat 96 | return step < until 97 | 98 | 99 | class Every: 100 | def __init__(self, every, action_repeat=1): 101 | self._every = every 102 | self._action_repeat = action_repeat 103 | 104 | def __call__(self, step): 105 | if self._every is None: 106 | return False 107 | every = self._every // self._action_repeat 108 | if step % every == 0: 109 | return True 110 | return False 111 | 112 | 113 | class Timer: 114 | def __init__(self): 115 | self._start_time = time.time() 116 | self._last_time = time.time() 117 | 118 | def reset(self): 119 | elapsed_time = time.time() - self._last_time 120 | self._last_time = time.time() 121 | total_time = time.time() - self._start_time 122 | return elapsed_time, total_time 123 | 124 | def total_time(self): 125 | return time.time() - self._start_time 126 | 127 | 128 | class TruncatedNormal(pyd.Normal): 129 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 130 | super().__init__(loc, scale, validate_args=False) 131 | self.low = low 132 | self.high = high 133 | self.eps = eps 134 | 135 | def _clamp(self, x): 136 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 137 | x = x - x.detach() + clamped_x.detach() 138 | return x 139 | 140 | def sample(self, clip=None, sample_shape=torch.Size()): 141 | shape = self._extended_shape(sample_shape) 142 | eps = _standard_normal(shape, 143 | dtype=self.loc.dtype, 144 | device=self.loc.device) 145 | eps *= self.scale 146 | if clip is not None: 147 | eps = torch.clamp(eps, -clip, clip) 148 | x = self.loc + eps 149 | return self._clamp(x) 150 | 151 | 152 | class TanhTransform(pyd.transforms.Transform): 153 | domain = pyd.constraints.real 154 | codomain = pyd.constraints.interval(-1.0, 1.0) 155 | bijective = True 156 | sign = +1 157 | 158 | def __init__(self, cache_size=1): 159 | super().__init__(cache_size=cache_size) 160 | 161 | @staticmethod 162 | def atanh(x): 163 | return 0.5 * (x.log1p() - (-x).log1p()) 164 | 165 | def __eq__(self, other): 166 | return isinstance(other, TanhTransform) 167 | 168 | def _call(self, x): 169 | return x.tanh() 170 | 171 | def _inverse(self, y): 172 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 173 | # one should use `cache_size=1` instead 174 | return self.atanh(y) 175 | 176 | def log_abs_det_jacobian(self, x, y): 177 | # We use a formula that is more numerically stable, see details in the following link 178 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 179 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 180 | 181 | 182 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 183 | def __init__(self, loc, scale): 184 | self.loc = loc 185 | self.scale = scale 186 | 187 | self.base_dist = pyd.Normal(loc, scale) 188 | transforms = [TanhTransform()] 189 | super().__init__(self.base_dist, transforms) 190 | 191 | @property 192 | def mean(self): 193 | mu = self.loc 194 | for tr in self.transforms: 195 | mu = tr(mu) 196 | return mu 197 | 198 | 199 | def schedule(schdl, step): 200 | try: 201 | return float(schdl) 202 | except ValueError: 203 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 204 | if match: 205 | init, final, duration = [float(g) for g in match.groups()] 206 | mix = np.clip(step / duration, 0.0, 1.0) 207 | return (1.0 - mix) * init + mix * final 208 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 209 | if match: 210 | init, final1, duration1, final2, duration2 = [ 211 | float(g) for g in match.groups() 212 | ] 213 | if step <= duration1: 214 | mix = np.clip(step / duration1, 0.0, 1.0) 215 | return (1.0 - mix) * init + mix * final1 216 | else: 217 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 218 | return (1.0 - mix) * final1 + mix * final2 219 | raise NotImplementedError(schdl) 220 | 221 | 222 | class RandomShiftsAug(nn.Module): 223 | def __init__(self, pad): 224 | super().__init__() 225 | self.pad = pad 226 | 227 | def forward(self, x): 228 | x = x.float() 229 | n, c, h, w = x.size() 230 | assert h == w 231 | padding = tuple([self.pad] * 4) 232 | x = F.pad(x, padding, 'replicate') 233 | eps = 1.0 / (h + 2 * self.pad) 234 | arange = torch.linspace(-1.0 + eps, 235 | 1.0 - eps, 236 | h + 2 * self.pad, 237 | device=x.device, 238 | dtype=x.dtype)[:h] 239 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 240 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 241 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 242 | 243 | shift = torch.randint(0, 244 | 2 * self.pad + 1, 245 | size=(n, 1, 1, 2), 246 | device=x.device, 247 | dtype=x.dtype) 248 | shift *= 2.0 / (h + 2 * self.pad) 249 | 250 | grid = base_grid + shift 251 | return F.grid_sample(x, 252 | grid, 253 | padding_mode='zeros', 254 | align_corners=False) 255 | 256 | 257 | class RMS(object): 258 | """running mean and std """ 259 | def __init__(self, device, epsilon=1e-4, shape=(1,)): 260 | self.M = torch.zeros(shape).to(device) 261 | self.S = torch.ones(shape).to(device) 262 | self.n = epsilon 263 | 264 | def __call__(self, x): 265 | bs = x.size(0) 266 | delta = torch.mean(x, dim=0) - self.M 267 | new_M = self.M + delta * bs / (self.n + bs) 268 | new_S = (self.S * self.n + torch.var(x, dim=0) * bs + 269 | torch.square(delta) * self.n * bs / 270 | (self.n + bs)) / (self.n + bs) 271 | 272 | self.M = new_M 273 | self.S = new_S 274 | self.n += bs 275 | 276 | return self.M, self.S 277 | 278 | 279 | class PBE(object): 280 | """particle-based entropy based on knn normalized by running mean """ 281 | def __init__(self, rms, knn_clip, knn_k, knn_avg, knn_rms, device): 282 | self.rms = rms 283 | self.knn_rms = knn_rms 284 | self.knn_k = knn_k 285 | self.knn_avg = knn_avg 286 | self.knn_clip = knn_clip 287 | self.device = device 288 | 289 | def __call__(self, rep): 290 | source = target = rep 291 | b1, b2 = source.size(0), target.size(0) 292 | # (b1, 1, c) - (1, b2, c) -> (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2) 293 | sim_matrix = torch.norm(source[:, None, :].view(b1, 1, -1) - 294 | target[None, :, :].view(1, b2, -1), 295 | dim=-1, 296 | p=2) 297 | reward, _ = sim_matrix.topk(self.knn_k, 298 | dim=1, 299 | largest=False, 300 | sorted=True) # (b1, k) 301 | if not self.knn_avg: # only keep k-th nearest neighbor 302 | reward = reward[:, -1] 303 | reward = reward.reshape(-1, 1) # (b1, 1) 304 | reward /= self.rms(reward)[0] if self.knn_rms else 1.0 305 | reward = torch.maximum( 306 | reward - self.knn_clip, 307 | torch.zeros_like(reward).to(self.device) 308 | ) if self.knn_clip >= 0.0 else reward # (b1, 1) 309 | else: # average over all k nearest neighbors 310 | reward = reward.reshape(-1, 1) # (b1 * k, 1) 311 | reward /= self.rms(reward)[0] if self.knn_rms else 1.0 312 | reward = torch.maximum( 313 | reward - self.knn_clip, 314 | torch.zeros_like(reward).to( 315 | self.device)) if self.knn_clip >= 0.0 else reward 316 | reward = reward.reshape((b1, self.knn_k)) # (b1, k) 317 | reward = reward.mean(dim=1, keepdim=True) # (b1, 1) 318 | reward = torch.log(reward + 1.0) 319 | return reward 320 | -------------------------------------------------------------------------------- /agent/ddpg.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import utils 10 | 11 | 12 | class Encoder(nn.Module): 13 | def __init__(self, obs_shape): 14 | super().__init__() 15 | 16 | assert len(obs_shape) == 3 17 | self.repr_dim = 32 * 35 * 35 18 | 19 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2), 20 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 21 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 22 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 23 | nn.ReLU()) 24 | 25 | self.apply(utils.weight_init) 26 | 27 | def forward(self, obs): 28 | obs = obs / 255.0 - 0.5 29 | h = self.convnet(obs) 30 | h = h.view(h.shape[0], -1) 31 | return h 32 | 33 | 34 | class Actor(nn.Module): 35 | def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim): 36 | super().__init__() 37 | 38 | feature_dim = feature_dim if obs_type == 'pixels' else hidden_dim 39 | 40 | self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim), 41 | nn.LayerNorm(feature_dim), nn.Tanh()) 42 | 43 | policy_layers = [] 44 | policy_layers += [ 45 | nn.Linear(feature_dim, hidden_dim), 46 | nn.ReLU(inplace=True) 47 | ] 48 | # add additional hidden layer for pixels 49 | if obs_type == 'pixels': 50 | policy_layers += [ 51 | nn.Linear(hidden_dim, hidden_dim), 52 | nn.ReLU(inplace=True) 53 | ] 54 | policy_layers += [nn.Linear(hidden_dim, action_dim)] 55 | 56 | self.policy = nn.Sequential(*policy_layers) 57 | 58 | self.apply(utils.weight_init) 59 | 60 | def forward(self, obs, std): 61 | h = self.trunk(obs) 62 | 63 | mu = self.policy(h) 64 | mu = torch.tanh(mu) 65 | std = torch.ones_like(mu) * std 66 | 67 | dist = utils.TruncatedNormal(mu, std) 68 | return dist 69 | 70 | 71 | class Critic(nn.Module): 72 | def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim): 73 | super().__init__() 74 | 75 | self.obs_type = obs_type 76 | 77 | if obs_type == 'pixels': 78 | # for pixels actions will be added after trunk 79 | self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim), 80 | nn.LayerNorm(feature_dim), nn.Tanh()) 81 | trunk_dim = feature_dim + action_dim 82 | else: 83 | # for states actions come in the beginning 84 | self.trunk = nn.Sequential( 85 | nn.Linear(obs_dim + action_dim, hidden_dim), 86 | nn.LayerNorm(hidden_dim), nn.Tanh()) 87 | trunk_dim = hidden_dim 88 | 89 | def make_q(): 90 | q_layers = [] 91 | q_layers += [ 92 | nn.Linear(trunk_dim, hidden_dim), 93 | nn.ReLU(inplace=True) 94 | ] 95 | if obs_type == 'pixels': 96 | q_layers += [ 97 | nn.Linear(hidden_dim, hidden_dim), 98 | nn.ReLU(inplace=True) 99 | ] 100 | q_layers += [nn.Linear(hidden_dim, 1)] 101 | return nn.Sequential(*q_layers) 102 | 103 | self.Q1 = make_q() 104 | self.Q2 = make_q() 105 | 106 | self.apply(utils.weight_init) 107 | 108 | def forward(self, obs, action): 109 | inpt = obs if self.obs_type == 'pixels' else torch.cat([obs, action], 110 | dim=-1) 111 | h = self.trunk(inpt) 112 | h = torch.cat([h, action], dim=-1) if self.obs_type == 'pixels' else h 113 | 114 | q1 = self.Q1(h) 115 | q2 = self.Q2(h) 116 | 117 | return q1, q2 118 | 119 | 120 | class DDPGAgent: 121 | def __init__(self, 122 | name, 123 | reward_free, 124 | obs_type, 125 | obs_shape, 126 | action_shape, 127 | device, 128 | lr, 129 | feature_dim, 130 | hidden_dim, 131 | critic_target_tau, 132 | num_expl_steps, 133 | update_every_steps, 134 | stddev_schedule, 135 | nstep, 136 | batch_size, 137 | stddev_clip, 138 | init_critic, 139 | use_tb, 140 | use_wandb, 141 | meta_dim=0): 142 | self.reward_free = reward_free 143 | self.obs_type = obs_type 144 | self.obs_shape = obs_shape 145 | self.action_dim = action_shape[0] 146 | self.hidden_dim = hidden_dim 147 | self.lr = lr 148 | self.device = device 149 | self.critic_target_tau = critic_target_tau 150 | self.update_every_steps = update_every_steps 151 | self.use_tb = use_tb 152 | self.use_wandb = use_wandb 153 | self.num_expl_steps = num_expl_steps 154 | self.stddev_schedule = stddev_schedule 155 | self.stddev_clip = stddev_clip 156 | self.init_critic = init_critic 157 | self.feature_dim = feature_dim 158 | self.solved_meta = None 159 | 160 | # models 161 | if obs_type == 'pixels': 162 | self.aug = utils.RandomShiftsAug(pad=4) 163 | self.encoder = Encoder(obs_shape).to(device) 164 | self.obs_dim = self.encoder.repr_dim + meta_dim 165 | else: 166 | self.aug = nn.Identity() 167 | self.encoder = nn.Identity() 168 | self.obs_dim = obs_shape[0] + meta_dim 169 | 170 | self.actor = Actor(obs_type, self.obs_dim, self.action_dim, 171 | feature_dim, hidden_dim).to(device) 172 | 173 | self.critic = Critic(obs_type, self.obs_dim, self.action_dim, 174 | feature_dim, hidden_dim).to(device) 175 | self.critic_target = Critic(obs_type, self.obs_dim, self.action_dim, 176 | feature_dim, hidden_dim).to(device) 177 | self.critic_target.load_state_dict(self.critic.state_dict()) 178 | 179 | # optimizers 180 | 181 | if obs_type == 'pixels': 182 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), 183 | lr=lr) 184 | else: 185 | self.encoder_opt = None 186 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 187 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 188 | 189 | self.train() 190 | self.critic_target.train() 191 | 192 | def train(self, training=True): 193 | self.training = training 194 | self.encoder.train(training) 195 | self.actor.train(training) 196 | self.critic.train(training) 197 | 198 | def init_from(self, other): 199 | # copy parameters over 200 | utils.hard_update_params(other.encoder, self.encoder) 201 | utils.hard_update_params(other.actor, self.actor) 202 | if self.init_critic: 203 | utils.hard_update_params(other.critic.trunk, self.critic.trunk) 204 | 205 | def get_meta_specs(self): 206 | return tuple() 207 | 208 | def init_meta(self): 209 | return OrderedDict() 210 | 211 | def update_meta(self, meta, global_step, time_step, finetune=False): 212 | return meta 213 | 214 | def act(self, obs, meta, step, eval_mode): 215 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 216 | h = self.encoder(obs) 217 | inputs = [h] 218 | for value in meta.values(): 219 | value = torch.as_tensor(value, device=self.device).unsqueeze(0) 220 | inputs.append(value) 221 | inpt = torch.cat(inputs, dim=-1) 222 | #assert obs.shape[-1] == self.obs_shape[-1] 223 | stddev = utils.schedule(self.stddev_schedule, step) 224 | dist = self.actor(inpt, stddev) 225 | if eval_mode: 226 | action = dist.mean 227 | else: 228 | action = dist.sample(clip=None) 229 | if step < self.num_expl_steps: 230 | action.uniform_(-1.0, 1.0) 231 | return action.cpu().numpy()[0] 232 | 233 | def update_critic(self, obs, action, reward, discount, next_obs, step): 234 | metrics = dict() 235 | 236 | with torch.no_grad(): 237 | stddev = utils.schedule(self.stddev_schedule, step) 238 | dist = self.actor(next_obs, stddev) 239 | next_action = dist.sample(clip=self.stddev_clip) 240 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 241 | target_V = torch.min(target_Q1, target_Q2) 242 | target_Q = reward + (discount * target_V) 243 | 244 | Q1, Q2 = self.critic(obs, action) 245 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 246 | 247 | if self.use_tb or self.use_wandb: 248 | metrics['critic_target_q'] = target_Q.mean().item() 249 | metrics['critic_q1'] = Q1.mean().item() 250 | metrics['critic_q2'] = Q2.mean().item() 251 | metrics['critic_loss'] = critic_loss.item() 252 | 253 | # optimize critic 254 | if self.encoder_opt is not None: 255 | self.encoder_opt.zero_grad(set_to_none=True) 256 | self.critic_opt.zero_grad(set_to_none=True) 257 | critic_loss.backward() 258 | self.critic_opt.step() 259 | if self.encoder_opt is not None: 260 | self.encoder_opt.step() 261 | return metrics 262 | 263 | def update_actor(self, obs, step): 264 | metrics = dict() 265 | 266 | stddev = utils.schedule(self.stddev_schedule, step) 267 | dist = self.actor(obs, stddev) 268 | action = dist.sample(clip=self.stddev_clip) 269 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 270 | Q1, Q2 = self.critic(obs, action) 271 | Q = torch.min(Q1, Q2) 272 | 273 | actor_loss = -Q.mean() 274 | 275 | # optimize actor 276 | self.actor_opt.zero_grad(set_to_none=True) 277 | actor_loss.backward() 278 | self.actor_opt.step() 279 | 280 | if self.use_tb or self.use_wandb: 281 | metrics['actor_loss'] = actor_loss.item() 282 | metrics['actor_logprob'] = log_prob.mean().item() 283 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() 284 | 285 | return metrics 286 | 287 | def aug_and_encode(self, obs): 288 | obs = self.aug(obs) 289 | return self.encoder(obs) 290 | 291 | def update(self, replay_iter, step): 292 | metrics = dict() 293 | #import ipdb; ipdb.set_trace() 294 | 295 | if step % self.update_every_steps != 0: 296 | return metrics 297 | 298 | batch = next(replay_iter) 299 | obs, action, reward, discount, next_obs = utils.to_torch( 300 | batch, self.device) 301 | 302 | # augment and encode 303 | obs = self.aug_and_encode(obs) 304 | with torch.no_grad(): 305 | next_obs = self.aug_and_encode(next_obs) 306 | 307 | if self.use_tb or self.use_wandb: 308 | metrics['batch_reward'] = reward.mean().item() 309 | 310 | # update critic 311 | metrics.update( 312 | self.update_critic(obs, action, reward, discount, next_obs, step)) 313 | 314 | # update actor 315 | metrics.update(self.update_actor(obs.detach(), step)) 316 | 317 | # update critic target 318 | utils.soft_update_params(self.critic, self.critic_target, 319 | self.critic_target_tau) 320 | 321 | return metrics 322 | -------------------------------------------------------------------------------- /dmc.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, deque 2 | from typing import Any, NamedTuple 3 | 4 | import dm_env 5 | import numpy as np 6 | from dm_control import manipulation, suite 7 | from dm_control.suite.wrappers import action_scale, pixels 8 | from dm_env import StepType, specs 9 | 10 | import custom_dmc_tasks as cdmc 11 | 12 | 13 | class ExtendedTimeStep(NamedTuple): 14 | step_type: Any 15 | reward: Any 16 | discount: Any 17 | observation: Any 18 | action: Any 19 | 20 | def first(self): 21 | return self.step_type == StepType.FIRST 22 | 23 | def mid(self): 24 | return self.step_type == StepType.MID 25 | 26 | def last(self): 27 | return self.step_type == StepType.LAST 28 | 29 | def __getitem__(self, attr): 30 | return getattr(self, attr) 31 | 32 | 33 | class FlattenJacoObservationWrapper(dm_env.Environment): 34 | def __init__(self, env): 35 | self._env = env 36 | self._obs_spec = OrderedDict() 37 | wrapped_obs_spec = env.observation_spec().copy() 38 | if 'front_close' in wrapped_obs_spec: 39 | spec = wrapped_obs_spec['front_close'] 40 | # drop batch dim 41 | self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:], 42 | dtype=spec.dtype, 43 | minimum=spec.minimum, 44 | maximum=spec.maximum, 45 | name='pixels') 46 | wrapped_obs_spec.pop('front_close') 47 | 48 | for key, spec in wrapped_obs_spec.items(): 49 | assert spec.dtype == np.float64 50 | assert type(spec) == specs.Array 51 | dim = np.sum( 52 | np.fromiter((np.int(np.prod(spec.shape)) 53 | for spec in wrapped_obs_spec.values()), np.int32)) 54 | 55 | self._obs_spec['observations'] = specs.Array(shape=(dim,), 56 | dtype=np.float32, 57 | name='observations') 58 | 59 | def _transform_observation(self, time_step): 60 | obs = OrderedDict() 61 | 62 | if 'front_close' in time_step.observation: 63 | pixels = time_step.observation['front_close'] 64 | time_step.observation.pop('front_close') 65 | pixels = np.squeeze(pixels) 66 | obs['pixels'] = pixels 67 | 68 | features = [] 69 | for feature in time_step.observation.values(): 70 | features.append(feature.ravel()) 71 | obs['observations'] = np.concatenate(features, axis=0) 72 | return time_step._replace(observation=obs) 73 | 74 | def reset(self): 75 | time_step = self._env.reset() 76 | return self._transform_observation(time_step) 77 | 78 | def step(self, action): 79 | time_step = self._env.step(action) 80 | return self._transform_observation(time_step) 81 | 82 | def observation_spec(self): 83 | return self._obs_spec 84 | 85 | def action_spec(self): 86 | return self._env.action_spec() 87 | 88 | def __getattr__(self, name): 89 | return getattr(self._env, name) 90 | 91 | 92 | class ActionRepeatWrapper(dm_env.Environment): 93 | def __init__(self, env, num_repeats): 94 | self._env = env 95 | self._num_repeats = num_repeats 96 | 97 | def step(self, action): 98 | reward = 0.0 99 | discount = 1.0 100 | for i in range(self._num_repeats): 101 | time_step = self._env.step(action) 102 | reward += (time_step.reward or 0.0) * discount 103 | discount *= time_step.discount 104 | if time_step.last(): 105 | break 106 | 107 | return time_step._replace(reward=reward, discount=discount) 108 | 109 | def observation_spec(self): 110 | return self._env.observation_spec() 111 | 112 | def action_spec(self): 113 | return self._env.action_spec() 114 | 115 | def reset(self): 116 | return self._env.reset() 117 | 118 | def __getattr__(self, name): 119 | return getattr(self._env, name) 120 | 121 | 122 | class FrameStackWrapper(dm_env.Environment): 123 | def __init__(self, env, num_frames, pixels_key='pixels'): 124 | self._env = env 125 | self._num_frames = num_frames 126 | self._frames = deque([], maxlen=num_frames) 127 | self._pixels_key = pixels_key 128 | 129 | wrapped_obs_spec = env.observation_spec() 130 | assert pixels_key in wrapped_obs_spec 131 | 132 | pixels_shape = wrapped_obs_spec[pixels_key].shape 133 | # remove batch dim 134 | if len(pixels_shape) == 4: 135 | pixels_shape = pixels_shape[1:] 136 | self._obs_spec = specs.BoundedArray(shape=np.concatenate( 137 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0), 138 | dtype=np.uint8, 139 | minimum=0, 140 | maximum=255, 141 | name='observation') 142 | 143 | def _transform_observation(self, time_step): 144 | assert len(self._frames) == self._num_frames 145 | obs = np.concatenate(list(self._frames), axis=0) 146 | return time_step._replace(observation=obs) 147 | 148 | def _extract_pixels(self, time_step): 149 | pixels = time_step.observation[self._pixels_key] 150 | # remove batch dim 151 | if len(pixels.shape) == 4: 152 | pixels = pixels[0] 153 | return pixels.transpose(2, 0, 1).copy() 154 | 155 | def reset(self): 156 | time_step = self._env.reset() 157 | pixels = self._extract_pixels(time_step) 158 | for _ in range(self._num_frames): 159 | self._frames.append(pixels) 160 | return self._transform_observation(time_step) 161 | 162 | def step(self, action): 163 | time_step = self._env.step(action) 164 | pixels = self._extract_pixels(time_step) 165 | self._frames.append(pixels) 166 | return self._transform_observation(time_step) 167 | 168 | def observation_spec(self): 169 | return self._obs_spec 170 | 171 | def action_spec(self): 172 | return self._env.action_spec() 173 | 174 | def __getattr__(self, name): 175 | return getattr(self._env, name) 176 | 177 | 178 | class ActionDTypeWrapper(dm_env.Environment): 179 | def __init__(self, env, dtype): 180 | self._env = env 181 | wrapped_action_spec = env.action_spec() 182 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, 183 | dtype, 184 | wrapped_action_spec.minimum, 185 | wrapped_action_spec.maximum, 186 | 'action') 187 | 188 | def step(self, action): 189 | action = action.astype(self._env.action_spec().dtype) 190 | return self._env.step(action) 191 | 192 | def observation_spec(self): 193 | return self._env.observation_spec() 194 | 195 | def action_spec(self): 196 | return self._action_spec 197 | 198 | def reset(self): 199 | return self._env.reset() 200 | 201 | def __getattr__(self, name): 202 | return getattr(self._env, name) 203 | 204 | 205 | class ObservationDTypeWrapper(dm_env.Environment): 206 | def __init__(self, env, dtype): 207 | self._env = env 208 | self._dtype = dtype 209 | wrapped_obs_spec = env.observation_spec()['observations'] 210 | self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype, 211 | 'observation') 212 | 213 | def _transform_observation(self, time_step): 214 | obs = time_step.observation['observations'].astype(self._dtype) 215 | return time_step._replace(observation=obs) 216 | 217 | def reset(self): 218 | time_step = self._env.reset() 219 | return self._transform_observation(time_step) 220 | 221 | def step(self, action): 222 | time_step = self._env.step(action) 223 | return self._transform_observation(time_step) 224 | 225 | def observation_spec(self): 226 | return self._obs_spec 227 | 228 | def action_spec(self): 229 | return self._env.action_spec() 230 | 231 | def __getattr__(self, name): 232 | return getattr(self._env, name) 233 | 234 | 235 | class ExtendedTimeStepWrapper(dm_env.Environment): 236 | def __init__(self, env): 237 | self._env = env 238 | 239 | def reset(self): 240 | time_step = self._env.reset() 241 | return self._augment_time_step(time_step) 242 | 243 | def step(self, action): 244 | time_step = self._env.step(action) 245 | return self._augment_time_step(time_step, action) 246 | 247 | def _augment_time_step(self, time_step, action=None): 248 | if action is None: 249 | action_spec = self.action_spec() 250 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 251 | return ExtendedTimeStep(observation=time_step.observation, 252 | step_type=time_step.step_type, 253 | action=action, 254 | reward=time_step.reward or 0.0, 255 | discount=time_step.discount or 1.0) 256 | 257 | def observation_spec(self): 258 | return self._env.observation_spec() 259 | 260 | def action_spec(self): 261 | return self._env.action_spec() 262 | 263 | def __getattr__(self, name): 264 | return getattr(self._env, name) 265 | 266 | 267 | def _make_jaco(obs_type, domain, task, frame_stack, action_repeat, seed): 268 | env = cdmc.make_jaco(task, obs_type, seed) 269 | env = ActionDTypeWrapper(env, np.float32) 270 | env = ActionRepeatWrapper(env, action_repeat) 271 | env = FlattenJacoObservationWrapper(env) 272 | return env 273 | 274 | 275 | def _make_dmc(obs_type, domain, task, frame_stack, action_repeat, seed): 276 | visualize_reward = False 277 | if (domain, task) in suite.ALL_TASKS: 278 | env = suite.load(domain, 279 | task, 280 | task_kwargs=dict(random=seed), 281 | environment_kwargs=dict(flat_observation=True), 282 | visualize_reward=visualize_reward) 283 | else: 284 | env = cdmc.make(domain, 285 | task, 286 | task_kwargs=dict(random=seed), 287 | environment_kwargs=dict(flat_observation=True), 288 | visualize_reward=visualize_reward) 289 | 290 | env = ActionDTypeWrapper(env, np.float32) 291 | env = ActionRepeatWrapper(env, action_repeat) 292 | if obs_type == 'pixels': 293 | # zoom in camera for quadruped 294 | camera_id = dict(quadruped=2).get(domain, 0) 295 | render_kwargs = dict(height=84, width=84, camera_id=camera_id) 296 | env = pixels.Wrapper(env, 297 | pixels_only=True, 298 | render_kwargs=render_kwargs) 299 | return env 300 | 301 | 302 | def make(name, obs_type, frame_stack, action_repeat, seed): 303 | assert obs_type in ['states', 'pixels'] 304 | domain, task = name.split('_', 1) 305 | domain = dict(cup='ball_in_cup').get(domain, domain) 306 | 307 | make_fn = _make_jaco if domain == 'jaco' else _make_dmc 308 | env = make_fn(obs_type, domain, task, frame_stack, action_repeat, seed) 309 | 310 | if obs_type == 'pixels': 311 | env = FrameStackWrapper(env, frame_stack) 312 | else: 313 | env = ObservationDTypeWrapper(env, np.float32) 314 | 315 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) 316 | env = ExtendedTimeStepWrapper(env) 317 | return env 318 | -------------------------------------------------------------------------------- /custom_dmc_tasks/quadruped.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /custom_dmc_tasks/quadruped.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Quadruped Domain.""" 17 | 18 | import collections 19 | 20 | from dm_control import mujoco 21 | from dm_control.mujoco.wrapper import mjbindings 22 | from dm_control.rl import control 23 | from dm_control.suite import base 24 | from dm_control.suite import common 25 | from dm_control.utils import containers 26 | from dm_control.utils import rewards 27 | from dm_control.utils import xml_tools 28 | from lxml import etree 29 | import numpy as np 30 | from scipy import ndimage 31 | import os 32 | 33 | enums = mjbindings.enums 34 | mjlib = mjbindings.mjlib 35 | 36 | 37 | _DEFAULT_TIME_LIMIT = 20 38 | _CONTROL_TIMESTEP = .02 39 | 40 | # Horizontal speeds above which the move reward is 1. 41 | _RUN_SPEED = 5 42 | _WALK_SPEED = 0.5 43 | 44 | _JUMP_HEIGHT = 1.0 45 | 46 | # Constants related to terrain generation. 47 | _HEIGHTFIELD_ID = 0 48 | _TERRAIN_SMOOTHNESS = 0.15 # 0.0: maximally bumpy; 1.0: completely smooth. 49 | _TERRAIN_BUMP_SCALE = 2 # Spatial scale of terrain bumps (in meters). 50 | 51 | # Named model elements. 52 | _TOES = ['toe_front_left', 'toe_back_left', 'toe_back_right', 'toe_front_right'] 53 | _WALLS = ['wall_px', 'wall_py', 'wall_nx', 'wall_ny'] 54 | 55 | SUITE = containers.TaggedTasks() 56 | 57 | def make(task, 58 | task_kwargs=None, 59 | environment_kwargs=None, 60 | visualize_reward=False): 61 | task_kwargs = task_kwargs or {} 62 | if environment_kwargs is not None: 63 | task_kwargs = task_kwargs.copy() 64 | task_kwargs['environment_kwargs'] = environment_kwargs 65 | env = SUITE[task](**task_kwargs) 66 | env.task.visualize_reward = visualize_reward 67 | return env 68 | 69 | def get_model_and_assets(): 70 | """Returns a tuple containing the model XML string and a dict of assets.""" 71 | root_dir = os.path.dirname(os.path.dirname(__file__)) 72 | xml = resources.GetResource( 73 | os.path.join(root_dir, 'custom_dmc_tasks', 'quadruped.xml')) 74 | return xml, common.ASSETS 75 | 76 | 77 | def make_model(floor_size=None, terrain=False, rangefinders=False, 78 | walls_and_ball=False): 79 | """Returns the model XML string.""" 80 | root_dir = os.path.dirname(os.path.dirname(__file__)) 81 | xml_string = common.read_model(os.path.join(root_dir, 'custom_dmc_tasks', 'quadruped.xml')) 82 | parser = etree.XMLParser(remove_blank_text=True) 83 | mjcf = etree.XML(xml_string, parser) 84 | 85 | # Set floor size. 86 | if floor_size is not None: 87 | floor_geom = mjcf.find('.//geom[@name=\'floor\']') 88 | floor_geom.attrib['size'] = f'{floor_size} {floor_size} .5' 89 | 90 | # Remove walls, ball and target. 91 | if not walls_and_ball: 92 | for wall in _WALLS: 93 | wall_geom = xml_tools.find_element(mjcf, 'geom', wall) 94 | wall_geom.getparent().remove(wall_geom) 95 | 96 | # Remove ball. 97 | ball_body = xml_tools.find_element(mjcf, 'body', 'ball') 98 | ball_body.getparent().remove(ball_body) 99 | 100 | # Remove target. 101 | target_site = xml_tools.find_element(mjcf, 'site', 'target') 102 | target_site.getparent().remove(target_site) 103 | 104 | # Remove terrain. 105 | if not terrain: 106 | terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain') 107 | terrain_geom.getparent().remove(terrain_geom) 108 | 109 | # Remove rangefinders if they're not used, as range computations can be 110 | # expensive, especially in a scene with heightfields. 111 | if not rangefinders: 112 | rangefinder_sensors = mjcf.findall('.//rangefinder') 113 | for rf in rangefinder_sensors: 114 | rf.getparent().remove(rf) 115 | 116 | return etree.tostring(mjcf, pretty_print=True) 117 | 118 | 119 | @SUITE.add() 120 | def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 121 | """Returns the Walk task.""" 122 | xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED) 123 | physics = Physics.from_xml_string(xml_string, common.ASSETS) 124 | task = Stand(random=random) 125 | environment_kwargs = environment_kwargs or {} 126 | return control.Environment(physics, task, time_limit=time_limit, 127 | control_timestep=_CONTROL_TIMESTEP, 128 | **environment_kwargs) 129 | 130 | @SUITE.add() 131 | def jump(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 132 | """Returns the Walk task.""" 133 | xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED) 134 | physics = Physics.from_xml_string(xml_string, common.ASSETS) 135 | task = Jump(desired_height=_JUMP_HEIGHT, random=random) 136 | environment_kwargs = environment_kwargs or {} 137 | return control.Environment(physics, task, time_limit=time_limit, 138 | control_timestep=_CONTROL_TIMESTEP, 139 | **environment_kwargs) 140 | 141 | @SUITE.add() 142 | def roll(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 143 | """Returns the Walk task.""" 144 | xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED) 145 | physics = Physics.from_xml_string(xml_string, common.ASSETS) 146 | task = Roll(desired_speed=_WALK_SPEED, random=random) 147 | environment_kwargs = environment_kwargs or {} 148 | return control.Environment(physics, task, time_limit=time_limit, 149 | control_timestep=_CONTROL_TIMESTEP, 150 | **environment_kwargs) 151 | 152 | @SUITE.add() 153 | def roll_fast(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 154 | """Returns the Walk task.""" 155 | xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED) 156 | physics = Physics.from_xml_string(xml_string, common.ASSETS) 157 | task = Roll(desired_speed=_RUN_SPEED, random=random) 158 | environment_kwargs = environment_kwargs or {} 159 | return control.Environment(physics, task, time_limit=time_limit, 160 | control_timestep=_CONTROL_TIMESTEP, 161 | **environment_kwargs) 162 | 163 | @SUITE.add() 164 | def escape(time_limit=_DEFAULT_TIME_LIMIT, random=None, 165 | environment_kwargs=None): 166 | """Returns the Escape task.""" 167 | xml_string = make_model(floor_size=40, terrain=True, rangefinders=True) 168 | physics = Physics.from_xml_string(xml_string, common.ASSETS) 169 | task = Escape(random=random) 170 | environment_kwargs = environment_kwargs or {} 171 | return control.Environment(physics, task, time_limit=time_limit, 172 | control_timestep=_CONTROL_TIMESTEP, 173 | **environment_kwargs) 174 | 175 | 176 | @SUITE.add() 177 | def fetch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 178 | """Returns the Fetch task.""" 179 | xml_string = make_model(walls_and_ball=True) 180 | physics = Physics.from_xml_string(xml_string, common.ASSETS) 181 | task = Fetch(random=random) 182 | environment_kwargs = environment_kwargs or {} 183 | return control.Environment(physics, task, time_limit=time_limit, 184 | control_timestep=_CONTROL_TIMESTEP, 185 | **environment_kwargs) 186 | 187 | 188 | class Physics(mujoco.Physics): 189 | """Physics simulation with additional features for the Quadruped domain.""" 190 | 191 | def _reload_from_data(self, data): 192 | super()._reload_from_data(data) 193 | # Clear cached sensor names when the physics is reloaded. 194 | self._sensor_types_to_names = {} 195 | self._hinge_names = [] 196 | 197 | def _get_sensor_names(self, *sensor_types): 198 | try: 199 | sensor_names = self._sensor_types_to_names[sensor_types] 200 | except KeyError: 201 | [sensor_ids] = np.where(np.in1d(self.model.sensor_type, sensor_types)) 202 | sensor_names = [self.model.id2name(s_id, 'sensor') for s_id in sensor_ids] 203 | self._sensor_types_to_names[sensor_types] = sensor_names 204 | return sensor_names 205 | 206 | def torso_upright(self): 207 | """Returns the dot-product of the torso z-axis and the global z-axis.""" 208 | return np.asarray(self.named.data.xmat['torso', 'zz']) 209 | 210 | def torso_velocity(self): 211 | """Returns the velocity of the torso, in the local frame.""" 212 | return self.named.data.sensordata['velocimeter'].copy() 213 | 214 | def com_height(self): 215 | return self.named.data.sensordata['center_of_mass'].copy()[2] 216 | 217 | def egocentric_state(self): 218 | """Returns the state without global orientation or position.""" 219 | if not self._hinge_names: 220 | [hinge_ids] = np.nonzero(self.model.jnt_type == 221 | enums.mjtJoint.mjJNT_HINGE) 222 | self._hinge_names = [self.model.id2name(j_id, 'joint') 223 | for j_id in hinge_ids] 224 | return np.hstack((self.named.data.qpos[self._hinge_names], 225 | self.named.data.qvel[self._hinge_names], 226 | self.data.act)) 227 | 228 | def toe_positions(self): 229 | """Returns toe positions in egocentric frame.""" 230 | torso_frame = self.named.data.xmat['torso'].reshape(3, 3) 231 | torso_pos = self.named.data.xpos['torso'] 232 | torso_to_toe = self.named.data.xpos[_TOES] - torso_pos 233 | return torso_to_toe.dot(torso_frame) 234 | 235 | def force_torque(self): 236 | """Returns scaled force/torque sensor readings at the toes.""" 237 | force_torque_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_FORCE, 238 | enums.mjtSensor.mjSENS_TORQUE) 239 | return np.arcsinh(self.named.data.sensordata[force_torque_sensors]) 240 | 241 | def imu(self): 242 | """Returns IMU-like sensor readings.""" 243 | imu_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_GYRO, 244 | enums.mjtSensor.mjSENS_ACCELEROMETER) 245 | return self.named.data.sensordata[imu_sensors] 246 | 247 | def rangefinder(self): 248 | """Returns scaled rangefinder sensor readings.""" 249 | rf_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_RANGEFINDER) 250 | rf_readings = self.named.data.sensordata[rf_sensors] 251 | no_intersection = -1.0 252 | return np.where(rf_readings == no_intersection, 1.0, np.tanh(rf_readings)) 253 | 254 | def origin_distance(self): 255 | """Returns the distance from the origin to the workspace.""" 256 | return np.asarray(np.linalg.norm(self.named.data.site_xpos['workspace'])) 257 | 258 | def origin(self): 259 | """Returns origin position in the torso frame.""" 260 | torso_frame = self.named.data.xmat['torso'].reshape(3, 3) 261 | torso_pos = self.named.data.xpos['torso'] 262 | return -torso_pos.dot(torso_frame) 263 | 264 | def ball_state(self): 265 | """Returns ball position and velocity relative to the torso frame.""" 266 | data = self.named.data 267 | torso_frame = data.xmat['torso'].reshape(3, 3) 268 | ball_rel_pos = data.xpos['ball'] - data.xpos['torso'] 269 | ball_rel_vel = data.qvel['ball_root'][:3] - data.qvel['root'][:3] 270 | ball_rot_vel = data.qvel['ball_root'][3:] 271 | ball_state = np.vstack((ball_rel_pos, ball_rel_vel, ball_rot_vel)) 272 | return ball_state.dot(torso_frame).ravel() 273 | 274 | def target_position(self): 275 | """Returns target position in torso frame.""" 276 | torso_frame = self.named.data.xmat['torso'].reshape(3, 3) 277 | torso_pos = self.named.data.xpos['torso'] 278 | torso_to_target = self.named.data.site_xpos['target'] - torso_pos 279 | return torso_to_target.dot(torso_frame) 280 | 281 | def ball_to_target_distance(self): 282 | """Returns horizontal distance from the ball to the target.""" 283 | ball_to_target = (self.named.data.site_xpos['target'] - 284 | self.named.data.xpos['ball']) 285 | return np.linalg.norm(ball_to_target[:2]) 286 | 287 | def self_to_ball_distance(self): 288 | """Returns horizontal distance from the quadruped workspace to the ball.""" 289 | self_to_ball = (self.named.data.site_xpos['workspace'] 290 | -self.named.data.xpos['ball']) 291 | return np.linalg.norm(self_to_ball[:2]) 292 | 293 | 294 | def _find_non_contacting_height(physics, orientation, x_pos=0.0, y_pos=0.0): 295 | """Find a height with no contacts given a body orientation. 296 | Args: 297 | physics: An instance of `Physics`. 298 | orientation: A quaternion. 299 | x_pos: A float. Position along global x-axis. 300 | y_pos: A float. Position along global y-axis. 301 | Raises: 302 | RuntimeError: If a non-contacting configuration has not been found after 303 | 10,000 attempts. 304 | """ 305 | z_pos = 0.0 # Start embedded in the floor. 306 | num_contacts = 1 307 | num_attempts = 0 308 | # Move up in 1cm increments until no contacts. 309 | while num_contacts > 0: 310 | try: 311 | with physics.reset_context(): 312 | physics.named.data.qpos['root'][:3] = x_pos, y_pos, z_pos 313 | physics.named.data.qpos['root'][3:] = orientation 314 | except control.PhysicsError: 315 | # We may encounter a PhysicsError here due to filling the contact 316 | # buffer, in which case we simply increment the height and continue. 317 | pass 318 | num_contacts = physics.data.ncon 319 | z_pos += 0.01 320 | num_attempts += 1 321 | if num_attempts > 10000: 322 | raise RuntimeError('Failed to find a non-contacting configuration.') 323 | 324 | 325 | def _common_observations(physics): 326 | """Returns the observations common to all tasks.""" 327 | obs = collections.OrderedDict() 328 | obs['egocentric_state'] = physics.egocentric_state() 329 | obs['torso_velocity'] = physics.torso_velocity() 330 | obs['torso_upright'] = physics.torso_upright() 331 | obs['imu'] = physics.imu() 332 | obs['force_torque'] = physics.force_torque() 333 | return obs 334 | 335 | 336 | def _upright_reward(physics, deviation_angle=0): 337 | """Returns a reward proportional to how upright the torso is. 338 | Args: 339 | physics: an instance of `Physics`. 340 | deviation_angle: A float, in degrees. The reward is 0 when the torso is 341 | exactly upside-down and 1 when the torso's z-axis is less than 342 | `deviation_angle` away from the global z-axis. 343 | """ 344 | deviation = np.cos(np.deg2rad(deviation_angle)) 345 | return rewards.tolerance( 346 | physics.torso_upright(), 347 | bounds=(deviation, float('inf')), 348 | sigmoid='linear', 349 | margin=1 + deviation, 350 | value_at_margin=0) 351 | 352 | 353 | class Move(base.Task): 354 | """A quadruped task solved by moving forward at a designated speed.""" 355 | 356 | def __init__(self, desired_speed, random=None): 357 | """Initializes an instance of `Move`. 358 | Args: 359 | desired_speed: A float. If this value is zero, reward is given simply 360 | for standing upright. Otherwise this specifies the horizontal velocity 361 | at which the velocity-dependent reward component is maximized. 362 | random: Optional, either a `numpy.random.RandomState` instance, an 363 | integer seed for creating a new `RandomState`, or None to select a seed 364 | automatically (default). 365 | """ 366 | self._desired_speed = desired_speed 367 | super().__init__(random=random) 368 | 369 | def initialize_episode(self, physics): 370 | """Sets the state of the environment at the start of each episode. 371 | Args: 372 | physics: An instance of `Physics`. 373 | """ 374 | # Initial configuration. 375 | orientation = self.random.randn(4) 376 | orientation /= np.linalg.norm(orientation) 377 | _find_non_contacting_height(physics, orientation) 378 | super().initialize_episode(physics) 379 | 380 | def get_observation(self, physics): 381 | """Returns an observation to the agent.""" 382 | return _common_observations(physics) 383 | 384 | def get_reward(self, physics): 385 | """Returns a reward to the agent.""" 386 | 387 | # Move reward term. 388 | move_reward = rewards.tolerance( 389 | physics.torso_velocity()[0], 390 | bounds=(self._desired_speed, float('inf')), 391 | margin=self._desired_speed, 392 | value_at_margin=0.5, 393 | sigmoid='linear') 394 | 395 | return _upright_reward(physics) * move_reward 396 | 397 | 398 | class Stand(base.Task): 399 | """A quadruped task solved by moving forward at a designated speed.""" 400 | 401 | def __init__(self, random=None): 402 | """Initializes an instance of `Move`. 403 | Args: 404 | desired_speed: A float. If this value is zero, reward is given simply 405 | for standing upright. Otherwise this specifies the horizontal velocity 406 | at which the velocity-dependent reward component is maximized. 407 | random: Optional, either a `numpy.random.RandomState` instance, an 408 | integer seed for creating a new `RandomState`, or None to select a seed 409 | automatically (default). 410 | """ 411 | super().__init__(random=random) 412 | 413 | def initialize_episode(self, physics): 414 | """Sets the state of the environment at the start of each episode. 415 | Args: 416 | physics: An instance of `Physics`. 417 | """ 418 | # Initial configuration. 419 | orientation = self.random.randn(4) 420 | orientation /= np.linalg.norm(orientation) 421 | _find_non_contacting_height(physics, orientation) 422 | super().initialize_episode(physics) 423 | 424 | def get_observation(self, physics): 425 | """Returns an observation to the agent.""" 426 | return _common_observations(physics) 427 | 428 | def get_reward(self, physics): 429 | """Returns a reward to the agent.""" 430 | 431 | return _upright_reward(physics) 432 | 433 | class Jump(base.Task): 434 | """A quadruped task solved by moving forward at a designated speed.""" 435 | 436 | def __init__(self, desired_height, random=None): 437 | """Initializes an instance of `Move`. 438 | Args: 439 | desired_speed: A float. If this value is zero, reward is given simply 440 | for standing upright. Otherwise this specifies the horizontal velocity 441 | at which the velocity-dependent reward component is maximized. 442 | random: Optional, either a `numpy.random.RandomState` instance, an 443 | integer seed for creating a new `RandomState`, or None to select a seed 444 | automatically (default). 445 | """ 446 | self._desired_height = desired_height 447 | super().__init__(random=random) 448 | 449 | def initialize_episode(self, physics): 450 | """Sets the state of the environment at the start of each episode. 451 | Args: 452 | physics: An instance of `Physics`. 453 | """ 454 | # Initial configuration. 455 | orientation = self.random.randn(4) 456 | orientation /= np.linalg.norm(orientation) 457 | _find_non_contacting_height(physics, orientation) 458 | super().initialize_episode(physics) 459 | 460 | def get_observation(self, physics): 461 | """Returns an observation to the agent.""" 462 | return _common_observations(physics) 463 | 464 | def get_reward(self, physics): 465 | """Returns a reward to the agent.""" 466 | 467 | # Move reward term. 468 | jump_up = rewards.tolerance( 469 | physics.com_height(), 470 | bounds=(self._desired_height, float('inf')), 471 | margin=self._desired_height, 472 | value_at_margin=0.5, 473 | sigmoid='linear') 474 | 475 | return _upright_reward(physics) * jump_up 476 | 477 | 478 | class Roll(base.Task): 479 | """A quadruped task solved by moving forward at a designated speed.""" 480 | 481 | def __init__(self, desired_speed, random=None): 482 | """Initializes an instance of `Move`. 483 | Args: 484 | desired_speed: A float. If this value is zero, reward is given simply 485 | for standing upright. Otherwise this specifies the horizontal velocity 486 | at which the velocity-dependent reward component is maximized. 487 | random: Optional, either a `numpy.random.RandomState` instance, an 488 | integer seed for creating a new `RandomState`, or None to select a seed 489 | automatically (default). 490 | """ 491 | self._desired_speed = desired_speed 492 | super().__init__(random=random) 493 | 494 | def initialize_episode(self, physics): 495 | """Sets the state of the environment at the start of each episode. 496 | Args: 497 | physics: An instance of `Physics`. 498 | """ 499 | # Initial configuration. 500 | orientation = self.random.randn(4) 501 | orientation /= np.linalg.norm(orientation) 502 | _find_non_contacting_height(physics, orientation) 503 | super().initialize_episode(physics) 504 | 505 | def get_observation(self, physics): 506 | """Returns an observation to the agent.""" 507 | return _common_observations(physics) 508 | 509 | def get_reward(self, physics): 510 | """Returns a reward to the agent.""" 511 | # Move reward term. 512 | move_reward = rewards.tolerance( 513 | np.linalg.norm(physics.torso_velocity()), 514 | bounds=(self._desired_speed, float('inf')), 515 | margin=self._desired_speed, 516 | value_at_margin=0.5, 517 | sigmoid='linear') 518 | 519 | return _upright_reward(physics) * move_reward 520 | 521 | 522 | class Escape(base.Task): 523 | """A quadruped task solved by escaping a bowl-shaped terrain.""" 524 | 525 | def initialize_episode(self, physics): 526 | """Sets the state of the environment at the start of each episode. 527 | Args: 528 | physics: An instance of `Physics`. 529 | """ 530 | # Get heightfield resolution, assert that it is square. 531 | res = physics.model.hfield_nrow[_HEIGHTFIELD_ID] 532 | assert res == physics.model.hfield_ncol[_HEIGHTFIELD_ID] 533 | # Sinusoidal bowl shape. 534 | row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j] 535 | radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .04, 1) 536 | bowl_shape = .5 - np.cos(2*np.pi*radius)/2 537 | # Random smooth bumps. 538 | terrain_size = 2 * physics.model.hfield_size[_HEIGHTFIELD_ID, 0] 539 | bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE) 540 | bumps = self.random.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res)) 541 | smooth_bumps = ndimage.zoom(bumps, res / float(bump_res)) 542 | # Terrain is elementwise product. 543 | terrain = bowl_shape * smooth_bumps 544 | start_idx = physics.model.hfield_adr[_HEIGHTFIELD_ID] 545 | physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel() 546 | super().initialize_episode(physics) 547 | 548 | # If we have a rendering context, we need to re-upload the modified 549 | # heightfield data. 550 | if physics.contexts: 551 | with physics.contexts.gl.make_current() as ctx: 552 | ctx.call(mjlib.mjr_uploadHField, 553 | physics.model.ptr, 554 | physics.contexts.mujoco.ptr, 555 | _HEIGHTFIELD_ID) 556 | 557 | # Initial configuration. 558 | orientation = self.random.randn(4) 559 | orientation /= np.linalg.norm(orientation) 560 | _find_non_contacting_height(physics, orientation) 561 | 562 | def get_observation(self, physics): 563 | """Returns an observation to the agent.""" 564 | obs = _common_observations(physics) 565 | obs['origin'] = physics.origin() 566 | obs['rangefinder'] = physics.rangefinder() 567 | return obs 568 | 569 | def get_reward(self, physics): 570 | """Returns a reward to the agent.""" 571 | 572 | # Escape reward term. 573 | terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0] 574 | escape_reward = rewards.tolerance( 575 | physics.origin_distance(), 576 | bounds=(terrain_size, float('inf')), 577 | margin=terrain_size, 578 | value_at_margin=0, 579 | sigmoid='linear') 580 | 581 | return _upright_reward(physics, deviation_angle=20) * escape_reward 582 | 583 | 584 | class Fetch(base.Task): 585 | """A quadruped task solved by bringing a ball to the origin.""" 586 | 587 | def initialize_episode(self, physics): 588 | """Sets the state of the environment at the start of each episode. 589 | Args: 590 | physics: An instance of `Physics`. 591 | """ 592 | # Initial configuration, random azimuth and horizontal position. 593 | azimuth = self.random.uniform(0, 2*np.pi) 594 | orientation = np.array((np.cos(azimuth/2), 0, 0, np.sin(azimuth/2))) 595 | spawn_radius = 0.9 * physics.named.model.geom_size['floor', 0] 596 | x_pos, y_pos = self.random.uniform(-spawn_radius, spawn_radius, size=(2,)) 597 | _find_non_contacting_height(physics, orientation, x_pos, y_pos) 598 | 599 | # Initial ball state. 600 | physics.named.data.qpos['ball_root'][:2] = self.random.uniform( 601 | -spawn_radius, spawn_radius, size=(2,)) 602 | physics.named.data.qpos['ball_root'][2] = 2 603 | physics.named.data.qvel['ball_root'][:2] = 5*self.random.randn(2) 604 | super().initialize_episode(physics) 605 | 606 | def get_observation(self, physics): 607 | """Returns an observation to the agent.""" 608 | obs = _common_observations(physics) 609 | obs['ball_state'] = physics.ball_state() 610 | obs['target_position'] = physics.target_position() 611 | return obs 612 | 613 | def get_reward(self, physics): 614 | """Returns a reward to the agent.""" 615 | 616 | # Reward for moving close to the ball. 617 | arena_radius = physics.named.model.geom_size['floor', 0] * np.sqrt(2) 618 | workspace_radius = physics.named.model.site_size['workspace', 0] 619 | ball_radius = physics.named.model.geom_size['ball', 0] 620 | reach_reward = rewards.tolerance( 621 | physics.self_to_ball_distance(), 622 | bounds=(0, workspace_radius+ball_radius), 623 | sigmoid='linear', 624 | margin=arena_radius, value_at_margin=0) 625 | 626 | # Reward for bringing the ball to the target. 627 | target_radius = physics.named.model.site_size['target', 0] 628 | fetch_reward = rewards.tolerance( 629 | physics.ball_to_target_distance(), 630 | bounds=(0, target_radius), 631 | sigmoid='linear', 632 | margin=arena_radius, value_at_margin=0) 633 | 634 | reach_then_fetch = reach_reward * (0.5 + 0.5*fetch_reward) 635 | 636 | return _upright_reward(physics) * reach_then_fetch --------------------------------------------------------------------------------