├── .gitignore ├── LICENSE ├── README.md ├── TransformLayer.py ├── agents ├── agent_sac.py ├── agent_sac_SPR.py ├── agent_sac_base.py ├── agent_sac_contrastive.py ├── agent_sac_noreward.py ├── agent_sac_norewtotransition.py ├── agent_sac_notransition.py ├── agent_sac_reconstruction.py ├── agent_sac_value.py └── auxiliary_funcs.py ├── curl_sac.py ├── data_augs.py ├── distractors └── driving │ └── 1.mp4 ├── dmc2gym ├── __init__.py ├── natural_imgsource.py └── wrappers.py ├── encoder.py ├── local_dm_control_suite ├── README.md ├── __init__.py ├── acrobot.py ├── acrobot.xml ├── ball_in_cup.py ├── ball_in_cup.xml ├── base.py ├── cartpole.py ├── cartpole.xml ├── cheetah.py ├── cheetah.xml ├── common │ ├── __init__.py │ ├── materials.xml │ ├── materials_white_floor.xml │ ├── skybox.xml │ └── visual.xml ├── demos │ ├── mocap_demo.py │ └── zeros.amc ├── explore.py ├── finger.py ├── finger.xml ├── fish.py ├── fish.xml ├── hopper.py ├── hopper.xml ├── humanoid.py ├── humanoid.xml ├── humanoid_CMU.py ├── humanoid_CMU.xml ├── lqr.py ├── lqr.xml ├── lqr_solver.py ├── manipulator.py ├── manipulator.xml ├── pendulum.py ├── pendulum.xml ├── point_mass.py ├── point_mass.xml ├── quadruped.py ├── quadruped.xml ├── reacher.py ├── reacher.xml ├── stacker.py ├── stacker.xml ├── swimmer.py ├── swimmer.xml ├── tests │ ├── domains_test.py │ ├── loader_test.py │ └── lqr_test.py ├── utils │ ├── __init__.py │ ├── parse_amc.py │ ├── parse_amc_test.py │ ├── randomizers.py │ └── randomizers_test.py ├── walker.py ├── walker.xml └── wrappers │ ├── __init__.py │ ├── action_noise.py │ ├── action_noise_test.py │ ├── pixels.py │ └── pixels_test.py ├── local_dm_control_suite_off_center ├── README.md ├── __init__.py ├── acrobot.py ├── acrobot.xml ├── ball_in_cup.py ├── ball_in_cup.xml ├── base.py ├── cartpole.py ├── cartpole.xml ├── cheetah.py ├── cheetah.xml ├── common │ ├── __init__.py │ ├── materials.xml │ ├── materials_white_floor.xml │ ├── skybox.xml │ └── visual.xml ├── demos │ ├── mocap_demo.py │ └── zeros.amc ├── explore.py ├── finger.py ├── finger.xml ├── fish.py ├── fish.xml ├── hopper.py ├── hopper.xml ├── humanoid.py ├── humanoid.xml ├── humanoid_CMU.py ├── humanoid_CMU.xml ├── lqr.py ├── lqr.xml ├── lqr_solver.py ├── manipulator.py ├── manipulator.xml ├── pendulum.py ├── pendulum.xml ├── point_mass.py ├── point_mass.xml ├── quadruped.py ├── quadruped.xml ├── reacher.py ├── reacher.xml ├── stacker.py ├── stacker.xml ├── swimmer.py ├── swimmer.xml ├── tests │ ├── domains_test.py │ ├── loader_test.py │ └── lqr_test.py ├── utils │ ├── __init__.py │ ├── parse_amc.py │ ├── parse_amc_test.py │ ├── randomizers.py │ └── randomizers_test.py ├── walker.py ├── walker.xml └── wrappers │ ├── __init__.py │ ├── action_noise.py │ ├── action_noise_test.py │ ├── pixels.py │ └── pixels_test.py ├── logger.py ├── multistep_dynamics.py ├── multistep_replay.py ├── multistep_utils.py ├── requirements.txt ├── scripts └── run.sh ├── train.py ├── utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | **/tb/** 3 | **/__pycache__/ 4 | **/tmp/** -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Utkarsh Mishra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /agents/agent_sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.core.numeric import tensordot 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import utils 8 | import data_augs as rad 9 | from agents.auxiliary_funcs import BaseSacAgent 10 | from multistep_dynamics import MultiStepDynamicsModel 11 | import multistep_utils as mutils 12 | 13 | class PixelSacAgent(BaseSacAgent): 14 | """Learning Representations of Pixel Observations with SAC + Self-Supervised Techniques..""" 15 | def __init__( 16 | self, 17 | obs_shape, 18 | action_shape, 19 | horizon, 20 | device, 21 | hidden_dim=256, 22 | discount=0.99, 23 | init_temperature=0.01, 24 | alpha_lr=1e-3, 25 | alpha_beta=0.9, 26 | actor_lr=1e-3, 27 | actor_beta=0.9, 28 | actor_log_std_min=-10, 29 | actor_log_std_max=2, 30 | actor_update_freq=2, 31 | critic_lr=1e-3, 32 | critic_beta=0.9, 33 | critic_tau=0.005, 34 | critic_target_update_freq=2, 35 | encoder_type='pixel', 36 | encoder_feature_dim=50, 37 | encoder_lr=1e-3, 38 | encoder_tau=0.005, 39 | decoder_lr=1e-3, 40 | decoder_weight_lambda=0.0, 41 | num_layers=4, 42 | num_filters=32, 43 | cpc_update_freq=1, 44 | log_interval=100, 45 | detach_encoder=False, 46 | latent_dim=128, 47 | data_augs = '', 48 | use_metric_loss=False 49 | ): 50 | 51 | print('########################################################################') 52 | print('################### Starting Case 0: Baseline Agent ####################') 53 | print('###################### Reward: No; Transition: No ######################') 54 | print('########################################################################') 55 | 56 | super().__init__( 57 | obs_shape, 58 | action_shape, 59 | horizon, 60 | device, 61 | hidden_dim, 62 | discount, 63 | init_temperature, 64 | alpha_lr, 65 | alpha_beta, 66 | actor_lr, 67 | actor_beta, 68 | actor_log_std_min, 69 | actor_log_std_max, 70 | actor_update_freq, 71 | critic_lr, 72 | critic_beta, 73 | critic_tau, 74 | critic_target_update_freq, 75 | encoder_type, 76 | encoder_feature_dim, 77 | encoder_lr, 78 | encoder_tau, 79 | decoder_lr, 80 | decoder_weight_lambda, 81 | num_layers, 82 | num_filters, 83 | cpc_update_freq, 84 | log_interval, 85 | detach_encoder, 86 | latent_dim, 87 | data_augs, 88 | use_metric_loss 89 | ) 90 | 91 | self.train() 92 | self.critic_target.train() 93 | 94 | def update(self, replay_buffer, L, step): 95 | if self.encoder_type == 'pixel': 96 | 97 | batch_obs, batch_action, batch_reward, batch_not_done = replay_buffer.sample_multistep() 98 | 99 | obs = batch_obs[0] 100 | action = batch_action[0] 101 | next_obs = batch_obs[1] 102 | reward = batch_reward[0].unsqueeze(-1) 103 | not_done = batch_not_done[0].unsqueeze(-1) 104 | 105 | self.update_critic(obs, action, reward, next_obs, not_done, L, step) 106 | 107 | encoded_batch_obs = [] 108 | 109 | for en_iter in range(batch_obs.size(0)): 110 | encoded_obs = self.critic.encoder(batch_obs[en_iter]) 111 | encoded_batch_obs.append(encoded_obs) 112 | 113 | encoded_batch_obs = torch.stack(encoded_batch_obs) 114 | 115 | else: 116 | obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio() 117 | self.update_critic(obs, action, reward, next_obs, not_done, L, step) 118 | 119 | if step % self.log_interval == 0: 120 | L.log('train/batch_reward', reward.mean(), step) 121 | 122 | if step % self.actor_update_freq == 0: 123 | self.update_actor_and_alpha(obs, L, step) 124 | 125 | if step % self.critic_target_update_freq == 0: 126 | utils.soft_update_params( 127 | self.critic.Q1, self.critic_target.Q1, self.critic_tau 128 | ) 129 | utils.soft_update_params( 130 | self.critic.Q2, self.critic_target.Q2, self.critic_tau 131 | ) 132 | utils.soft_update_params( 133 | self.critic.encoder, self.critic_target.encoder, 134 | self.encoder_tau 135 | ) -------------------------------------------------------------------------------- /distractors/driving/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UtkarshMishra04/pixel-representations-RL/8f457adcf41eb3b8975eaa5a752736c07b908c63/distractors/driving/1.mp4 -------------------------------------------------------------------------------- /dmc2gym/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.envs.registration import register 3 | 4 | 5 | def make( 6 | domain_name, 7 | task_name, 8 | resource_files, 9 | img_source, 10 | total_frames, 11 | seed=1, 12 | visualize_reward=True, 13 | from_pixels=False, 14 | height=84, 15 | width=84, 16 | camera_id=0, 17 | frame_skip=1, 18 | episode_length=1000, 19 | off_center=False, 20 | environment_kwargs=None 21 | ): 22 | env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed) 23 | 24 | if from_pixels: 25 | assert not visualize_reward, 'cannot use visualize reward when learning from pixels' 26 | 27 | # shorten episode length 28 | max_episode_steps = (episode_length + frame_skip - 1) // frame_skip 29 | 30 | if not env_id in gym.envs.registry.env_specs: 31 | register( 32 | id=env_id, 33 | entry_point='dmc2gym.wrappers:DMCWrapper', 34 | kwargs={ 35 | 'domain_name': domain_name, 36 | 'task_name': task_name, 37 | 'resource_files': resource_files, 38 | 'img_source': img_source, 39 | 'total_frames': total_frames, 40 | 'task_kwargs': { 41 | 'random': seed 42 | }, 43 | 'environment_kwargs': environment_kwargs, 44 | 'visualize_reward': visualize_reward, 45 | 'from_pixels': from_pixels, 46 | 'height': height, 47 | 'width': width, 48 | 'camera_id': camera_id, 49 | 'frame_skip': frame_skip, 50 | 'off_center': off_center, 51 | }, 52 | max_episode_steps=max_episode_steps 53 | ) 54 | return gym.make(env_id) 55 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def tie_weights(src, trg): 6 | assert type(src) == type(trg) 7 | trg.weight = src.weight 8 | trg.bias = src.bias 9 | 10 | 11 | ### Works with Finism SAC 12 | # OUT_DIM = {2: 39, 4: 35, 6: 31} 13 | # OUT_DIM_84 = {2: 39, 4: 35, 6: 31} 14 | # OUT_DIM_100 = {2: 39, 4: 43, 6: 31} 15 | 16 | ### Works with RAD SAC 17 | OUT_DIM = {2: 39, 4: 35, 6: 31} 18 | OUT_DIM_84 = {2: 29, 4: 35, 6: 21} 19 | OUT_DIM_100 = {4: 47} 20 | 21 | 22 | class PixelEncoder(nn.Module): 23 | """Convolutional encoder of pixels observations.""" 24 | def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32,output_logits=False): 25 | super().__init__() 26 | 27 | assert len(obs_shape) == 3 28 | self.obs_shape = obs_shape 29 | self.feature_dim = feature_dim 30 | self.num_layers = num_layers 31 | # try 2 5x5s with strides 2x2. with samep adding, it should reduce 84 to 21, so with valid, it should be even smaller than 21. 32 | self.convs = nn.ModuleList( 33 | [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] 34 | ) 35 | for i in range(num_layers - 1): 36 | self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) 37 | 38 | if obs_shape[-1] == 100: 39 | assert num_layers in OUT_DIM_100 40 | out_dim = OUT_DIM_100[num_layers] 41 | elif obs_shape[-1] == 84: 42 | out_dim = OUT_DIM_84[num_layers] 43 | else: 44 | out_dim = OUT_DIM[num_layers] 45 | 46 | self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) 47 | self.ln = nn.LayerNorm(self.feature_dim) 48 | 49 | self.outputs = dict() 50 | self.output_logits = output_logits 51 | 52 | def reparameterize(self, mu, logstd): 53 | std = torch.exp(logstd) 54 | eps = torch.randn_like(std) 55 | return mu + eps * std 56 | 57 | def forward_conv(self, obs): 58 | if obs.max() > 1.: 59 | obs = obs / 255. 60 | 61 | self.outputs['obs'] = obs 62 | 63 | conv = torch.relu(self.convs[0](obs)) 64 | self.outputs['conv1'] = conv 65 | 66 | for i in range(1, self.num_layers): 67 | conv = torch.relu(self.convs[i](conv)) 68 | self.outputs['conv%s' % (i + 1)] = conv 69 | 70 | h = conv.view(conv.size(0), -1) 71 | 72 | return h 73 | 74 | def forward(self, obs, detach=False): 75 | h = self.forward_conv(obs) 76 | 77 | if detach: 78 | h = h.detach() 79 | 80 | h_fc = self.fc(h) 81 | self.outputs['fc'] = h_fc 82 | 83 | h_norm = self.ln(h_fc) 84 | self.outputs['ln'] = h_norm 85 | 86 | if self.output_logits: 87 | out = h_norm 88 | else: 89 | out = torch.tanh(h_norm) 90 | self.outputs['tanh'] = out 91 | 92 | return out 93 | 94 | def copy_conv_weights_from(self, source): 95 | """Tie convolutional layers""" 96 | # only tie conv layers 97 | for i in range(self.num_layers): 98 | tie_weights(src=source.convs[i], trg=self.convs[i]) 99 | 100 | def log(self, L, step, log_freq): 101 | if step % log_freq != 0: 102 | return 103 | 104 | for k, v in self.outputs.items(): 105 | L.log_histogram('train_encoder/%s_hist' % k, v, step) 106 | if len(v.shape) > 2: 107 | L.log_image('train_encoder/%s_img' % k, v[0], step) 108 | 109 | for i in range(self.num_layers): 110 | L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step) 111 | L.log_param('train_encoder/fc', self.fc, step) 112 | L.log_param('train_encoder/ln', self.ln, step) 113 | 114 | 115 | class IdentityEncoder(nn.Module): 116 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args): 117 | super().__init__() 118 | 119 | assert len(obs_shape) == 1 120 | self.feature_dim = obs_shape[0] 121 | 122 | def forward(self, obs, detach=False): 123 | return obs 124 | 125 | def copy_conv_weights_from(self, source): 126 | pass 127 | 128 | def log(self, L, step, log_freq): 129 | pass 130 | 131 | 132 | _AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder} 133 | 134 | 135 | def make_encoder( 136 | encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False 137 | ): 138 | assert encoder_type in _AVAILABLE_ENCODERS 139 | return _AVAILABLE_ENCODERS[encoder_type]( 140 | obs_shape, feature_dim, num_layers, num_filters, output_logits 141 | ) 142 | -------------------------------------------------------------------------------- /local_dm_control_suite/README.md: -------------------------------------------------------------------------------- 1 | # DeepMind Control Suite. 2 | 3 | This submodule contains the domains and tasks described in the 4 | [DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690). 5 | 6 | ## Quickstart 7 | 8 | ```python 9 | from dm_control import suite 10 | import numpy as np 11 | 12 | # Load one task: 13 | env = suite.load(domain_name="cartpole", task_name="swingup") 14 | 15 | # Iterate over a task set: 16 | for domain_name, task_name in suite.BENCHMARKING: 17 | env = suite.load(domain_name, task_name) 18 | 19 | # Step through an episode and print out reward, discount and observation. 20 | action_spec = env.action_spec() 21 | time_step = env.reset() 22 | while not time_step.last(): 23 | action = np.random.uniform(action_spec.minimum, 24 | action_spec.maximum, 25 | size=action_spec.shape) 26 | time_step = env.step(action) 27 | print(time_step.reward, time_step.discount, time_step.observation) 28 | ``` 29 | 30 | ## Illustration video 31 | 32 | Below is a video montage of solved Control Suite tasks, with reward 33 | visualisation enabled. 34 | 35 | [![Video montage](https://img.youtube.com/vi/rAai4QzcYbs/0.jpg)](https://www.youtube.com/watch?v=rAai4QzcYbs) 36 | 37 | 38 | ### Quadruped domain [April 2019] 39 | 40 | Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are: 41 | 42 | - 4 DoFs per leg, 1 constraining tendon. 43 | - 3 actuators per leg: 'yaw', 'lift', 'extend'. 44 | - Filtered position actuators with timescale of 100ms. 45 | - Sensors include an IMU, force/torque sensors, and rangefinders. 46 | 47 | Four tasks: 48 | 49 | - `walk` and `run`: self-right the body then move forward at a desired speed. 50 | - `escape`: escape a bowl-shaped random terrain (uses rangefinders). 51 | - `fetch`, go to a moving ball and bring it to a target. 52 | 53 | All behaviors in the video below were trained with [Abdolmaleki et al's 54 | MPO](https://arxiv.org/abs/1806.06920). 55 | 56 | [![Video montage](https://img.youtube.com/vi/RhRLjbb7pBE/0.jpg)](https://www.youtube.com/watch?v=RhRLjbb7pBE) 57 | -------------------------------------------------------------------------------- /local_dm_control_suite/acrobot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Acrobot domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.utils import containers 29 | from dm_control.utils import rewards 30 | import numpy as np 31 | 32 | _DEFAULT_TIME_LIMIT = 10 33 | SUITE = containers.TaggedTasks() 34 | 35 | 36 | def get_model_and_assets(): 37 | """Returns a tuple containing the model XML string and a dict of assets.""" 38 | return common.read_model('acrobot.xml'), common.ASSETS 39 | 40 | 41 | @SUITE.add('benchmarking') 42 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None, 43 | environment_kwargs=None): 44 | """Returns Acrobot balance task.""" 45 | physics = Physics.from_xml_string(*get_model_and_assets()) 46 | task = Balance(sparse=False, random=random) 47 | environment_kwargs = environment_kwargs or {} 48 | return control.Environment( 49 | physics, task, time_limit=time_limit, **environment_kwargs) 50 | 51 | 52 | @SUITE.add('benchmarking') 53 | def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None, 54 | environment_kwargs=None): 55 | """Returns Acrobot sparse balance.""" 56 | physics = Physics.from_xml_string(*get_model_and_assets()) 57 | task = Balance(sparse=True, random=random) 58 | environment_kwargs = environment_kwargs or {} 59 | return control.Environment( 60 | physics, task, time_limit=time_limit, **environment_kwargs) 61 | 62 | 63 | class Physics(mujoco.Physics): 64 | """Physics simulation with additional features for the Acrobot domain.""" 65 | 66 | def horizontal(self): 67 | """Returns horizontal (x) component of body frame z-axes.""" 68 | return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz'] 69 | 70 | def vertical(self): 71 | """Returns vertical (z) component of body frame z-axes.""" 72 | return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz'] 73 | 74 | def to_target(self): 75 | """Returns the distance from the tip to the target.""" 76 | tip_to_target = (self.named.data.site_xpos['target'] - 77 | self.named.data.site_xpos['tip']) 78 | return np.linalg.norm(tip_to_target) 79 | 80 | def orientations(self): 81 | """Returns the sines and cosines of the pole angles.""" 82 | return np.concatenate((self.horizontal(), self.vertical())) 83 | 84 | 85 | class Balance(base.Task): 86 | """An Acrobot `Task` to swing up and balance the pole.""" 87 | 88 | def __init__(self, sparse, random=None): 89 | """Initializes an instance of `Balance`. 90 | 91 | Args: 92 | sparse: A `bool` specifying whether to use a sparse (indicator) reward. 93 | random: Optional, either a `numpy.random.RandomState` instance, an 94 | integer seed for creating a new `RandomState`, or None to select a seed 95 | automatically (default). 96 | """ 97 | self._sparse = sparse 98 | super(Balance, self).__init__(random=random) 99 | 100 | def initialize_episode(self, physics): 101 | """Sets the state of the environment at the start of each episode. 102 | 103 | Shoulder and elbow are set to a random position between [-pi, pi). 104 | 105 | Args: 106 | physics: An instance of `Physics`. 107 | """ 108 | physics.named.data.qpos[ 109 | ['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2) 110 | super(Balance, self).initialize_episode(physics) 111 | 112 | def get_observation(self, physics): 113 | """Returns an observation of pole orientation and angular velocities.""" 114 | obs = collections.OrderedDict() 115 | obs['orientations'] = physics.orientations() 116 | obs['velocity'] = physics.velocity() 117 | return obs 118 | 119 | def _get_reward(self, physics, sparse): 120 | target_radius = physics.named.model.site_size['target', 0] 121 | return rewards.tolerance(physics.to_target(), 122 | bounds=(0, target_radius), 123 | margin=0 if sparse else 1) 124 | 125 | def get_reward(self, physics): 126 | """Returns a sparse or a smooth reward, as specified in the constructor.""" 127 | return self._get_reward(physics, sparse=self._sparse) 128 | -------------------------------------------------------------------------------- /local_dm_control_suite/acrobot.xml: -------------------------------------------------------------------------------- 1 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /local_dm_control_suite/ball_in_cup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Ball-in-Cup Domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.utils import containers 29 | 30 | _DEFAULT_TIME_LIMIT = 20 # (seconds) 31 | _CONTROL_TIMESTEP = .02 # (seconds) 32 | 33 | 34 | SUITE = containers.TaggedTasks() 35 | 36 | 37 | def get_model_and_assets(): 38 | """Returns a tuple containing the model XML string and a dict of assets.""" 39 | return common.read_model('ball_in_cup.xml'), common.ASSETS 40 | 41 | 42 | @SUITE.add('benchmarking', 'easy') 43 | def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 44 | """Returns the Ball-in-Cup task.""" 45 | physics = Physics.from_xml_string(*get_model_and_assets()) 46 | task = BallInCup(random=random) 47 | environment_kwargs = environment_kwargs or {} 48 | return control.Environment( 49 | physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP, 50 | **environment_kwargs) 51 | 52 | 53 | class Physics(mujoco.Physics): 54 | """Physics with additional features for the Ball-in-Cup domain.""" 55 | 56 | def ball_to_target(self): 57 | """Returns the vector from the ball to the target.""" 58 | target = self.named.data.site_xpos['target', ['x', 'z']] 59 | ball = self.named.data.xpos['ball', ['x', 'z']] 60 | return target - ball 61 | 62 | def in_target(self): 63 | """Returns 1 if the ball is in the target, 0 otherwise.""" 64 | ball_to_target = abs(self.ball_to_target()) 65 | target_size = self.named.model.site_size['target', [0, 2]] 66 | ball_size = self.named.model.geom_size['ball', 0] 67 | return float(all(ball_to_target < target_size - ball_size)) 68 | 69 | 70 | class BallInCup(base.Task): 71 | """The Ball-in-Cup task. Put the ball in the cup.""" 72 | 73 | def initialize_episode(self, physics): 74 | """Sets the state of the environment at the start of each episode. 75 | 76 | Args: 77 | physics: An instance of `Physics`. 78 | 79 | """ 80 | # Find a collision-free random initial position of the ball. 81 | penetrating = True 82 | while penetrating: 83 | # Assign a random ball position. 84 | physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2) 85 | physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5) 86 | # Check for collisions. 87 | physics.after_reset() 88 | penetrating = physics.data.ncon > 0 89 | super(BallInCup, self).initialize_episode(physics) 90 | 91 | def get_observation(self, physics): 92 | """Returns an observation of the state.""" 93 | obs = collections.OrderedDict() 94 | obs['position'] = physics.position() 95 | obs['velocity'] = physics.velocity() 96 | return obs 97 | 98 | def get_reward(self, physics): 99 | """Returns a sparse reward.""" 100 | return physics.in_target() 101 | -------------------------------------------------------------------------------- /local_dm_control_suite/ball_in_cup.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /local_dm_control_suite/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Base class for tasks in the Control Suite.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from dm_control import mujoco 23 | from dm_control.rl import control 24 | 25 | import numpy as np 26 | 27 | 28 | class Task(control.Task): 29 | """Base class for tasks in the Control Suite. 30 | 31 | Actions are mapped directly to the states of MuJoCo actuators: each element of 32 | the action array is used to set the control input for a single actuator. The 33 | ordering of the actuators is the same as in the corresponding MJCF XML file. 34 | 35 | Attributes: 36 | random: A `numpy.random.RandomState` instance. This should be used to 37 | generate all random variables associated with the task, such as random 38 | starting states, observation noise* etc. 39 | 40 | *If sensor noise is enabled in the MuJoCo model then this will be generated 41 | using MuJoCo's internal RNG, which has its own independent state. 42 | """ 43 | 44 | def __init__(self, random=None): 45 | """Initializes a new continuous control task. 46 | 47 | Args: 48 | random: Optional, either a `numpy.random.RandomState` instance, an integer 49 | seed for creating a new `RandomState`, or None to select a seed 50 | automatically (default). 51 | """ 52 | if not isinstance(random, np.random.RandomState): 53 | random = np.random.RandomState(random) 54 | self._random = random 55 | self._visualize_reward = False 56 | 57 | @property 58 | def random(self): 59 | """Task-specific `numpy.random.RandomState` instance.""" 60 | return self._random 61 | 62 | def action_spec(self, physics): 63 | """Returns a `BoundedArraySpec` matching the `physics` actuators.""" 64 | return mujoco.action_spec(physics) 65 | 66 | def initialize_episode(self, physics): 67 | """Resets geom colors to their defaults after starting a new episode. 68 | 69 | Subclasses of `base.Task` must delegate to this method after performing 70 | their own initialization. 71 | 72 | Args: 73 | physics: An instance of `mujoco.Physics`. 74 | """ 75 | self.after_step(physics) 76 | 77 | def before_step(self, action, physics): 78 | """Sets the control signal for the actuators to values in `action`.""" 79 | # Support legacy internal code. 80 | action = getattr(action, "continuous_actions", action) 81 | physics.set_control(action) 82 | 83 | def after_step(self, physics): 84 | """Modifies colors according to the reward.""" 85 | if self._visualize_reward: 86 | reward = np.clip(self.get_reward(physics), 0.0, 1.0) 87 | _set_reward_colors(physics, reward) 88 | 89 | @property 90 | def visualize_reward(self): 91 | return self._visualize_reward 92 | 93 | @visualize_reward.setter 94 | def visualize_reward(self, value): 95 | if not isinstance(value, bool): 96 | raise ValueError("Expected a boolean, got {}.".format(type(value))) 97 | self._visualize_reward = value 98 | 99 | 100 | _MATERIALS = ["self", "effector", "target"] 101 | _DEFAULT = [name + "_default" for name in _MATERIALS] 102 | _HIGHLIGHT = [name + "_highlight" for name in _MATERIALS] 103 | 104 | 105 | def _set_reward_colors(physics, reward): 106 | """Sets the highlight, effector and target colors according to the reward.""" 107 | assert 0.0 <= reward <= 1.0 108 | colors = physics.named.model.mat_rgba 109 | default = colors[_DEFAULT] 110 | highlight = colors[_HIGHLIGHT] 111 | blend_coef = reward ** 4 # Better color distinction near high rewards. 112 | colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default 113 | -------------------------------------------------------------------------------- /local_dm_control_suite/cartpole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /local_dm_control_suite/cheetah.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Cheetah Domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.utils import containers 29 | from dm_control.utils import rewards 30 | 31 | 32 | # How long the simulation will run, in seconds. 33 | _DEFAULT_TIME_LIMIT = 10 34 | 35 | # Running speed above which reward is 1. 36 | _RUN_SPEED = 10 37 | 38 | SUITE = containers.TaggedTasks() 39 | 40 | 41 | def get_model_and_assets(): 42 | """Returns a tuple containing the model XML string and a dict of assets.""" 43 | return common.read_model('cheetah.xml'), common.ASSETS 44 | 45 | 46 | @SUITE.add('benchmarking') 47 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 48 | """Returns the run task.""" 49 | physics = Physics.from_xml_string(*get_model_and_assets()) 50 | task = Cheetah(random=random) 51 | environment_kwargs = environment_kwargs or {} 52 | return control.Environment(physics, task, time_limit=time_limit, 53 | **environment_kwargs) 54 | 55 | 56 | class Physics(mujoco.Physics): 57 | """Physics simulation with additional features for the Cheetah domain.""" 58 | 59 | def speed(self): 60 | """Returns the horizontal speed of the Cheetah.""" 61 | return self.named.data.sensordata['torso_subtreelinvel'][0] 62 | 63 | 64 | class Cheetah(base.Task): 65 | """A `Task` to train a running Cheetah.""" 66 | 67 | def initialize_episode(self, physics): 68 | """Sets the state of the environment at the start of each episode.""" 69 | # The indexing below assumes that all joints have a single DOF. 70 | assert physics.model.nq == physics.model.njnt 71 | is_limited = physics.model.jnt_limited == 1 72 | lower, upper = physics.model.jnt_range[is_limited].T 73 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper) 74 | 75 | # Stabilize the model before the actual simulation. 76 | for _ in range(200): 77 | physics.step() 78 | 79 | physics.data.time = 0 80 | self._timeout_progress = 0 81 | super(Cheetah, self).initialize_episode(physics) 82 | 83 | def get_observation(self, physics): 84 | """Returns an observation of the state, ignoring horizontal position.""" 85 | obs = collections.OrderedDict() 86 | # Ignores horizontal position to maintain translational invariance. 87 | obs['position'] = physics.data.qpos[1:].copy() 88 | obs['velocity'] = physics.velocity() 89 | return obs 90 | 91 | def get_reward(self, physics): 92 | """Returns a reward to the agent.""" 93 | return rewards.tolerance(physics.speed(), 94 | bounds=(_RUN_SPEED, float('inf')), 95 | margin=_RUN_SPEED, 96 | value_at_margin=0, 97 | sigmoid='linear') 98 | -------------------------------------------------------------------------------- /local_dm_control_suite/cheetah.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 74 | -------------------------------------------------------------------------------- /local_dm_control_suite/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Functions to manage the common assets for domains.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | from dm_control.utils import io as resources 24 | 25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__)) 26 | _FILENAMES = [ 27 | "./common/materials.xml", 28 | "./common/materials_white_floor.xml", 29 | "./common/skybox.xml", 30 | "./common/visual.xml", 31 | ] 32 | 33 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename)) 34 | for filename in _FILENAMES} 35 | 36 | 37 | def read_model(model_filename): 38 | """Reads a model XML file and returns its contents as a string.""" 39 | return resources.GetResource(os.path.join(_SUITE_DIR, model_filename)) 40 | -------------------------------------------------------------------------------- /local_dm_control_suite/common/materials.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /local_dm_control_suite/common/materials_white_floor.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /local_dm_control_suite/common/skybox.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /local_dm_control_suite/common/visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /local_dm_control_suite/demos/mocap_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Demonstration of amc parsing for CMU mocap database. 17 | 18 | To run the demo, supply a path to a `.amc` file: 19 | 20 | python mocap_demo --filename='path/to/mocap.amc' 21 | 22 | CMU motion capture clips are available at mocap.cs.cmu.edu 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import time 30 | # Internal dependencies. 31 | 32 | from absl import app 33 | from absl import flags 34 | 35 | from local_dm_control_suite import humanoid_CMU 36 | from dm_control.suite.utils import parse_amc 37 | 38 | import matplotlib.pyplot as plt 39 | import numpy as np 40 | 41 | FLAGS = flags.FLAGS 42 | flags.DEFINE_string('filename', None, 'amc file to be converted.') 43 | flags.DEFINE_integer('max_num_frames', 90, 44 | 'Maximum number of frames for plotting/playback') 45 | 46 | 47 | def main(unused_argv): 48 | env = humanoid_CMU.stand() 49 | 50 | # Parse and convert specified clip. 51 | converted = parse_amc.convert(FLAGS.filename, 52 | env.physics, env.control_timestep()) 53 | 54 | max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1) 55 | 56 | width = 480 57 | height = 480 58 | video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8) 59 | 60 | for i in range(max_frame): 61 | p_i = converted.qpos[:, i] 62 | with env.physics.reset_context(): 63 | env.physics.data.qpos[:] = p_i 64 | video[i] = np.hstack([env.physics.render(height, width, camera_id=0), 65 | env.physics.render(height, width, camera_id=1)]) 66 | 67 | tic = time.time() 68 | for i in range(max_frame): 69 | if i == 0: 70 | img = plt.imshow(video[i]) 71 | else: 72 | img.set_data(video[i]) 73 | toc = time.time() 74 | clock_dt = toc - tic 75 | tic = time.time() 76 | # Real-time playback not always possible as clock_dt > .03 77 | plt.pause(max(0.01, 0.03 - clock_dt)) # Need min display time > 0.0. 78 | plt.draw() 79 | plt.waitforbuttonpress() 80 | 81 | 82 | if __name__ == '__main__': 83 | flags.mark_flag_as_required('filename') 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /local_dm_control_suite/demos/zeros.amc: -------------------------------------------------------------------------------- 1 | #DUMMY AMC for testing 2 | :FULLY-SPECIFIED 3 | :DEGREES 4 | 1 5 | root 0 0 0 0 0 0 6 | lowerback 0 0 0 7 | upperback 0 0 0 8 | thorax 0 0 0 9 | lowerneck 0 0 0 10 | upperneck 0 0 0 11 | head 0 0 0 12 | rclavicle 0 0 13 | rhumerus 0 0 0 14 | rradius 0 15 | rwrist 0 16 | rhand 0 0 17 | rfingers 0 18 | rthumb 0 0 19 | lclavicle 0 0 20 | lhumerus 0 0 0 21 | lradius 0 22 | lwrist 0 23 | lhand 0 0 24 | lfingers 0 25 | lthumb 0 0 26 | rfemur 0 0 0 27 | rtibia 0 28 | rfoot 0 0 29 | rtoes 0 30 | lfemur 0 0 0 31 | ltibia 0 32 | lfoot 0 0 33 | ltoes 0 34 | 2 35 | root 0 0 0 0 0 0 36 | lowerback 0 0 0 37 | upperback 0 0 0 38 | thorax 0 0 0 39 | lowerneck 0 0 0 40 | upperneck 0 0 0 41 | head 0 0 0 42 | rclavicle 0 0 43 | rhumerus 0 0 0 44 | rradius 0 45 | rwrist 0 46 | rhand 0 0 47 | rfingers 0 48 | rthumb 0 0 49 | lclavicle 0 0 50 | lhumerus 0 0 0 51 | lradius 0 52 | lwrist 0 53 | lhand 0 0 54 | lfingers 0 55 | lthumb 0 0 56 | rfemur 0 0 0 57 | rtibia 0 58 | rfoot 0 0 59 | rtoes 0 60 | lfemur 0 0 0 61 | ltibia 0 62 | lfoot 0 0 63 | ltoes 0 64 | 3 65 | root 0 0 0 0 0 0 66 | lowerback 0 0 0 67 | upperback 0 0 0 68 | thorax 0 0 0 69 | lowerneck 0 0 0 70 | upperneck 0 0 0 71 | head 0 0 0 72 | rclavicle 0 0 73 | rhumerus 0 0 0 74 | rradius 0 75 | rwrist 0 76 | rhand 0 0 77 | rfingers 0 78 | rthumb 0 0 79 | lclavicle 0 0 80 | lhumerus 0 0 0 81 | lradius 0 82 | lwrist 0 83 | lhand 0 0 84 | lfingers 0 85 | lthumb 0 0 86 | rfemur 0 0 0 87 | rtibia 0 88 | rfoot 0 0 89 | rtoes 0 90 | lfemur 0 0 0 91 | ltibia 0 92 | lfoot 0 0 93 | ltoes 0 94 | 4 95 | root 0 0 0 0 0 0 96 | lowerback 0 0 0 97 | upperback 0 0 0 98 | thorax 0 0 0 99 | lowerneck 0 0 0 100 | upperneck 0 0 0 101 | head 0 0 0 102 | rclavicle 0 0 103 | rhumerus 0 0 0 104 | rradius 0 105 | rwrist 0 106 | rhand 0 0 107 | rfingers 0 108 | rthumb 0 0 109 | lclavicle 0 0 110 | lhumerus 0 0 0 111 | lradius 0 112 | lwrist 0 113 | lhand 0 0 114 | lfingers 0 115 | lthumb 0 0 116 | rfemur 0 0 0 117 | rtibia 0 118 | rfoot 0 0 119 | rtoes 0 120 | lfemur 0 0 0 121 | ltibia 0 122 | lfoot 0 0 123 | ltoes 0 124 | 5 125 | root 0 0 0 0 0 0 126 | lowerback 0 0 0 127 | upperback 0 0 0 128 | thorax 0 0 0 129 | lowerneck 0 0 0 130 | upperneck 0 0 0 131 | head 0 0 0 132 | rclavicle 0 0 133 | rhumerus 0 0 0 134 | rradius 0 135 | rwrist 0 136 | rhand 0 0 137 | rfingers 0 138 | rthumb 0 0 139 | lclavicle 0 0 140 | lhumerus 0 0 0 141 | lradius 0 142 | lwrist 0 143 | lhand 0 0 144 | lfingers 0 145 | lthumb 0 0 146 | rfemur 0 0 0 147 | rtibia 0 148 | rfoot 0 0 149 | rtoes 0 150 | lfemur 0 0 0 151 | ltibia 0 152 | lfoot 0 0 153 | ltoes 0 154 | 6 155 | root 0 0 0 0 0 0 156 | lowerback 0 0 0 157 | upperback 0 0 0 158 | thorax 0 0 0 159 | lowerneck 0 0 0 160 | upperneck 0 0 0 161 | head 0 0 0 162 | rclavicle 0 0 163 | rhumerus 0 0 0 164 | rradius 0 165 | rwrist 0 166 | rhand 0 0 167 | rfingers 0 168 | rthumb 0 0 169 | lclavicle 0 0 170 | lhumerus 0 0 0 171 | lradius 0 172 | lwrist 0 173 | lhand 0 0 174 | lfingers 0 175 | lthumb 0 0 176 | rfemur 0 0 0 177 | rtibia 0 178 | rfoot 0 0 179 | rtoes 0 180 | lfemur 0 0 0 181 | ltibia 0 182 | lfoot 0 0 183 | ltoes 0 184 | 7 185 | root 0 0 0 0 0 0 186 | lowerback 0 0 0 187 | upperback 0 0 0 188 | thorax 0 0 0 189 | lowerneck 0 0 0 190 | upperneck 0 0 0 191 | head 0 0 0 192 | rclavicle 0 0 193 | rhumerus 0 0 0 194 | rradius 0 195 | rwrist 0 196 | rhand 0 0 197 | rfingers 0 198 | rthumb 0 0 199 | lclavicle 0 0 200 | lhumerus 0 0 0 201 | lradius 0 202 | lwrist 0 203 | lhand 0 0 204 | lfingers 0 205 | lthumb 0 0 206 | rfemur 0 0 0 207 | rtibia 0 208 | rfoot 0 0 209 | rtoes 0 210 | lfemur 0 0 0 211 | ltibia 0 212 | lfoot 0 0 213 | ltoes 0 214 | -------------------------------------------------------------------------------- /local_dm_control_suite/explore.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Control suite environments explorer.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import app 22 | from absl import flags 23 | from dm_control import suite 24 | from dm_control.suite.wrappers import action_noise 25 | from six.moves import input 26 | 27 | from dm_control import viewer 28 | 29 | 30 | _ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS] 31 | 32 | flags.DEFINE_enum('environment_name', None, _ALL_NAMES, 33 | 'Optional \'domain_name.task_name\' pair specifying the ' 34 | 'environment to load. If unspecified a prompt will appear to ' 35 | 'select one.') 36 | flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.') 37 | flags.DEFINE_bool('visualize_reward', True, 38 | 'Whether to vary the colors of geoms according to the ' 39 | 'current reward value.') 40 | flags.DEFINE_float('action_noise', 0., 41 | 'Standard deviation of Gaussian noise to apply to actions, ' 42 | 'expressed as a fraction of the max-min range for each ' 43 | 'action dimension. Defaults to 0, i.e. no noise.') 44 | FLAGS = flags.FLAGS 45 | 46 | 47 | def prompt_environment_name(prompt, values): 48 | environment_name = None 49 | while not environment_name: 50 | environment_name = input(prompt) 51 | if not environment_name or values.index(environment_name) < 0: 52 | print('"%s" is not a valid environment name.' % environment_name) 53 | environment_name = None 54 | return environment_name 55 | 56 | 57 | def main(argv): 58 | del argv 59 | environment_name = FLAGS.environment_name 60 | if environment_name is None: 61 | print('\n '.join(['Available environments:'] + _ALL_NAMES)) 62 | environment_name = prompt_environment_name( 63 | 'Please select an environment name: ', _ALL_NAMES) 64 | 65 | index = _ALL_NAMES.index(environment_name) 66 | domain_name, task_name = suite.ALL_TASKS[index] 67 | 68 | task_kwargs = {} 69 | if not FLAGS.timeout: 70 | task_kwargs['time_limit'] = float('inf') 71 | 72 | def loader(): 73 | env = suite.load( 74 | domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs) 75 | env.task.visualize_reward = FLAGS.visualize_reward 76 | if FLAGS.action_noise > 0: 77 | env = action_noise.Wrapper(env, scale=FLAGS.action_noise) 78 | return env 79 | 80 | viewer.launch(loader) 81 | 82 | 83 | if __name__ == '__main__': 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /local_dm_control_suite/finger.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /local_dm_control_suite/fish.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /local_dm_control_suite/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 67 | -------------------------------------------------------------------------------- /local_dm_control_suite/lqr.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /local_dm_control_suite/lqr_solver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | r"""Optimal policy for LQR levels. 17 | 18 | LQR control problem is described in 19 | https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | from absl import logging 27 | from dm_control.mujoco import wrapper 28 | import numpy as np 29 | from six.moves import range 30 | 31 | try: 32 | import scipy.linalg as sp # pylint: disable=g-import-not-at-top 33 | except ImportError: 34 | sp = None 35 | 36 | 37 | def _solve_dare(a, b, q, r): 38 | """Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration. 39 | 40 | Algebraic Riccati Equation: 41 | ```none 42 | P_{t-1} = Q + A' * P_{t} * A - 43 | A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A 44 | ``` 45 | 46 | Args: 47 | a: A 2 dimensional numpy array, transition matrix A. 48 | b: A 2 dimensional numpy array, control matrix B. 49 | q: A 2 dimensional numpy array, symmetric positive definite cost matrix. 50 | r: A 2 dimensional numpy array, symmetric positive definite cost matrix 51 | 52 | Returns: 53 | A numpy array, a real symmetric matrix P which is the solution to DARE. 54 | 55 | Raises: 56 | RuntimeError: If the computed P matrix is not symmetric and 57 | positive-definite. 58 | """ 59 | p = np.eye(len(a)) 60 | for _ in range(1000000): 61 | a_p = a.T.dot(p) # A' * P_t 62 | a_p_b = np.dot(a_p, b) # A' * P_t * B 63 | # Algebraic Riccati Equation. 64 | p_next = q + np.dot(a_p, a) - a_p_b.dot( 65 | np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T)) 66 | p_next += p_next.T 67 | p_next *= .5 68 | if np.abs(p - p_next).max() < 1e-12: 69 | break 70 | p = p_next 71 | else: 72 | logging.warning('DARE solver did not converge') 73 | try: 74 | # Check that the result is symmetric and positive-definite. 75 | np.linalg.cholesky(p_next) 76 | except np.linalg.LinAlgError: 77 | raise RuntimeError('ARE solver failed: P matrix is not symmetric and ' 78 | 'positive-definite.') 79 | return p_next 80 | 81 | 82 | def solve(env): 83 | """Returns the optimal value and policy for LQR problem. 84 | 85 | Args: 86 | env: An instance of `control.EnvironmentV2` with LQR level. 87 | 88 | Returns: 89 | p: A numpy array, the Hessian of the optimal total cost-to-go (value 90 | function at state x) is V(x) = .5 * x' * p * x. 91 | k: A numpy array which gives the optimal linear policy u = k * x. 92 | beta: The maximum eigenvalue of (a + b * k). Under optimal policy, at 93 | timestep n the state tends to 0 like beta^n. 94 | 95 | Raises: 96 | RuntimeError: If the controlled system is unstable. 97 | """ 98 | n = env.physics.model.nq # number of DoFs 99 | m = env.physics.model.nu # number of controls 100 | 101 | # Compute the mass matrix. 102 | mass = np.zeros((n, n)) 103 | wrapper.mjbindings.mjlib.mj_fullM(env.physics.model.ptr, mass, 104 | env.physics.data.qM) 105 | 106 | # Compute input matrices a, b, q and r to the DARE solvers. 107 | # State transition matrix a. 108 | stiffness = np.diag(env.physics.model.jnt_stiffness.ravel()) 109 | damping = np.diag(env.physics.model.dof_damping.ravel()) 110 | dt = env.physics.model.opt.timestep 111 | 112 | j = np.linalg.solve(-mass, np.hstack((stiffness, damping))) 113 | a = np.eye(2 * n) + dt * np.vstack( 114 | (dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j)) 115 | 116 | # Control transition matrix b. 117 | b = env.physics.data.actuator_moment.T 118 | bc = np.linalg.solve(mass, b) 119 | b = dt * np.vstack((dt * bc, bc)) 120 | 121 | # State cost Hessian q. 122 | q = np.diag(np.hstack([np.ones(n), np.zeros(n)])) 123 | 124 | # Control cost Hessian r. 125 | r = env.task.control_cost_coef * np.eye(m) 126 | 127 | if sp: 128 | # Use scipy's faster DARE solver if available. 129 | solve_dare = sp.solve_discrete_are 130 | else: 131 | # Otherwise fall back on a slower internal implementation. 132 | solve_dare = _solve_dare 133 | 134 | # Solve the discrete algebraic Riccati equation. 135 | p = solve_dare(a, b, q, r) 136 | k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a))) 137 | 138 | # Under optimal policy, state tends to 0 like beta^n_timesteps 139 | beta = np.abs(np.linalg.eigvals(a + b.dot(k))).max() 140 | if beta >= 1.0: 141 | raise RuntimeError('Controlled system is unstable.') 142 | return p, k, beta 143 | -------------------------------------------------------------------------------- /local_dm_control_suite/pendulum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Pendulum domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.utils import containers 29 | from dm_control.utils import rewards 30 | import numpy as np 31 | 32 | 33 | _DEFAULT_TIME_LIMIT = 20 34 | _ANGLE_BOUND = 8 35 | _COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND)) 36 | SUITE = containers.TaggedTasks() 37 | 38 | 39 | def get_model_and_assets(): 40 | """Returns a tuple containing the model XML string and a dict of assets.""" 41 | return common.read_model('pendulum.xml'), common.ASSETS 42 | 43 | 44 | @SUITE.add('benchmarking') 45 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None, 46 | environment_kwargs=None): 47 | """Returns pendulum swingup task .""" 48 | physics = Physics.from_xml_string(*get_model_and_assets()) 49 | task = SwingUp(random=random) 50 | environment_kwargs = environment_kwargs or {} 51 | return control.Environment( 52 | physics, task, time_limit=time_limit, **environment_kwargs) 53 | 54 | 55 | class Physics(mujoco.Physics): 56 | """Physics simulation with additional features for the Pendulum domain.""" 57 | 58 | def pole_vertical(self): 59 | """Returns vertical (z) component of pole frame.""" 60 | return self.named.data.xmat['pole', 'zz'] 61 | 62 | def angular_velocity(self): 63 | """Returns the angular velocity of the pole.""" 64 | return self.named.data.qvel['hinge'].copy() 65 | 66 | def pole_orientation(self): 67 | """Returns both horizontal and vertical components of pole frame.""" 68 | return self.named.data.xmat['pole', ['zz', 'xz']] 69 | 70 | 71 | class SwingUp(base.Task): 72 | """A Pendulum `Task` to swing up and balance the pole.""" 73 | 74 | def __init__(self, random=None): 75 | """Initialize an instance of `Pendulum`. 76 | 77 | Args: 78 | random: Optional, either a `numpy.random.RandomState` instance, an 79 | integer seed for creating a new `RandomState`, or None to select a seed 80 | automatically (default). 81 | """ 82 | super(SwingUp, self).__init__(random=random) 83 | 84 | def initialize_episode(self, physics): 85 | """Sets the state of the environment at the start of each episode. 86 | 87 | Pole is set to a random angle between [-pi, pi). 88 | 89 | Args: 90 | physics: An instance of `Physics`. 91 | 92 | """ 93 | physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi) 94 | super(SwingUp, self).initialize_episode(physics) 95 | 96 | def get_observation(self, physics): 97 | """Returns an observation. 98 | 99 | Observations are states concatenating pole orientation and angular velocity 100 | and pixels from fixed camera. 101 | 102 | Args: 103 | physics: An instance of `physics`, Pendulum physics. 104 | 105 | Returns: 106 | A `dict` of observation. 107 | """ 108 | obs = collections.OrderedDict() 109 | obs['orientation'] = physics.pole_orientation() 110 | obs['velocity'] = physics.angular_velocity() 111 | return obs 112 | 113 | def get_reward(self, physics): 114 | return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1)) 115 | -------------------------------------------------------------------------------- /local_dm_control_suite/pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /local_dm_control_suite/point_mass.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /local_dm_control_suite/reacher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Reacher domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | import numpy as np 32 | 33 | SUITE = containers.TaggedTasks() 34 | _DEFAULT_TIME_LIMIT = 20 35 | _BIG_TARGET = .05 36 | _SMALL_TARGET = .015 37 | 38 | 39 | def get_model_and_assets(): 40 | """Returns a tuple containing the model XML string and a dict of assets.""" 41 | return common.read_model('reacher.xml'), common.ASSETS 42 | 43 | 44 | @SUITE.add('benchmarking', 'easy') 45 | def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 46 | """Returns reacher with sparse reward with 5e-2 tol and randomized target.""" 47 | physics = Physics.from_xml_string(*get_model_and_assets()) 48 | task = Reacher(target_size=_BIG_TARGET, random=random) 49 | environment_kwargs = environment_kwargs or {} 50 | return control.Environment( 51 | physics, task, time_limit=time_limit, **environment_kwargs) 52 | 53 | 54 | @SUITE.add('benchmarking') 55 | def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 56 | """Returns reacher with sparse reward with 1e-2 tol and randomized target.""" 57 | physics = Physics.from_xml_string(*get_model_and_assets()) 58 | task = Reacher(target_size=_SMALL_TARGET, random=random) 59 | environment_kwargs = environment_kwargs or {} 60 | return control.Environment( 61 | physics, task, time_limit=time_limit, **environment_kwargs) 62 | 63 | 64 | class Physics(mujoco.Physics): 65 | """Physics simulation with additional features for the Reacher domain.""" 66 | 67 | def finger_to_target(self): 68 | """Returns the vector from target to finger in global coordinates.""" 69 | return (self.named.data.geom_xpos['target', :2] - 70 | self.named.data.geom_xpos['finger', :2]) 71 | 72 | def finger_to_target_dist(self): 73 | """Returns the signed distance between the finger and target surface.""" 74 | return np.linalg.norm(self.finger_to_target()) 75 | 76 | 77 | class Reacher(base.Task): 78 | """A reacher `Task` to reach the target.""" 79 | 80 | def __init__(self, target_size, random=None): 81 | """Initialize an instance of `Reacher`. 82 | 83 | Args: 84 | target_size: A `float`, tolerance to determine whether finger reached the 85 | target. 86 | random: Optional, either a `numpy.random.RandomState` instance, an 87 | integer seed for creating a new `RandomState`, or None to select a seed 88 | automatically (default). 89 | """ 90 | self._target_size = target_size 91 | super(Reacher, self).__init__(random=random) 92 | 93 | def initialize_episode(self, physics): 94 | """Sets the state of the environment at the start of each episode.""" 95 | physics.named.model.geom_size['target', 0] = self._target_size 96 | randomizers.randomize_limited_and_rotational_joints(physics, self.random) 97 | 98 | # Randomize target position 99 | angle = self.random.uniform(0, 2 * np.pi) 100 | radius = self.random.uniform(.05, .20) 101 | physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle) 102 | physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle) 103 | 104 | super(Reacher, self).initialize_episode(physics) 105 | 106 | def get_observation(self, physics): 107 | """Returns an observation of the state and the target position.""" 108 | obs = collections.OrderedDict() 109 | obs['position'] = physics.position() 110 | obs['to_target'] = physics.finger_to_target() 111 | obs['velocity'] = physics.velocity() 112 | return obs 113 | 114 | def get_reward(self, physics): 115 | radii = physics.named.model.geom_size[['target', 'finger'], 0].sum() 116 | return rewards.tolerance(physics.finger_to_target_dist(), (0, radii)) 117 | -------------------------------------------------------------------------------- /local_dm_control_suite/reacher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /local_dm_control_suite/swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /local_dm_control_suite/tests/loader_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for the dm_control.suite loader.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Internal dependencies. 23 | 24 | from absl.testing import absltest 25 | 26 | from dm_control import suite 27 | from dm_control.rl import control 28 | 29 | 30 | class LoaderTest(absltest.TestCase): 31 | 32 | def test_load_without_kwargs(self): 33 | env = suite.load('cartpole', 'swingup') 34 | self.assertIsInstance(env, control.Environment) 35 | 36 | def test_load_with_kwargs(self): 37 | env = suite.load('cartpole', 'swingup', 38 | task_kwargs={'time_limit': 40, 'random': 99}) 39 | self.assertIsInstance(env, control.Environment) 40 | 41 | 42 | class LoaderConstantsTest(absltest.TestCase): 43 | 44 | def testSuiteConstants(self): 45 | self.assertNotEmpty(suite.BENCHMARKING) 46 | self.assertNotEmpty(suite.EASY) 47 | self.assertNotEmpty(suite.HARD) 48 | self.assertNotEmpty(suite.EXTRA) 49 | 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /local_dm_control_suite/tests/lqr_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests specific to the LQR domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | import unittest 24 | 25 | # Internal dependencies. 26 | from absl import logging 27 | 28 | from absl.testing import absltest 29 | from absl.testing import parameterized 30 | 31 | from local_dm_control_suite import lqr 32 | from local_dm_control_suite import lqr_solver 33 | 34 | import numpy as np 35 | from six.moves import range 36 | 37 | 38 | class LqrTest(parameterized.TestCase): 39 | 40 | @parameterized.named_parameters( 41 | ('lqr_2_1', lqr.lqr_2_1), 42 | ('lqr_6_2', lqr.lqr_6_2)) 43 | def test_lqr_optimal_policy(self, make_env): 44 | env = make_env() 45 | p, k, beta = lqr_solver.solve(env) 46 | self.assertPolicyisOptimal(env, p, k, beta) 47 | 48 | @parameterized.named_parameters( 49 | ('lqr_2_1', lqr.lqr_2_1), 50 | ('lqr_6_2', lqr.lqr_6_2)) 51 | @unittest.skipUnless( 52 | condition=lqr_solver.sp, 53 | reason='scipy is not available, so non-scipy DARE solver is the default.') 54 | def test_lqr_optimal_policy_no_scipy(self, make_env): 55 | env = make_env() 56 | old_sp = lqr_solver.sp 57 | try: 58 | lqr_solver.sp = None # Force the solver to use the non-scipy code path. 59 | p, k, beta = lqr_solver.solve(env) 60 | finally: 61 | lqr_solver.sp = old_sp 62 | self.assertPolicyisOptimal(env, p, k, beta) 63 | 64 | def assertPolicyisOptimal(self, env, p, k, beta): 65 | tolerance = 1e-3 66 | n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta))) 67 | logging.info('%d timesteps for %g convergence.', n_steps, tolerance) 68 | total_loss = 0.0 69 | 70 | timestep = env.reset() 71 | initial_state = np.hstack((timestep.observation['position'], 72 | timestep.observation['velocity'])) 73 | logging.info('Measuring total cost over %d steps.', n_steps) 74 | for _ in range(n_steps): 75 | x = np.hstack((timestep.observation['position'], 76 | timestep.observation['velocity'])) 77 | # u = k*x is the optimal policy 78 | u = k.dot(x) 79 | total_loss += 1 - (timestep.reward or 0.0) 80 | timestep = env.step(u) 81 | 82 | logging.info('Analytical expected total cost is .5*x^T*p*x.') 83 | expected_loss = .5 * initial_state.T.dot(p).dot(initial_state) 84 | logging.info('Comparing measured and predicted costs.') 85 | np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance) 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /local_dm_control_suite/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Utility functions used in the control suite.""" 17 | -------------------------------------------------------------------------------- /local_dm_control_suite/utils/parse_amc_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for parse_amc utility.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | # Internal dependencies. 25 | 26 | from absl.testing import absltest 27 | from local_dm_control_suite import humanoid_CMU 28 | from dm_control.suite.utils import parse_amc 29 | 30 | from dm_control.utils import io as resources 31 | 32 | _TEST_AMC_PATH = resources.GetResourceFilename( 33 | os.path.join(os.path.dirname(__file__), '../demos/zeros.amc')) 34 | 35 | 36 | class ParseAMCTest(absltest.TestCase): 37 | 38 | def test_sizes_of_parsed_data(self): 39 | 40 | # Instantiate the humanoid environment. 41 | env = humanoid_CMU.stand() 42 | 43 | # Parse and convert specified clip. 44 | converted = parse_amc.convert( 45 | _TEST_AMC_PATH, env.physics, env.control_timestep()) 46 | 47 | self.assertEqual(converted.qpos.shape[0], 63) 48 | self.assertEqual(converted.qvel.shape[0], 62) 49 | self.assertEqual(converted.time.shape[0], converted.qpos.shape[1]) 50 | self.assertEqual(converted.qpos.shape[1], 51 | converted.qvel.shape[1] + 1) 52 | 53 | # Parse and convert specified clip -- WITH SMALLER TIMESTEP 54 | converted2 = parse_amc.convert( 55 | _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep()) 56 | 57 | self.assertEqual(converted2.qpos.shape[0], 63) 58 | self.assertEqual(converted2.qvel.shape[0], 62) 59 | self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1]) 60 | self.assertEqual(converted.qpos.shape[1], 61 | converted.qvel.shape[1] + 1) 62 | 63 | # Compare sizes of parsed objects for different timesteps 64 | self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1]) 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /local_dm_control_suite/utils/randomizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Randomization functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from dm_control.mujoco.wrapper import mjbindings 23 | import numpy as np 24 | from six.moves import range 25 | 26 | 27 | def random_limited_quaternion(random, limit): 28 | """Generates a random quaternion limited to the specified rotations.""" 29 | axis = random.randn(3) 30 | axis /= np.linalg.norm(axis) 31 | angle = random.rand() * limit 32 | 33 | quaternion = np.zeros(4) 34 | mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle) 35 | 36 | return quaternion 37 | 38 | 39 | def randomize_limited_and_rotational_joints(physics, random=None): 40 | """Randomizes the positions of joints defined in the physics body. 41 | 42 | The following randomization rules apply: 43 | - Bounded joints (hinges or sliders) are sampled uniformly in the bounds. 44 | - Unbounded hinges are samples uniformly in [-pi, pi] 45 | - Quaternions for unlimited free joints and ball joints are sampled 46 | uniformly on the unit 3-sphere. 47 | - Quaternions for limited ball joints are sampled uniformly on a sector 48 | of the unit 3-sphere. 49 | - The linear degrees of freedom of free joints are not randomized. 50 | 51 | Args: 52 | physics: Instance of 'Physics' class that holds a loaded model. 53 | random: Optional instance of 'np.random.RandomState'. Defaults to the global 54 | NumPy random state. 55 | """ 56 | random = random or np.random 57 | 58 | hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE 59 | slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE 60 | ball = mjbindings.enums.mjtJoint.mjJNT_BALL 61 | free = mjbindings.enums.mjtJoint.mjJNT_FREE 62 | 63 | qpos = physics.named.data.qpos 64 | 65 | for joint_id in range(physics.model.njnt): 66 | joint_name = physics.model.id2name(joint_id, 'joint') 67 | joint_type = physics.model.jnt_type[joint_id] 68 | is_limited = physics.model.jnt_limited[joint_id] 69 | range_min, range_max = physics.model.jnt_range[joint_id] 70 | 71 | if is_limited: 72 | if joint_type == hinge or joint_type == slide: 73 | qpos[joint_name] = random.uniform(range_min, range_max) 74 | 75 | elif joint_type == ball: 76 | qpos[joint_name] = random_limited_quaternion(random, range_max) 77 | 78 | else: 79 | if joint_type == hinge: 80 | qpos[joint_name] = random.uniform(-np.pi, np.pi) 81 | 82 | elif joint_type == ball: 83 | quat = random.randn(4) 84 | quat /= np.linalg.norm(quat) 85 | qpos[joint_name] = quat 86 | 87 | elif joint_type == free: 88 | quat = random.rand(4) 89 | quat /= np.linalg.norm(quat) 90 | qpos[joint_name][3:] = quat 91 | 92 | -------------------------------------------------------------------------------- /local_dm_control_suite/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 71 | -------------------------------------------------------------------------------- /local_dm_control_suite/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Environment wrappers used to extend or modify environment behaviour.""" 17 | -------------------------------------------------------------------------------- /local_dm_control_suite/wrappers/action_noise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Wrapper control suite environments that adds Gaussian noise to actions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import dm_env 23 | import numpy as np 24 | 25 | 26 | _BOUNDS_MUST_BE_FINITE = ( 27 | 'All bounds in `env.action_spec()` must be finite, got: {action_spec}') 28 | 29 | 30 | class Wrapper(dm_env.Environment): 31 | """Wraps a control environment and adds Gaussian noise to actions.""" 32 | 33 | def __init__(self, env, scale=0.01): 34 | """Initializes a new action noise Wrapper. 35 | 36 | Args: 37 | env: The control suite environment to wrap. 38 | scale: The standard deviation of the noise, expressed as a fraction 39 | of the max-min range for each action dimension. 40 | 41 | Raises: 42 | ValueError: If any of the action dimensions of the wrapped environment are 43 | unbounded. 44 | """ 45 | action_spec = env.action_spec() 46 | if not (np.all(np.isfinite(action_spec.minimum)) and 47 | np.all(np.isfinite(action_spec.maximum))): 48 | raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec)) 49 | self._minimum = action_spec.minimum 50 | self._maximum = action_spec.maximum 51 | self._noise_std = scale * (action_spec.maximum - action_spec.minimum) 52 | self._env = env 53 | 54 | def step(self, action): 55 | noisy_action = action + self._env.task.random.normal(scale=self._noise_std) 56 | # Clip the noisy actions in place so that they fall within the bounds 57 | # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of- 58 | # bounds control inputs, but we also clip here in case the actions do not 59 | # correspond directly to MuJoCo actuators, or if there are other wrapper 60 | # layers that expect the actions to be within bounds. 61 | np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action) 62 | return self._env.step(noisy_action) 63 | 64 | def reset(self): 65 | return self._env.reset() 66 | 67 | def observation_spec(self): 68 | return self._env.observation_spec() 69 | 70 | def action_spec(self): 71 | return self._env.action_spec() 72 | 73 | def __getattr__(self, name): 74 | return getattr(self._env, name) 75 | -------------------------------------------------------------------------------- /local_dm_control_suite/wrappers/pixels.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Wrapper that adds pixel observations to a control environment.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | import dm_env 25 | from dm_env import specs 26 | 27 | STATE_KEY = 'state' 28 | 29 | 30 | class Wrapper(dm_env.Environment): 31 | """Wraps a control environment and adds a rendered pixel observation.""" 32 | 33 | def __init__(self, env, pixels_only=True, render_kwargs=None, 34 | observation_key='pixels'): 35 | """Initializes a new pixel Wrapper. 36 | 37 | Args: 38 | env: The environment to wrap. 39 | pixels_only: If True (default), the original set of 'state' observations 40 | returned by the wrapped environment will be discarded, and the 41 | `OrderedDict` of observations will only contain pixels. If False, the 42 | `OrderedDict` will contain the original observations as well as the 43 | pixel observations. 44 | render_kwargs: Optional `dict` containing keyword arguments passed to the 45 | `mujoco.Physics.render` method. 46 | observation_key: Optional custom string specifying the pixel observation's 47 | key in the `OrderedDict` of observations. Defaults to 'pixels'. 48 | 49 | Raises: 50 | ValueError: If `env`'s observation spec is not compatible with the 51 | wrapper. Supported formats are a single array, or a dict of arrays. 52 | ValueError: If `env`'s observation already contains the specified 53 | `observation_key`. 54 | """ 55 | if render_kwargs is None: 56 | render_kwargs = {} 57 | 58 | wrapped_observation_spec = env.observation_spec() 59 | 60 | if isinstance(wrapped_observation_spec, specs.Array): 61 | self._observation_is_dict = False 62 | invalid_keys = set([STATE_KEY]) 63 | elif isinstance(wrapped_observation_spec, collections.MutableMapping): 64 | self._observation_is_dict = True 65 | invalid_keys = set(wrapped_observation_spec.keys()) 66 | else: 67 | raise ValueError('Unsupported observation spec structure.') 68 | 69 | if not pixels_only and observation_key in invalid_keys: 70 | raise ValueError('Duplicate or reserved observation key {!r}.' 71 | .format(observation_key)) 72 | 73 | if pixels_only: 74 | self._observation_spec = collections.OrderedDict() 75 | elif self._observation_is_dict: 76 | self._observation_spec = wrapped_observation_spec.copy() 77 | else: 78 | self._observation_spec = collections.OrderedDict() 79 | self._observation_spec[STATE_KEY] = wrapped_observation_spec 80 | 81 | # Extend observation spec. 82 | pixels = env.physics.render(**render_kwargs) 83 | pixels_spec = specs.Array( 84 | shape=pixels.shape, dtype=pixels.dtype, name=observation_key) 85 | self._observation_spec[observation_key] = pixels_spec 86 | 87 | self._env = env 88 | self._pixels_only = pixels_only 89 | self._render_kwargs = render_kwargs 90 | self._observation_key = observation_key 91 | 92 | def reset(self): 93 | time_step = self._env.reset() 94 | return self._add_pixel_observation(time_step) 95 | 96 | def step(self, action): 97 | time_step = self._env.step(action) 98 | return self._add_pixel_observation(time_step) 99 | 100 | def observation_spec(self): 101 | return self._observation_spec 102 | 103 | def action_spec(self): 104 | return self._env.action_spec() 105 | 106 | def _add_pixel_observation(self, time_step): 107 | if self._pixels_only: 108 | observation = collections.OrderedDict() 109 | elif self._observation_is_dict: 110 | observation = type(time_step.observation)(time_step.observation) 111 | else: 112 | observation = collections.OrderedDict() 113 | observation[STATE_KEY] = time_step.observation 114 | 115 | pixels = self._env.physics.render(**self._render_kwargs) 116 | observation[self._observation_key] = pixels 117 | return time_step._replace(observation=observation) 118 | 119 | def __getattr__(self, name): 120 | return getattr(self._env, name) 121 | -------------------------------------------------------------------------------- /local_dm_control_suite/wrappers/pixels_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for the pixel wrapper.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | # Internal dependencies. 25 | from absl.testing import absltest 26 | from absl.testing import parameterized 27 | from local_dm_control_suite import cartpole 28 | from dm_control.suite.wrappers import pixels 29 | import dm_env 30 | from dm_env import specs 31 | 32 | import numpy as np 33 | 34 | 35 | class FakePhysics(object): 36 | 37 | def render(self, *args, **kwargs): 38 | del args 39 | del kwargs 40 | return np.zeros((4, 5, 3), dtype=np.uint8) 41 | 42 | 43 | class FakeArrayObservationEnvironment(dm_env.Environment): 44 | 45 | def __init__(self): 46 | self.physics = FakePhysics() 47 | 48 | def reset(self): 49 | return dm_env.restart(np.zeros((2,))) 50 | 51 | def step(self, action): 52 | del action 53 | return dm_env.transition(0.0, np.zeros((2,))) 54 | 55 | def action_spec(self): 56 | pass 57 | 58 | def observation_spec(self): 59 | return specs.Array(shape=(2,), dtype=np.float) 60 | 61 | 62 | class PixelsTest(parameterized.TestCase): 63 | 64 | @parameterized.parameters(True, False) 65 | def test_dict_observation(self, pixels_only): 66 | pixel_key = 'rgb' 67 | 68 | env = cartpole.swingup() 69 | 70 | # Make sure we are testing the right environment for the test. 71 | observation_spec = env.observation_spec() 72 | self.assertIsInstance(observation_spec, collections.OrderedDict) 73 | 74 | width = 320 75 | height = 240 76 | 77 | # The wrapper should only add one observation. 78 | wrapped = pixels.Wrapper(env, 79 | observation_key=pixel_key, 80 | pixels_only=pixels_only, 81 | render_kwargs={'width': width, 'height': height}) 82 | 83 | wrapped_observation_spec = wrapped.observation_spec() 84 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) 85 | 86 | if pixels_only: 87 | self.assertLen(wrapped_observation_spec, 1) 88 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) 89 | else: 90 | expected_length = len(observation_spec) + 1 91 | self.assertLen(wrapped_observation_spec, expected_length) 92 | expected_keys = list(observation_spec.keys()) + [pixel_key] 93 | self.assertEqual(expected_keys, list(wrapped_observation_spec.keys())) 94 | 95 | # Check that the added spec item is consistent with the added observation. 96 | time_step = wrapped.reset() 97 | rgb_observation = time_step.observation[pixel_key] 98 | wrapped_observation_spec[pixel_key].validate(rgb_observation) 99 | 100 | self.assertEqual(rgb_observation.shape, (height, width, 3)) 101 | self.assertEqual(rgb_observation.dtype, np.uint8) 102 | 103 | @parameterized.parameters(True, False) 104 | def test_single_array_observation(self, pixels_only): 105 | pixel_key = 'depth' 106 | 107 | env = FakeArrayObservationEnvironment() 108 | observation_spec = env.observation_spec() 109 | self.assertIsInstance(observation_spec, specs.Array) 110 | 111 | wrapped = pixels.Wrapper(env, observation_key=pixel_key, 112 | pixels_only=pixels_only) 113 | wrapped_observation_spec = wrapped.observation_spec() 114 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) 115 | 116 | if pixels_only: 117 | self.assertLen(wrapped_observation_spec, 1) 118 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) 119 | else: 120 | self.assertLen(wrapped_observation_spec, 2) 121 | self.assertEqual([pixels.STATE_KEY, pixel_key], 122 | list(wrapped_observation_spec.keys())) 123 | 124 | time_step = wrapped.reset() 125 | 126 | depth_observation = time_step.observation[pixel_key] 127 | wrapped_observation_spec[pixel_key].validate(depth_observation) 128 | 129 | self.assertEqual(depth_observation.shape, (4, 5, 3)) 130 | self.assertEqual(depth_observation.dtype, np.uint8) 131 | 132 | if __name__ == '__main__': 133 | absltest.main() 134 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/README.md: -------------------------------------------------------------------------------- 1 | # DeepMind Control Suite. 2 | 3 | This submodule contains the domains and tasks described in the 4 | [DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690). 5 | 6 | ## Quickstart 7 | 8 | ```python 9 | from dm_control import suite 10 | import numpy as np 11 | 12 | # Load one task: 13 | env = suite.load(domain_name="cartpole", task_name="swingup") 14 | 15 | # Iterate over a task set: 16 | for domain_name, task_name in suite.BENCHMARKING: 17 | env = suite.load(domain_name, task_name) 18 | 19 | # Step through an episode and print out reward, discount and observation. 20 | action_spec = env.action_spec() 21 | time_step = env.reset() 22 | while not time_step.last(): 23 | action = np.random.uniform(action_spec.minimum, 24 | action_spec.maximum, 25 | size=action_spec.shape) 26 | time_step = env.step(action) 27 | print(time_step.reward, time_step.discount, time_step.observation) 28 | ``` 29 | 30 | ## Illustration video 31 | 32 | Below is a video montage of solved Control Suite tasks, with reward 33 | visualisation enabled. 34 | 35 | [![Video montage](https://img.youtube.com/vi/rAai4QzcYbs/0.jpg)](https://www.youtube.com/watch?v=rAai4QzcYbs) 36 | 37 | 38 | ### Quadruped domain [April 2019] 39 | 40 | Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are: 41 | 42 | - 4 DoFs per leg, 1 constraining tendon. 43 | - 3 actuators per leg: 'yaw', 'lift', 'extend'. 44 | - Filtered position actuators with timescale of 100ms. 45 | - Sensors include an IMU, force/torque sensors, and rangefinders. 46 | 47 | Four tasks: 48 | 49 | - `walk` and `run`: self-right the body then move forward at a desired speed. 50 | - `escape`: escape a bowl-shaped random terrain (uses rangefinders). 51 | - `fetch`, go to a moving ball and bring it to a target. 52 | 53 | All behaviors in the video below were trained with [Abdolmaleki et al's 54 | MPO](https://arxiv.org/abs/1806.06920). 55 | 56 | [![Video montage](https://img.youtube.com/vi/RhRLjbb7pBE/0.jpg)](https://www.youtube.com/watch?v=RhRLjbb7pBE) 57 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/acrobot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Acrobot domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.utils import containers 29 | from dm_control.utils import rewards 30 | import numpy as np 31 | 32 | _DEFAULT_TIME_LIMIT = 10 33 | SUITE = containers.TaggedTasks() 34 | 35 | 36 | def get_model_and_assets(): 37 | """Returns a tuple containing the model XML string and a dict of assets.""" 38 | return common.read_model('acrobot.xml'), common.ASSETS 39 | 40 | 41 | @SUITE.add('benchmarking') 42 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None, 43 | environment_kwargs=None): 44 | """Returns Acrobot balance task.""" 45 | physics = Physics.from_xml_string(*get_model_and_assets()) 46 | task = Balance(sparse=False, random=random) 47 | environment_kwargs = environment_kwargs or {} 48 | return control.Environment( 49 | physics, task, time_limit=time_limit, **environment_kwargs) 50 | 51 | 52 | @SUITE.add('benchmarking') 53 | def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None, 54 | environment_kwargs=None): 55 | """Returns Acrobot sparse balance.""" 56 | physics = Physics.from_xml_string(*get_model_and_assets()) 57 | task = Balance(sparse=True, random=random) 58 | environment_kwargs = environment_kwargs or {} 59 | return control.Environment( 60 | physics, task, time_limit=time_limit, **environment_kwargs) 61 | 62 | 63 | class Physics(mujoco.Physics): 64 | """Physics simulation with additional features for the Acrobot domain.""" 65 | 66 | def horizontal(self): 67 | """Returns horizontal (x) component of body frame z-axes.""" 68 | return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz'] 69 | 70 | def vertical(self): 71 | """Returns vertical (z) component of body frame z-axes.""" 72 | return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz'] 73 | 74 | def to_target(self): 75 | """Returns the distance from the tip to the target.""" 76 | tip_to_target = (self.named.data.site_xpos['target'] - 77 | self.named.data.site_xpos['tip']) 78 | return np.linalg.norm(tip_to_target) 79 | 80 | def orientations(self): 81 | """Returns the sines and cosines of the pole angles.""" 82 | return np.concatenate((self.horizontal(), self.vertical())) 83 | 84 | 85 | class Balance(base.Task): 86 | """An Acrobot `Task` to swing up and balance the pole.""" 87 | 88 | def __init__(self, sparse, random=None): 89 | """Initializes an instance of `Balance`. 90 | 91 | Args: 92 | sparse: A `bool` specifying whether to use a sparse (indicator) reward. 93 | random: Optional, either a `numpy.random.RandomState` instance, an 94 | integer seed for creating a new `RandomState`, or None to select a seed 95 | automatically (default). 96 | """ 97 | self._sparse = sparse 98 | super(Balance, self).__init__(random=random) 99 | 100 | def initialize_episode(self, physics): 101 | """Sets the state of the environment at the start of each episode. 102 | 103 | Shoulder and elbow are set to a random position between [-pi, pi). 104 | 105 | Args: 106 | physics: An instance of `Physics`. 107 | """ 108 | physics.named.data.qpos[ 109 | ['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2) 110 | super(Balance, self).initialize_episode(physics) 111 | 112 | def get_observation(self, physics): 113 | """Returns an observation of pole orientation and angular velocities.""" 114 | obs = collections.OrderedDict() 115 | obs['orientations'] = physics.orientations() 116 | obs['velocity'] = physics.velocity() 117 | return obs 118 | 119 | def _get_reward(self, physics, sparse): 120 | target_radius = physics.named.model.site_size['target', 0] 121 | return rewards.tolerance(physics.to_target(), 122 | bounds=(0, target_radius), 123 | margin=0 if sparse else 1) 124 | 125 | def get_reward(self, physics): 126 | """Returns a sparse or a smooth reward, as specified in the constructor.""" 127 | return self._get_reward(physics, sparse=self._sparse) 128 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/acrobot.xml: -------------------------------------------------------------------------------- 1 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/ball_in_cup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Ball-in-Cup Domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.utils import containers 29 | 30 | _DEFAULT_TIME_LIMIT = 20 # (seconds) 31 | _CONTROL_TIMESTEP = .02 # (seconds) 32 | 33 | 34 | SUITE = containers.TaggedTasks() 35 | 36 | 37 | def get_model_and_assets(): 38 | """Returns a tuple containing the model XML string and a dict of assets.""" 39 | return common.read_model('ball_in_cup.xml'), common.ASSETS 40 | 41 | 42 | @SUITE.add('benchmarking', 'easy') 43 | def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 44 | """Returns the Ball-in-Cup task.""" 45 | physics = Physics.from_xml_string(*get_model_and_assets()) 46 | task = BallInCup(random=random) 47 | environment_kwargs = environment_kwargs or {} 48 | return control.Environment( 49 | physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP, 50 | **environment_kwargs) 51 | 52 | 53 | class Physics(mujoco.Physics): 54 | """Physics with additional features for the Ball-in-Cup domain.""" 55 | 56 | def ball_to_target(self): 57 | """Returns the vector from the ball to the target.""" 58 | target = self.named.data.site_xpos['target', ['x', 'z']] 59 | ball = self.named.data.xpos['ball', ['x', 'z']] 60 | return target - ball 61 | 62 | def in_target(self): 63 | """Returns 1 if the ball is in the target, 0 otherwise.""" 64 | ball_to_target = abs(self.ball_to_target()) 65 | target_size = self.named.model.site_size['target', [0, 2]] 66 | ball_size = self.named.model.geom_size['ball', 0] 67 | return float(all(ball_to_target < target_size - ball_size)) 68 | 69 | 70 | class BallInCup(base.Task): 71 | """The Ball-in-Cup task. Put the ball in the cup.""" 72 | 73 | def initialize_episode(self, physics): 74 | """Sets the state of the environment at the start of each episode. 75 | 76 | Args: 77 | physics: An instance of `Physics`. 78 | 79 | """ 80 | # Find a collision-free random initial position of the ball. 81 | penetrating = True 82 | while penetrating: 83 | # Assign a random ball position. 84 | physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2) 85 | physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5) 86 | # Check for collisions. 87 | physics.after_reset() 88 | penetrating = physics.data.ncon > 0 89 | super(BallInCup, self).initialize_episode(physics) 90 | 91 | def get_observation(self, physics): 92 | """Returns an observation of the state.""" 93 | obs = collections.OrderedDict() 94 | obs['position'] = physics.position() 95 | obs['velocity'] = physics.velocity() 96 | return obs 97 | 98 | def get_reward(self, physics): 99 | """Returns a sparse reward.""" 100 | return physics.in_target() 101 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/ball_in_cup.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Base class for tasks in the Control Suite.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from dm_control import mujoco 23 | from dm_control.rl import control 24 | 25 | import numpy as np 26 | 27 | 28 | class Task(control.Task): 29 | """Base class for tasks in the Control Suite. 30 | 31 | Actions are mapped directly to the states of MuJoCo actuators: each element of 32 | the action array is used to set the control input for a single actuator. The 33 | ordering of the actuators is the same as in the corresponding MJCF XML file. 34 | 35 | Attributes: 36 | random: A `numpy.random.RandomState` instance. This should be used to 37 | generate all random variables associated with the task, such as random 38 | starting states, observation noise* etc. 39 | 40 | *If sensor noise is enabled in the MuJoCo model then this will be generated 41 | using MuJoCo's internal RNG, which has its own independent state. 42 | """ 43 | 44 | def __init__(self, random=None): 45 | """Initializes a new continuous control task. 46 | 47 | Args: 48 | random: Optional, either a `numpy.random.RandomState` instance, an integer 49 | seed for creating a new `RandomState`, or None to select a seed 50 | automatically (default). 51 | """ 52 | if not isinstance(random, np.random.RandomState): 53 | random = np.random.RandomState(random) 54 | self._random = random 55 | self._visualize_reward = False 56 | 57 | @property 58 | def random(self): 59 | """Task-specific `numpy.random.RandomState` instance.""" 60 | return self._random 61 | 62 | def action_spec(self, physics): 63 | """Returns a `BoundedArraySpec` matching the `physics` actuators.""" 64 | return mujoco.action_spec(physics) 65 | 66 | def initialize_episode(self, physics): 67 | """Resets geom colors to their defaults after starting a new episode. 68 | 69 | Subclasses of `base.Task` must delegate to this method after performing 70 | their own initialization. 71 | 72 | Args: 73 | physics: An instance of `mujoco.Physics`. 74 | """ 75 | self.after_step(physics) 76 | 77 | def before_step(self, action, physics): 78 | """Sets the control signal for the actuators to values in `action`.""" 79 | # Support legacy internal code. 80 | action = getattr(action, "continuous_actions", action) 81 | physics.set_control(action) 82 | 83 | def after_step(self, physics): 84 | """Modifies colors according to the reward.""" 85 | if self._visualize_reward: 86 | reward = np.clip(self.get_reward(physics), 0.0, 1.0) 87 | _set_reward_colors(physics, reward) 88 | 89 | @property 90 | def visualize_reward(self): 91 | return self._visualize_reward 92 | 93 | @visualize_reward.setter 94 | def visualize_reward(self, value): 95 | if not isinstance(value, bool): 96 | raise ValueError("Expected a boolean, got {}.".format(type(value))) 97 | self._visualize_reward = value 98 | 99 | 100 | _MATERIALS = ["self", "effector", "target"] 101 | _DEFAULT = [name + "_default" for name in _MATERIALS] 102 | _HIGHLIGHT = [name + "_highlight" for name in _MATERIALS] 103 | 104 | 105 | def _set_reward_colors(physics, reward): 106 | """Sets the highlight, effector and target colors according to the reward.""" 107 | assert 0.0 <= reward <= 1.0 108 | colors = physics.named.model.mat_rgba 109 | default = colors[_DEFAULT] 110 | highlight = colors[_HIGHLIGHT] 111 | blend_coef = reward ** 4 # Better color distinction near high rewards. 112 | colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default 113 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/cartpole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/cheetah.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Cheetah Domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.utils import containers 29 | from dm_control.utils import rewards 30 | 31 | 32 | # How long the simulation will run, in seconds. 33 | _DEFAULT_TIME_LIMIT = 10 34 | 35 | # Running speed above which reward is 1. 36 | _RUN_SPEED = 10 37 | 38 | SUITE = containers.TaggedTasks() 39 | 40 | 41 | def get_model_and_assets(): 42 | """Returns a tuple containing the model XML string and a dict of assets.""" 43 | return common.read_model('cheetah.xml'), common.ASSETS 44 | 45 | 46 | @SUITE.add('benchmarking') 47 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 48 | """Returns the run task.""" 49 | physics = Physics.from_xml_string(*get_model_and_assets()) 50 | task = Cheetah(random=random) 51 | environment_kwargs = environment_kwargs or {} 52 | return control.Environment(physics, task, time_limit=time_limit, 53 | **environment_kwargs) 54 | 55 | 56 | class Physics(mujoco.Physics): 57 | """Physics simulation with additional features for the Cheetah domain.""" 58 | 59 | def speed(self): 60 | """Returns the horizontal speed of the Cheetah.""" 61 | return self.named.data.sensordata['torso_subtreelinvel'][0] 62 | 63 | 64 | class Cheetah(base.Task): 65 | """A `Task` to train a running Cheetah.""" 66 | 67 | def initialize_episode(self, physics): 68 | """Sets the state of the environment at the start of each episode.""" 69 | # The indexing below assumes that all joints have a single DOF. 70 | assert physics.model.nq == physics.model.njnt 71 | is_limited = physics.model.jnt_limited == 1 72 | lower, upper = physics.model.jnt_range[is_limited].T 73 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper) 74 | 75 | # Stabilize the model before the actual simulation. 76 | for _ in range(200): 77 | physics.step() 78 | 79 | physics.data.time = 0 80 | self._timeout_progress = 0 81 | super(Cheetah, self).initialize_episode(physics) 82 | 83 | def get_observation(self, physics): 84 | """Returns an observation of the state, ignoring horizontal position.""" 85 | obs = collections.OrderedDict() 86 | # Ignores horizontal position to maintain translational invariance. 87 | obs['position'] = physics.data.qpos[1:].copy() 88 | obs['velocity'] = physics.velocity() 89 | return obs 90 | 91 | def get_reward(self, physics): 92 | """Returns a reward to the agent.""" 93 | return rewards.tolerance(physics.speed(), 94 | bounds=(_RUN_SPEED, float('inf')), 95 | margin=_RUN_SPEED, 96 | value_at_margin=0, 97 | sigmoid='linear') 98 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/cheetah.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 74 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Functions to manage the common assets for domains.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | from dm_control.utils import io as resources 24 | 25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__)) 26 | _FILENAMES = [ 27 | "./common/materials.xml", 28 | "./common/materials_white_floor.xml", 29 | "./common/skybox.xml", 30 | "./common/visual.xml", 31 | ] 32 | 33 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename)) 34 | for filename in _FILENAMES} 35 | 36 | 37 | def read_model(model_filename): 38 | """Reads a model XML file and returns its contents as a string.""" 39 | return resources.GetResource(os.path.join(_SUITE_DIR, model_filename)) 40 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/common/materials.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/common/materials_white_floor.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/common/skybox.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/common/visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/demos/mocap_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Demonstration of amc parsing for CMU mocap database. 17 | 18 | To run the demo, supply a path to a `.amc` file: 19 | 20 | python mocap_demo --filename='path/to/mocap.amc' 21 | 22 | CMU motion capture clips are available at mocap.cs.cmu.edu 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import time 30 | # Internal dependencies. 31 | 32 | from absl import app 33 | from absl import flags 34 | 35 | from local_dm_control_suite import humanoid_CMU 36 | from dm_control.suite.utils import parse_amc 37 | 38 | import matplotlib.pyplot as plt 39 | import numpy as np 40 | 41 | FLAGS = flags.FLAGS 42 | flags.DEFINE_string('filename', None, 'amc file to be converted.') 43 | flags.DEFINE_integer('max_num_frames', 90, 44 | 'Maximum number of frames for plotting/playback') 45 | 46 | 47 | def main(unused_argv): 48 | env = humanoid_CMU.stand() 49 | 50 | # Parse and convert specified clip. 51 | converted = parse_amc.convert(FLAGS.filename, 52 | env.physics, env.control_timestep()) 53 | 54 | max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1) 55 | 56 | width = 480 57 | height = 480 58 | video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8) 59 | 60 | for i in range(max_frame): 61 | p_i = converted.qpos[:, i] 62 | with env.physics.reset_context(): 63 | env.physics.data.qpos[:] = p_i 64 | video[i] = np.hstack([env.physics.render(height, width, camera_id=0), 65 | env.physics.render(height, width, camera_id=1)]) 66 | 67 | tic = time.time() 68 | for i in range(max_frame): 69 | if i == 0: 70 | img = plt.imshow(video[i]) 71 | else: 72 | img.set_data(video[i]) 73 | toc = time.time() 74 | clock_dt = toc - tic 75 | tic = time.time() 76 | # Real-time playback not always possible as clock_dt > .03 77 | plt.pause(max(0.01, 0.03 - clock_dt)) # Need min display time > 0.0. 78 | plt.draw() 79 | plt.waitforbuttonpress() 80 | 81 | 82 | if __name__ == '__main__': 83 | flags.mark_flag_as_required('filename') 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/demos/zeros.amc: -------------------------------------------------------------------------------- 1 | #DUMMY AMC for testing 2 | :FULLY-SPECIFIED 3 | :DEGREES 4 | 1 5 | root 0 0 0 0 0 0 6 | lowerback 0 0 0 7 | upperback 0 0 0 8 | thorax 0 0 0 9 | lowerneck 0 0 0 10 | upperneck 0 0 0 11 | head 0 0 0 12 | rclavicle 0 0 13 | rhumerus 0 0 0 14 | rradius 0 15 | rwrist 0 16 | rhand 0 0 17 | rfingers 0 18 | rthumb 0 0 19 | lclavicle 0 0 20 | lhumerus 0 0 0 21 | lradius 0 22 | lwrist 0 23 | lhand 0 0 24 | lfingers 0 25 | lthumb 0 0 26 | rfemur 0 0 0 27 | rtibia 0 28 | rfoot 0 0 29 | rtoes 0 30 | lfemur 0 0 0 31 | ltibia 0 32 | lfoot 0 0 33 | ltoes 0 34 | 2 35 | root 0 0 0 0 0 0 36 | lowerback 0 0 0 37 | upperback 0 0 0 38 | thorax 0 0 0 39 | lowerneck 0 0 0 40 | upperneck 0 0 0 41 | head 0 0 0 42 | rclavicle 0 0 43 | rhumerus 0 0 0 44 | rradius 0 45 | rwrist 0 46 | rhand 0 0 47 | rfingers 0 48 | rthumb 0 0 49 | lclavicle 0 0 50 | lhumerus 0 0 0 51 | lradius 0 52 | lwrist 0 53 | lhand 0 0 54 | lfingers 0 55 | lthumb 0 0 56 | rfemur 0 0 0 57 | rtibia 0 58 | rfoot 0 0 59 | rtoes 0 60 | lfemur 0 0 0 61 | ltibia 0 62 | lfoot 0 0 63 | ltoes 0 64 | 3 65 | root 0 0 0 0 0 0 66 | lowerback 0 0 0 67 | upperback 0 0 0 68 | thorax 0 0 0 69 | lowerneck 0 0 0 70 | upperneck 0 0 0 71 | head 0 0 0 72 | rclavicle 0 0 73 | rhumerus 0 0 0 74 | rradius 0 75 | rwrist 0 76 | rhand 0 0 77 | rfingers 0 78 | rthumb 0 0 79 | lclavicle 0 0 80 | lhumerus 0 0 0 81 | lradius 0 82 | lwrist 0 83 | lhand 0 0 84 | lfingers 0 85 | lthumb 0 0 86 | rfemur 0 0 0 87 | rtibia 0 88 | rfoot 0 0 89 | rtoes 0 90 | lfemur 0 0 0 91 | ltibia 0 92 | lfoot 0 0 93 | ltoes 0 94 | 4 95 | root 0 0 0 0 0 0 96 | lowerback 0 0 0 97 | upperback 0 0 0 98 | thorax 0 0 0 99 | lowerneck 0 0 0 100 | upperneck 0 0 0 101 | head 0 0 0 102 | rclavicle 0 0 103 | rhumerus 0 0 0 104 | rradius 0 105 | rwrist 0 106 | rhand 0 0 107 | rfingers 0 108 | rthumb 0 0 109 | lclavicle 0 0 110 | lhumerus 0 0 0 111 | lradius 0 112 | lwrist 0 113 | lhand 0 0 114 | lfingers 0 115 | lthumb 0 0 116 | rfemur 0 0 0 117 | rtibia 0 118 | rfoot 0 0 119 | rtoes 0 120 | lfemur 0 0 0 121 | ltibia 0 122 | lfoot 0 0 123 | ltoes 0 124 | 5 125 | root 0 0 0 0 0 0 126 | lowerback 0 0 0 127 | upperback 0 0 0 128 | thorax 0 0 0 129 | lowerneck 0 0 0 130 | upperneck 0 0 0 131 | head 0 0 0 132 | rclavicle 0 0 133 | rhumerus 0 0 0 134 | rradius 0 135 | rwrist 0 136 | rhand 0 0 137 | rfingers 0 138 | rthumb 0 0 139 | lclavicle 0 0 140 | lhumerus 0 0 0 141 | lradius 0 142 | lwrist 0 143 | lhand 0 0 144 | lfingers 0 145 | lthumb 0 0 146 | rfemur 0 0 0 147 | rtibia 0 148 | rfoot 0 0 149 | rtoes 0 150 | lfemur 0 0 0 151 | ltibia 0 152 | lfoot 0 0 153 | ltoes 0 154 | 6 155 | root 0 0 0 0 0 0 156 | lowerback 0 0 0 157 | upperback 0 0 0 158 | thorax 0 0 0 159 | lowerneck 0 0 0 160 | upperneck 0 0 0 161 | head 0 0 0 162 | rclavicle 0 0 163 | rhumerus 0 0 0 164 | rradius 0 165 | rwrist 0 166 | rhand 0 0 167 | rfingers 0 168 | rthumb 0 0 169 | lclavicle 0 0 170 | lhumerus 0 0 0 171 | lradius 0 172 | lwrist 0 173 | lhand 0 0 174 | lfingers 0 175 | lthumb 0 0 176 | rfemur 0 0 0 177 | rtibia 0 178 | rfoot 0 0 179 | rtoes 0 180 | lfemur 0 0 0 181 | ltibia 0 182 | lfoot 0 0 183 | ltoes 0 184 | 7 185 | root 0 0 0 0 0 0 186 | lowerback 0 0 0 187 | upperback 0 0 0 188 | thorax 0 0 0 189 | lowerneck 0 0 0 190 | upperneck 0 0 0 191 | head 0 0 0 192 | rclavicle 0 0 193 | rhumerus 0 0 0 194 | rradius 0 195 | rwrist 0 196 | rhand 0 0 197 | rfingers 0 198 | rthumb 0 0 199 | lclavicle 0 0 200 | lhumerus 0 0 0 201 | lradius 0 202 | lwrist 0 203 | lhand 0 0 204 | lfingers 0 205 | lthumb 0 0 206 | rfemur 0 0 0 207 | rtibia 0 208 | rfoot 0 0 209 | rtoes 0 210 | lfemur 0 0 0 211 | ltibia 0 212 | lfoot 0 0 213 | ltoes 0 214 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/explore.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Control suite environments explorer.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import app 22 | from absl import flags 23 | from dm_control import suite 24 | from dm_control.suite.wrappers import action_noise 25 | from six.moves import input 26 | 27 | from dm_control import viewer 28 | 29 | 30 | _ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS] 31 | 32 | flags.DEFINE_enum('environment_name', None, _ALL_NAMES, 33 | 'Optional \'domain_name.task_name\' pair specifying the ' 34 | 'environment to load. If unspecified a prompt will appear to ' 35 | 'select one.') 36 | flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.') 37 | flags.DEFINE_bool('visualize_reward', True, 38 | 'Whether to vary the colors of geoms according to the ' 39 | 'current reward value.') 40 | flags.DEFINE_float('action_noise', 0., 41 | 'Standard deviation of Gaussian noise to apply to actions, ' 42 | 'expressed as a fraction of the max-min range for each ' 43 | 'action dimension. Defaults to 0, i.e. no noise.') 44 | FLAGS = flags.FLAGS 45 | 46 | 47 | def prompt_environment_name(prompt, values): 48 | environment_name = None 49 | while not environment_name: 50 | environment_name = input(prompt) 51 | if not environment_name or values.index(environment_name) < 0: 52 | print('"%s" is not a valid environment name.' % environment_name) 53 | environment_name = None 54 | return environment_name 55 | 56 | 57 | def main(argv): 58 | del argv 59 | environment_name = FLAGS.environment_name 60 | if environment_name is None: 61 | print('\n '.join(['Available environments:'] + _ALL_NAMES)) 62 | environment_name = prompt_environment_name( 63 | 'Please select an environment name: ', _ALL_NAMES) 64 | 65 | index = _ALL_NAMES.index(environment_name) 66 | domain_name, task_name = suite.ALL_TASKS[index] 67 | 68 | task_kwargs = {} 69 | if not FLAGS.timeout: 70 | task_kwargs['time_limit'] = float('inf') 71 | 72 | def loader(): 73 | env = suite.load( 74 | domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs) 75 | env.task.visualize_reward = FLAGS.visualize_reward 76 | if FLAGS.action_noise > 0: 77 | env = action_noise.Wrapper(env, scale=FLAGS.action_noise) 78 | return env 79 | 80 | viewer.launch(loader) 81 | 82 | 83 | if __name__ == '__main__': 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/finger.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/fish.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 67 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/lqr.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/pendulum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Pendulum domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.utils import containers 29 | from dm_control.utils import rewards 30 | import numpy as np 31 | 32 | 33 | _DEFAULT_TIME_LIMIT = 20 34 | _ANGLE_BOUND = 8 35 | _COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND)) 36 | SUITE = containers.TaggedTasks() 37 | 38 | 39 | def get_model_and_assets(): 40 | """Returns a tuple containing the model XML string and a dict of assets.""" 41 | return common.read_model('pendulum.xml'), common.ASSETS 42 | 43 | 44 | @SUITE.add('benchmarking') 45 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None, 46 | environment_kwargs=None): 47 | """Returns pendulum swingup task .""" 48 | physics = Physics.from_xml_string(*get_model_and_assets()) 49 | task = SwingUp(random=random) 50 | environment_kwargs = environment_kwargs or {} 51 | return control.Environment( 52 | physics, task, time_limit=time_limit, **environment_kwargs) 53 | 54 | 55 | class Physics(mujoco.Physics): 56 | """Physics simulation with additional features for the Pendulum domain.""" 57 | 58 | def pole_vertical(self): 59 | """Returns vertical (z) component of pole frame.""" 60 | return self.named.data.xmat['pole', 'zz'] 61 | 62 | def angular_velocity(self): 63 | """Returns the angular velocity of the pole.""" 64 | return self.named.data.qvel['hinge'].copy() 65 | 66 | def pole_orientation(self): 67 | """Returns both horizontal and vertical components of pole frame.""" 68 | return self.named.data.xmat['pole', ['zz', 'xz']] 69 | 70 | 71 | class SwingUp(base.Task): 72 | """A Pendulum `Task` to swing up and balance the pole.""" 73 | 74 | def __init__(self, random=None): 75 | """Initialize an instance of `Pendulum`. 76 | 77 | Args: 78 | random: Optional, either a `numpy.random.RandomState` instance, an 79 | integer seed for creating a new `RandomState`, or None to select a seed 80 | automatically (default). 81 | """ 82 | super(SwingUp, self).__init__(random=random) 83 | 84 | def initialize_episode(self, physics): 85 | """Sets the state of the environment at the start of each episode. 86 | 87 | Pole is set to a random angle between [-pi, pi). 88 | 89 | Args: 90 | physics: An instance of `Physics`. 91 | 92 | """ 93 | physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi) 94 | super(SwingUp, self).initialize_episode(physics) 95 | 96 | def get_observation(self, physics): 97 | """Returns an observation. 98 | 99 | Observations are states concatenating pole orientation and angular velocity 100 | and pixels from fixed camera. 101 | 102 | Args: 103 | physics: An instance of `physics`, Pendulum physics. 104 | 105 | Returns: 106 | A `dict` of observation. 107 | """ 108 | obs = collections.OrderedDict() 109 | obs['orientation'] = physics.pole_orientation() 110 | obs['velocity'] = physics.angular_velocity() 111 | return obs 112 | 113 | def get_reward(self, physics): 114 | return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1)) 115 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/point_mass.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/reacher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Reacher domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from local_dm_control_suite import base 27 | from local_dm_control_suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | import numpy as np 32 | 33 | SUITE = containers.TaggedTasks() 34 | _DEFAULT_TIME_LIMIT = 20 35 | _BIG_TARGET = .05 36 | _SMALL_TARGET = .015 37 | 38 | 39 | def get_model_and_assets(): 40 | """Returns a tuple containing the model XML string and a dict of assets.""" 41 | return common.read_model('reacher.xml'), common.ASSETS 42 | 43 | 44 | @SUITE.add('benchmarking', 'easy') 45 | def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 46 | """Returns reacher with sparse reward with 5e-2 tol and randomized target.""" 47 | physics = Physics.from_xml_string(*get_model_and_assets()) 48 | task = Reacher(target_size=_BIG_TARGET, random=random) 49 | environment_kwargs = environment_kwargs or {} 50 | return control.Environment( 51 | physics, task, time_limit=time_limit, **environment_kwargs) 52 | 53 | 54 | @SUITE.add('benchmarking') 55 | def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 56 | """Returns reacher with sparse reward with 1e-2 tol and randomized target.""" 57 | physics = Physics.from_xml_string(*get_model_and_assets()) 58 | task = Reacher(target_size=_SMALL_TARGET, random=random) 59 | environment_kwargs = environment_kwargs or {} 60 | return control.Environment( 61 | physics, task, time_limit=time_limit, **environment_kwargs) 62 | 63 | 64 | class Physics(mujoco.Physics): 65 | """Physics simulation with additional features for the Reacher domain.""" 66 | 67 | def finger_to_target(self): 68 | """Returns the vector from target to finger in global coordinates.""" 69 | return (self.named.data.geom_xpos['target', :2] - 70 | self.named.data.geom_xpos['finger', :2]) 71 | 72 | def finger_to_target_dist(self): 73 | """Returns the signed distance between the finger and target surface.""" 74 | return np.linalg.norm(self.finger_to_target()) 75 | 76 | 77 | class Reacher(base.Task): 78 | """A reacher `Task` to reach the target.""" 79 | 80 | def __init__(self, target_size, random=None): 81 | """Initialize an instance of `Reacher`. 82 | 83 | Args: 84 | target_size: A `float`, tolerance to determine whether finger reached the 85 | target. 86 | random: Optional, either a `numpy.random.RandomState` instance, an 87 | integer seed for creating a new `RandomState`, or None to select a seed 88 | automatically (default). 89 | """ 90 | self._target_size = target_size 91 | super(Reacher, self).__init__(random=random) 92 | 93 | def initialize_episode(self, physics): 94 | """Sets the state of the environment at the start of each episode.""" 95 | physics.named.model.geom_size['target', 0] = self._target_size 96 | randomizers.randomize_limited_and_rotational_joints(physics, self.random) 97 | 98 | # Randomize target position 99 | angle = self.random.uniform(0, 2 * np.pi) 100 | radius = self.random.uniform(.05, .20) 101 | physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle) 102 | physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle) 103 | 104 | super(Reacher, self).initialize_episode(physics) 105 | 106 | def get_observation(self, physics): 107 | """Returns an observation of the state and the target position.""" 108 | obs = collections.OrderedDict() 109 | obs['position'] = physics.position() 110 | obs['to_target'] = physics.finger_to_target() 111 | obs['velocity'] = physics.velocity() 112 | return obs 113 | 114 | def get_reward(self, physics): 115 | radii = physics.named.model.geom_size[['target', 'finger'], 0].sum() 116 | return rewards.tolerance(physics.finger_to_target_dist(), (0, radii)) 117 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/reacher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/tests/loader_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for the dm_control.suite loader.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Internal dependencies. 23 | 24 | from absl.testing import absltest 25 | 26 | from dm_control import suite 27 | from dm_control.rl import control 28 | 29 | 30 | class LoaderTest(absltest.TestCase): 31 | 32 | def test_load_without_kwargs(self): 33 | env = suite.load('cartpole', 'swingup') 34 | self.assertIsInstance(env, control.Environment) 35 | 36 | def test_load_with_kwargs(self): 37 | env = suite.load('cartpole', 'swingup', 38 | task_kwargs={'time_limit': 40, 'random': 99}) 39 | self.assertIsInstance(env, control.Environment) 40 | 41 | 42 | class LoaderConstantsTest(absltest.TestCase): 43 | 44 | def testSuiteConstants(self): 45 | self.assertNotEmpty(suite.BENCHMARKING) 46 | self.assertNotEmpty(suite.EASY) 47 | self.assertNotEmpty(suite.HARD) 48 | self.assertNotEmpty(suite.EXTRA) 49 | 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/tests/lqr_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests specific to the LQR domain.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | import unittest 24 | 25 | # Internal dependencies. 26 | from absl import logging 27 | 28 | from absl.testing import absltest 29 | from absl.testing import parameterized 30 | 31 | from local_dm_control_suite import lqr 32 | from local_dm_control_suite import lqr_solver 33 | 34 | import numpy as np 35 | from six.moves import range 36 | 37 | 38 | class LqrTest(parameterized.TestCase): 39 | 40 | @parameterized.named_parameters( 41 | ('lqr_2_1', lqr.lqr_2_1), 42 | ('lqr_6_2', lqr.lqr_6_2)) 43 | def test_lqr_optimal_policy(self, make_env): 44 | env = make_env() 45 | p, k, beta = lqr_solver.solve(env) 46 | self.assertPolicyisOptimal(env, p, k, beta) 47 | 48 | @parameterized.named_parameters( 49 | ('lqr_2_1', lqr.lqr_2_1), 50 | ('lqr_6_2', lqr.lqr_6_2)) 51 | @unittest.skipUnless( 52 | condition=lqr_solver.sp, 53 | reason='scipy is not available, so non-scipy DARE solver is the default.') 54 | def test_lqr_optimal_policy_no_scipy(self, make_env): 55 | env = make_env() 56 | old_sp = lqr_solver.sp 57 | try: 58 | lqr_solver.sp = None # Force the solver to use the non-scipy code path. 59 | p, k, beta = lqr_solver.solve(env) 60 | finally: 61 | lqr_solver.sp = old_sp 62 | self.assertPolicyisOptimal(env, p, k, beta) 63 | 64 | def assertPolicyisOptimal(self, env, p, k, beta): 65 | tolerance = 1e-3 66 | n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta))) 67 | logging.info('%d timesteps for %g convergence.', n_steps, tolerance) 68 | total_loss = 0.0 69 | 70 | timestep = env.reset() 71 | initial_state = np.hstack((timestep.observation['position'], 72 | timestep.observation['velocity'])) 73 | logging.info('Measuring total cost over %d steps.', n_steps) 74 | for _ in range(n_steps): 75 | x = np.hstack((timestep.observation['position'], 76 | timestep.observation['velocity'])) 77 | # u = k*x is the optimal policy 78 | u = k.dot(x) 79 | total_loss += 1 - (timestep.reward or 0.0) 80 | timestep = env.step(u) 81 | 82 | logging.info('Analytical expected total cost is .5*x^T*p*x.') 83 | expected_loss = .5 * initial_state.T.dot(p).dot(initial_state) 84 | logging.info('Comparing measured and predicted costs.') 85 | np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance) 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Utility functions used in the control suite.""" 17 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/utils/parse_amc_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for parse_amc utility.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | # Internal dependencies. 25 | 26 | from absl.testing import absltest 27 | from local_dm_control_suite import humanoid_CMU 28 | from dm_control.suite.utils import parse_amc 29 | 30 | from dm_control.utils import io as resources 31 | 32 | _TEST_AMC_PATH = resources.GetResourceFilename( 33 | os.path.join(os.path.dirname(__file__), '../demos/zeros.amc')) 34 | 35 | 36 | class ParseAMCTest(absltest.TestCase): 37 | 38 | def test_sizes_of_parsed_data(self): 39 | 40 | # Instantiate the humanoid environment. 41 | env = humanoid_CMU.stand() 42 | 43 | # Parse and convert specified clip. 44 | converted = parse_amc.convert( 45 | _TEST_AMC_PATH, env.physics, env.control_timestep()) 46 | 47 | self.assertEqual(converted.qpos.shape[0], 63) 48 | self.assertEqual(converted.qvel.shape[0], 62) 49 | self.assertEqual(converted.time.shape[0], converted.qpos.shape[1]) 50 | self.assertEqual(converted.qpos.shape[1], 51 | converted.qvel.shape[1] + 1) 52 | 53 | # Parse and convert specified clip -- WITH SMALLER TIMESTEP 54 | converted2 = parse_amc.convert( 55 | _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep()) 56 | 57 | self.assertEqual(converted2.qpos.shape[0], 63) 58 | self.assertEqual(converted2.qvel.shape[0], 62) 59 | self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1]) 60 | self.assertEqual(converted.qpos.shape[1], 61 | converted.qvel.shape[1] + 1) 62 | 63 | # Compare sizes of parsed objects for different timesteps 64 | self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1]) 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/utils/randomizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Randomization functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from dm_control.mujoco.wrapper import mjbindings 23 | import numpy as np 24 | from six.moves import range 25 | 26 | 27 | def random_limited_quaternion(random, limit): 28 | """Generates a random quaternion limited to the specified rotations.""" 29 | axis = random.randn(3) 30 | axis /= np.linalg.norm(axis) 31 | angle = random.rand() * limit 32 | 33 | quaternion = np.zeros(4) 34 | mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle) 35 | 36 | return quaternion 37 | 38 | 39 | def randomize_limited_and_rotational_joints(physics, random=None): 40 | """Randomizes the positions of joints defined in the physics body. 41 | 42 | The following randomization rules apply: 43 | - Bounded joints (hinges or sliders) are sampled uniformly in the bounds. 44 | - Unbounded hinges are samples uniformly in [-pi, pi] 45 | - Quaternions for unlimited free joints and ball joints are sampled 46 | uniformly on the unit 3-sphere. 47 | - Quaternions for limited ball joints are sampled uniformly on a sector 48 | of the unit 3-sphere. 49 | - The linear degrees of freedom of free joints are not randomized. 50 | 51 | Args: 52 | physics: Instance of 'Physics' class that holds a loaded model. 53 | random: Optional instance of 'np.random.RandomState'. Defaults to the global 54 | NumPy random state. 55 | """ 56 | random = random or np.random 57 | 58 | hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE 59 | slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE 60 | ball = mjbindings.enums.mjtJoint.mjJNT_BALL 61 | free = mjbindings.enums.mjtJoint.mjJNT_FREE 62 | 63 | qpos = physics.named.data.qpos 64 | 65 | for joint_id in range(physics.model.njnt): 66 | joint_name = physics.model.id2name(joint_id, 'joint') 67 | joint_type = physics.model.jnt_type[joint_id] 68 | is_limited = physics.model.jnt_limited[joint_id] 69 | range_min, range_max = physics.model.jnt_range[joint_id] 70 | 71 | if is_limited: 72 | if joint_type == hinge or joint_type == slide: 73 | qpos[joint_name] = random.uniform(range_min, range_max) 74 | 75 | elif joint_type == ball: 76 | qpos[joint_name] = random_limited_quaternion(random, range_max) 77 | 78 | else: 79 | if joint_type == hinge: 80 | qpos[joint_name] = random.uniform(-np.pi, np.pi) 81 | 82 | elif joint_type == ball: 83 | quat = random.randn(4) 84 | quat /= np.linalg.norm(quat) 85 | qpos[joint_name] = quat 86 | 87 | elif joint_type == free: 88 | quat = random.rand(4) 89 | quat /= np.linalg.norm(quat) 90 | qpos[joint_name][3:] = quat 91 | 92 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 72 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Environment wrappers used to extend or modify environment behaviour.""" 17 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/wrappers/action_noise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Wrapper control suite environments that adds Gaussian noise to actions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import dm_env 23 | import numpy as np 24 | 25 | 26 | _BOUNDS_MUST_BE_FINITE = ( 27 | 'All bounds in `env.action_spec()` must be finite, got: {action_spec}') 28 | 29 | 30 | class Wrapper(dm_env.Environment): 31 | """Wraps a control environment and adds Gaussian noise to actions.""" 32 | 33 | def __init__(self, env, scale=0.01): 34 | """Initializes a new action noise Wrapper. 35 | 36 | Args: 37 | env: The control suite environment to wrap. 38 | scale: The standard deviation of the noise, expressed as a fraction 39 | of the max-min range for each action dimension. 40 | 41 | Raises: 42 | ValueError: If any of the action dimensions of the wrapped environment are 43 | unbounded. 44 | """ 45 | action_spec = env.action_spec() 46 | if not (np.all(np.isfinite(action_spec.minimum)) and 47 | np.all(np.isfinite(action_spec.maximum))): 48 | raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec)) 49 | self._minimum = action_spec.minimum 50 | self._maximum = action_spec.maximum 51 | self._noise_std = scale * (action_spec.maximum - action_spec.minimum) 52 | self._env = env 53 | 54 | def step(self, action): 55 | noisy_action = action + self._env.task.random.normal(scale=self._noise_std) 56 | # Clip the noisy actions in place so that they fall within the bounds 57 | # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of- 58 | # bounds control inputs, but we also clip here in case the actions do not 59 | # correspond directly to MuJoCo actuators, or if there are other wrapper 60 | # layers that expect the actions to be within bounds. 61 | np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action) 62 | return self._env.step(noisy_action) 63 | 64 | def reset(self): 65 | return self._env.reset() 66 | 67 | def observation_spec(self): 68 | return self._env.observation_spec() 69 | 70 | def action_spec(self): 71 | return self._env.action_spec() 72 | 73 | def __getattr__(self, name): 74 | return getattr(self._env, name) 75 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/wrappers/pixels.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Wrapper that adds pixel observations to a control environment.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | import dm_env 25 | from dm_env import specs 26 | 27 | STATE_KEY = 'state' 28 | 29 | 30 | class Wrapper(dm_env.Environment): 31 | """Wraps a control environment and adds a rendered pixel observation.""" 32 | 33 | def __init__(self, env, pixels_only=True, render_kwargs=None, 34 | observation_key='pixels'): 35 | """Initializes a new pixel Wrapper. 36 | 37 | Args: 38 | env: The environment to wrap. 39 | pixels_only: If True (default), the original set of 'state' observations 40 | returned by the wrapped environment will be discarded, and the 41 | `OrderedDict` of observations will only contain pixels. If False, the 42 | `OrderedDict` will contain the original observations as well as the 43 | pixel observations. 44 | render_kwargs: Optional `dict` containing keyword arguments passed to the 45 | `mujoco.Physics.render` method. 46 | observation_key: Optional custom string specifying the pixel observation's 47 | key in the `OrderedDict` of observations. Defaults to 'pixels'. 48 | 49 | Raises: 50 | ValueError: If `env`'s observation spec is not compatible with the 51 | wrapper. Supported formats are a single array, or a dict of arrays. 52 | ValueError: If `env`'s observation already contains the specified 53 | `observation_key`. 54 | """ 55 | if render_kwargs is None: 56 | render_kwargs = {} 57 | 58 | wrapped_observation_spec = env.observation_spec() 59 | 60 | if isinstance(wrapped_observation_spec, specs.Array): 61 | self._observation_is_dict = False 62 | invalid_keys = set([STATE_KEY]) 63 | elif isinstance(wrapped_observation_spec, collections.MutableMapping): 64 | self._observation_is_dict = True 65 | invalid_keys = set(wrapped_observation_spec.keys()) 66 | else: 67 | raise ValueError('Unsupported observation spec structure.') 68 | 69 | if not pixels_only and observation_key in invalid_keys: 70 | raise ValueError('Duplicate or reserved observation key {!r}.' 71 | .format(observation_key)) 72 | 73 | if pixels_only: 74 | self._observation_spec = collections.OrderedDict() 75 | elif self._observation_is_dict: 76 | self._observation_spec = wrapped_observation_spec.copy() 77 | else: 78 | self._observation_spec = collections.OrderedDict() 79 | self._observation_spec[STATE_KEY] = wrapped_observation_spec 80 | 81 | # Extend observation spec. 82 | pixels = env.physics.render(**render_kwargs) 83 | pixels_spec = specs.Array( 84 | shape=pixels.shape, dtype=pixels.dtype, name=observation_key) 85 | self._observation_spec[observation_key] = pixels_spec 86 | 87 | self._env = env 88 | self._pixels_only = pixels_only 89 | self._render_kwargs = render_kwargs 90 | self._observation_key = observation_key 91 | 92 | def reset(self): 93 | time_step = self._env.reset() 94 | return self._add_pixel_observation(time_step) 95 | 96 | def step(self, action): 97 | time_step = self._env.step(action) 98 | return self._add_pixel_observation(time_step) 99 | 100 | def observation_spec(self): 101 | return self._observation_spec 102 | 103 | def action_spec(self): 104 | return self._env.action_spec() 105 | 106 | def _add_pixel_observation(self, time_step): 107 | if self._pixels_only: 108 | observation = collections.OrderedDict() 109 | elif self._observation_is_dict: 110 | observation = type(time_step.observation)(time_step.observation) 111 | else: 112 | observation = collections.OrderedDict() 113 | observation[STATE_KEY] = time_step.observation 114 | 115 | pixels = self._env.physics.render(**self._render_kwargs) 116 | observation[self._observation_key] = pixels 117 | return time_step._replace(observation=observation) 118 | 119 | def __getattr__(self, name): 120 | return getattr(self._env, name) 121 | -------------------------------------------------------------------------------- /local_dm_control_suite_off_center/wrappers/pixels_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for the pixel wrapper.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | # Internal dependencies. 25 | from absl.testing import absltest 26 | from absl.testing import parameterized 27 | from local_dm_control_suite import cartpole 28 | from dm_control.suite.wrappers import pixels 29 | import dm_env 30 | from dm_env import specs 31 | 32 | import numpy as np 33 | 34 | 35 | class FakePhysics(object): 36 | 37 | def render(self, *args, **kwargs): 38 | del args 39 | del kwargs 40 | return np.zeros((4, 5, 3), dtype=np.uint8) 41 | 42 | 43 | class FakeArrayObservationEnvironment(dm_env.Environment): 44 | 45 | def __init__(self): 46 | self.physics = FakePhysics() 47 | 48 | def reset(self): 49 | return dm_env.restart(np.zeros((2,))) 50 | 51 | def step(self, action): 52 | del action 53 | return dm_env.transition(0.0, np.zeros((2,))) 54 | 55 | def action_spec(self): 56 | pass 57 | 58 | def observation_spec(self): 59 | return specs.Array(shape=(2,), dtype=np.float) 60 | 61 | 62 | class PixelsTest(parameterized.TestCase): 63 | 64 | @parameterized.parameters(True, False) 65 | def test_dict_observation(self, pixels_only): 66 | pixel_key = 'rgb' 67 | 68 | env = cartpole.swingup() 69 | 70 | # Make sure we are testing the right environment for the test. 71 | observation_spec = env.observation_spec() 72 | self.assertIsInstance(observation_spec, collections.OrderedDict) 73 | 74 | width = 320 75 | height = 240 76 | 77 | # The wrapper should only add one observation. 78 | wrapped = pixels.Wrapper(env, 79 | observation_key=pixel_key, 80 | pixels_only=pixels_only, 81 | render_kwargs={'width': width, 'height': height}) 82 | 83 | wrapped_observation_spec = wrapped.observation_spec() 84 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) 85 | 86 | if pixels_only: 87 | self.assertLen(wrapped_observation_spec, 1) 88 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) 89 | else: 90 | expected_length = len(observation_spec) + 1 91 | self.assertLen(wrapped_observation_spec, expected_length) 92 | expected_keys = list(observation_spec.keys()) + [pixel_key] 93 | self.assertEqual(expected_keys, list(wrapped_observation_spec.keys())) 94 | 95 | # Check that the added spec item is consistent with the added observation. 96 | time_step = wrapped.reset() 97 | rgb_observation = time_step.observation[pixel_key] 98 | wrapped_observation_spec[pixel_key].validate(rgb_observation) 99 | 100 | self.assertEqual(rgb_observation.shape, (height, width, 3)) 101 | self.assertEqual(rgb_observation.dtype, np.uint8) 102 | 103 | @parameterized.parameters(True, False) 104 | def test_single_array_observation(self, pixels_only): 105 | pixel_key = 'depth' 106 | 107 | env = FakeArrayObservationEnvironment() 108 | observation_spec = env.observation_spec() 109 | self.assertIsInstance(observation_spec, specs.Array) 110 | 111 | wrapped = pixels.Wrapper(env, observation_key=pixel_key, 112 | pixels_only=pixels_only) 113 | wrapped_observation_spec = wrapped.observation_spec() 114 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) 115 | 116 | if pixels_only: 117 | self.assertLen(wrapped_observation_spec, 1) 118 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) 119 | else: 120 | self.assertLen(wrapped_observation_spec, 2) 121 | self.assertEqual([pixels.STATE_KEY, pixel_key], 122 | list(wrapped_observation_spec.keys())) 123 | 124 | time_step = wrapped.reset() 125 | 126 | depth_observation = time_step.observation[pixel_key] 127 | wrapped_observation_spec[pixel_key].validate(depth_observation) 128 | 129 | self.assertEqual(depth_observation.shape, (4, 5, 3)) 130 | self.assertEqual(depth_observation.dtype, np.uint8) 131 | 132 | if __name__ == '__main__': 133 | absltest.main() 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | cachetools==4.2.2 3 | certifi==2021.5.30 4 | cffi==1.15.0 5 | chardet==4.0.0 6 | cloudpickle==1.6.0 7 | cycler==0.10.0 8 | Cython==0.29.24 9 | dataclasses==0.8 10 | decorator==4.4.2 11 | fasteners==0.16.3 12 | flatbuffers==1.12 13 | future==0.18.2 14 | gast==0.2.2 15 | gin-config==0.4.0 16 | glfw==2.1.0 17 | google-auth==1.28.1 18 | google-auth-oauthlib==0.4.4 19 | google-pasta==0.2.0 20 | grpcio==1.32.0 21 | gym==0.18.0 22 | h5py==2.10.0 23 | idna==2.10 24 | imageio==2.9.0 25 | imageio-ffmpeg==0.4.2 26 | importlib-metadata==3.10.0 27 | kiwisolver==1.3.1 28 | labmaze==1.0.5 29 | lockfile==0.12.2 30 | Markdown==3.3.4 31 | matplotlib==3.3.4 32 | mujoco-py==2.0.2.13 33 | networkx==2.5.1 34 | numpy==1.19.1 35 | oauthlib==3.1.0 36 | opencv-python==4.5.1.48 37 | opt-einsum==3.3.0 38 | pandas==1.1.5 39 | Pillow>=8.3.2 40 | Pillow-SIMD==7.0.0.post3 41 | protobuf==3.17.1 42 | pyasn1==0.4.8 43 | pyasn1-modules==0.2.8 44 | pybullet==3.2.0 45 | pycparser==2.20 46 | pyglet==1.5.0 47 | PyOpenGL==3.1.5 48 | pyparsing==2.4.7 49 | python-dateutil==2.8.1 50 | pytz==2021.1 51 | PyWavelets==1.1.1 52 | PyYAML==6.0 53 | requests==2.25.1 54 | requests-oauthlib==1.3.0 55 | rsa==4.7.2 56 | scikit-image==0.17.2 57 | scipy==1.5.4 58 | seaborn==0.11.1 59 | six==1.16.0 60 | sortedcontainers==2.4.0 61 | tensorboard==2.0.2 62 | tensorboard-data-server==0.6.0 63 | tensorboard-plugin-wit==1.8.0 64 | termcolor==1.1.0 65 | tifffile==2020.9.3 66 | torch==1.8.1 67 | tqdm==4.60.0 68 | typing-extensions==3.7.4.3 69 | urllib3>=1.26.5 70 | Werkzeug==1.0.1 71 | wrapt==1.12.1 72 | zipp==3.4.1 73 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --domain_name $3 \ 3 | --task_name $4 --case $1 \ 4 | --encoder_type pixel --work_dir ./tmp/test \ 5 | --action_repeat 8 --num_eval_episodes 10 \ 6 | --pre_transform_image_size 100 --image_size 84 --replay_buffer_capacity 100 \ 7 | --frame_stack 3 --data_augs no_aug \ 8 | --seed 23 --critic_lr 1e-3 --actor_lr 1e-3 --eval_freq 10000 \ 9 | --batch_size 16 --num_train_steps 200000 --metric_loss \ 10 | --init_steps 1000 \ 11 | --resource_files './distractors/driving/*.mp4' --img_source 'video' --total_frames 50 \ 12 | --horizon $2 --save_model 13 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import os 3 | import numpy as np 4 | 5 | 6 | class VideoRecorder(object): 7 | def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): 8 | self.dir_name = dir_name 9 | self.height = height 10 | self.width = width 11 | self.camera_id = camera_id 12 | self.fps = fps 13 | self.frames = [] 14 | 15 | def init(self, enabled=True): 16 | self.frames = [] 17 | self.enabled = self.dir_name is not None and enabled 18 | 19 | def record(self, env): 20 | if self.enabled: 21 | try: 22 | frame = env.render( 23 | mode='rgb_array', 24 | height=self.height, 25 | width=self.width, 26 | camera_id=self.camera_id 27 | ) 28 | except: 29 | frame = env.render( 30 | mode='rgb_array', 31 | ) 32 | 33 | self.frames.append(frame) 34 | 35 | def save(self, file_name): 36 | if self.enabled: 37 | path = os.path.join(self.dir_name, file_name) 38 | imageio.mimsave(path, self.frames, fps=self.fps) 39 | --------------------------------------------------------------------------------