├── .gitignore
├── LICENSE
├── README.md
├── TransformLayer.py
├── agents
    ├── agent_sac.py
    ├── agent_sac_SPR.py
    ├── agent_sac_base.py
    ├── agent_sac_contrastive.py
    ├── agent_sac_noreward.py
    ├── agent_sac_norewtotransition.py
    ├── agent_sac_notransition.py
    ├── agent_sac_reconstruction.py
    ├── agent_sac_value.py
    └── auxiliary_funcs.py
├── curl_sac.py
├── data_augs.py
├── distractors
    └── driving
    │   └── 1.mp4
├── dmc2gym
    ├── __init__.py
    ├── natural_imgsource.py
    └── wrappers.py
├── encoder.py
├── local_dm_control_suite
    ├── README.md
    ├── __init__.py
    ├── acrobot.py
    ├── acrobot.xml
    ├── ball_in_cup.py
    ├── ball_in_cup.xml
    ├── base.py
    ├── cartpole.py
    ├── cartpole.xml
    ├── cheetah.py
    ├── cheetah.xml
    ├── common
    │   ├── __init__.py
    │   ├── materials.xml
    │   ├── materials_white_floor.xml
    │   ├── skybox.xml
    │   └── visual.xml
    ├── demos
    │   ├── mocap_demo.py
    │   └── zeros.amc
    ├── explore.py
    ├── finger.py
    ├── finger.xml
    ├── fish.py
    ├── fish.xml
    ├── hopper.py
    ├── hopper.xml
    ├── humanoid.py
    ├── humanoid.xml
    ├── humanoid_CMU.py
    ├── humanoid_CMU.xml
    ├── lqr.py
    ├── lqr.xml
    ├── lqr_solver.py
    ├── manipulator.py
    ├── manipulator.xml
    ├── pendulum.py
    ├── pendulum.xml
    ├── point_mass.py
    ├── point_mass.xml
    ├── quadruped.py
    ├── quadruped.xml
    ├── reacher.py
    ├── reacher.xml
    ├── stacker.py
    ├── stacker.xml
    ├── swimmer.py
    ├── swimmer.xml
    ├── tests
    │   ├── domains_test.py
    │   ├── loader_test.py
    │   └── lqr_test.py
    ├── utils
    │   ├── __init__.py
    │   ├── parse_amc.py
    │   ├── parse_amc_test.py
    │   ├── randomizers.py
    │   └── randomizers_test.py
    ├── walker.py
    ├── walker.xml
    └── wrappers
    │   ├── __init__.py
    │   ├── action_noise.py
    │   ├── action_noise_test.py
    │   ├── pixels.py
    │   └── pixels_test.py
├── local_dm_control_suite_off_center
    ├── README.md
    ├── __init__.py
    ├── acrobot.py
    ├── acrobot.xml
    ├── ball_in_cup.py
    ├── ball_in_cup.xml
    ├── base.py
    ├── cartpole.py
    ├── cartpole.xml
    ├── cheetah.py
    ├── cheetah.xml
    ├── common
    │   ├── __init__.py
    │   ├── materials.xml
    │   ├── materials_white_floor.xml
    │   ├── skybox.xml
    │   └── visual.xml
    ├── demos
    │   ├── mocap_demo.py
    │   └── zeros.amc
    ├── explore.py
    ├── finger.py
    ├── finger.xml
    ├── fish.py
    ├── fish.xml
    ├── hopper.py
    ├── hopper.xml
    ├── humanoid.py
    ├── humanoid.xml
    ├── humanoid_CMU.py
    ├── humanoid_CMU.xml
    ├── lqr.py
    ├── lqr.xml
    ├── lqr_solver.py
    ├── manipulator.py
    ├── manipulator.xml
    ├── pendulum.py
    ├── pendulum.xml
    ├── point_mass.py
    ├── point_mass.xml
    ├── quadruped.py
    ├── quadruped.xml
    ├── reacher.py
    ├── reacher.xml
    ├── stacker.py
    ├── stacker.xml
    ├── swimmer.py
    ├── swimmer.xml
    ├── tests
    │   ├── domains_test.py
    │   ├── loader_test.py
    │   └── lqr_test.py
    ├── utils
    │   ├── __init__.py
    │   ├── parse_amc.py
    │   ├── parse_amc_test.py
    │   ├── randomizers.py
    │   └── randomizers_test.py
    ├── walker.py
    ├── walker.xml
    └── wrappers
    │   ├── __init__.py
    │   ├── action_noise.py
    │   ├── action_noise_test.py
    │   ├── pixels.py
    │   └── pixels_test.py
├── logger.py
├── multistep_dynamics.py
├── multistep_replay.py
├── multistep_utils.py
├── requirements.txt
├── scripts
    └── run.sh
├── train.py
├── utils.py
└── video.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | **/tb/**
3 | **/__pycache__/
4 | **/tmp/**
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2021 Utkarsh Mishra
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 
--------------------------------------------------------------------------------
/agents/agent_sac.py:
--------------------------------------------------------------------------------
  1 | import numpy as np
  2 | from numpy.core.numeric import tensordot
  3 | import torch
  4 | import torch.nn as nn
  5 | import torch.nn.functional as F
  6 | 
  7 | import utils
  8 | import data_augs as rad
  9 | from agents.auxiliary_funcs import BaseSacAgent
 10 | from multistep_dynamics import MultiStepDynamicsModel
 11 | import multistep_utils as mutils
 12 | 
 13 | class PixelSacAgent(BaseSacAgent):
 14 |     """Learning Representations of Pixel Observations with SAC + Self-Supervised Techniques.."""
 15 |     def __init__(
 16 |         self,
 17 |         obs_shape,
 18 |         action_shape,
 19 |         horizon,
 20 |         device,
 21 |         hidden_dim=256,
 22 |         discount=0.99,
 23 |         init_temperature=0.01,
 24 |         alpha_lr=1e-3,
 25 |         alpha_beta=0.9,
 26 |         actor_lr=1e-3,
 27 |         actor_beta=0.9,
 28 |         actor_log_std_min=-10,
 29 |         actor_log_std_max=2,
 30 |         actor_update_freq=2,
 31 |         critic_lr=1e-3,
 32 |         critic_beta=0.9,
 33 |         critic_tau=0.005,
 34 |         critic_target_update_freq=2,
 35 |         encoder_type='pixel',
 36 |         encoder_feature_dim=50,
 37 |         encoder_lr=1e-3,
 38 |         encoder_tau=0.005,
 39 |         decoder_lr=1e-3,
 40 |         decoder_weight_lambda=0.0,
 41 |         num_layers=4,
 42 |         num_filters=32,
 43 |         cpc_update_freq=1,
 44 |         log_interval=100,
 45 |         detach_encoder=False,
 46 |         latent_dim=128,
 47 |         data_augs = '',
 48 |         use_metric_loss=False
 49 |     ):
 50 | 
 51 |         print('########################################################################')
 52 |         print('################### Starting Case 0: Baseline Agent ####################')
 53 |         print('###################### Reward: No; Transition: No ######################')
 54 |         print('########################################################################')
 55 |         
 56 |         super().__init__(
 57 |             obs_shape,
 58 |             action_shape,
 59 |             horizon,
 60 |             device,
 61 |             hidden_dim,
 62 |             discount,
 63 |             init_temperature,
 64 |             alpha_lr,
 65 |             alpha_beta,
 66 |             actor_lr,
 67 |             actor_beta,
 68 |             actor_log_std_min,
 69 |             actor_log_std_max,
 70 |             actor_update_freq,
 71 |             critic_lr,
 72 |             critic_beta,
 73 |             critic_tau,
 74 |             critic_target_update_freq,
 75 |             encoder_type,
 76 |             encoder_feature_dim,
 77 |             encoder_lr,
 78 |             encoder_tau,
 79 |             decoder_lr,
 80 |             decoder_weight_lambda,
 81 |             num_layers,
 82 |             num_filters,
 83 |             cpc_update_freq,
 84 |             log_interval,
 85 |             detach_encoder,
 86 |             latent_dim,
 87 |             data_augs,
 88 |             use_metric_loss
 89 |         )
 90 | 
 91 |         self.train()
 92 |         self.critic_target.train()
 93 | 
 94 |     def update(self, replay_buffer, L, step):
 95 |         if self.encoder_type == 'pixel':
 96 | 
 97 |             batch_obs, batch_action, batch_reward, batch_not_done = replay_buffer.sample_multistep()
 98 | 
 99 |             obs = batch_obs[0]
100 |             action = batch_action[0]
101 |             next_obs = batch_obs[1]
102 |             reward = batch_reward[0].unsqueeze(-1)
103 |             not_done = batch_not_done[0].unsqueeze(-1) 
104 | 
105 |             self.update_critic(obs, action, reward, next_obs, not_done, L, step)
106 | 
107 |             encoded_batch_obs = []
108 | 
109 |             for en_iter in range(batch_obs.size(0)):
110 |                 encoded_obs = self.critic.encoder(batch_obs[en_iter])
111 |                 encoded_batch_obs.append(encoded_obs)
112 |                 
113 |             encoded_batch_obs = torch.stack(encoded_batch_obs)   
114 | 
115 |         else:
116 |             obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio()
117 |             self.update_critic(obs, action, reward, next_obs, not_done, L, step)
118 |     
119 |         if step % self.log_interval == 0:
120 |             L.log('train/batch_reward', reward.mean(), step)
121 | 
122 |         if step % self.actor_update_freq == 0:
123 |             self.update_actor_and_alpha(obs, L, step)
124 | 
125 |         if step % self.critic_target_update_freq == 0:
126 |             utils.soft_update_params(
127 |                 self.critic.Q1, self.critic_target.Q1, self.critic_tau
128 |             )
129 |             utils.soft_update_params(
130 |                 self.critic.Q2, self.critic_target.Q2, self.critic_tau
131 |             )
132 |             utils.soft_update_params(
133 |                 self.critic.encoder, self.critic_target.encoder,
134 |                 self.encoder_tau
135 |             )
--------------------------------------------------------------------------------
/distractors/driving/1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UtkarshMishra04/pixel-representations-RL/8f457adcf41eb3b8975eaa5a752736c07b908c63/distractors/driving/1.mp4
--------------------------------------------------------------------------------
/dmc2gym/__init__.py:
--------------------------------------------------------------------------------
 1 | import gym
 2 | from gym.envs.registration import register
 3 | 
 4 | 
 5 | def make(
 6 |     domain_name,
 7 |     task_name,
 8 |     resource_files,
 9 |     img_source,
10 |     total_frames,
11 |     seed=1,
12 |     visualize_reward=True,
13 |     from_pixels=False,
14 |     height=84,
15 |     width=84,
16 |     camera_id=0,
17 |     frame_skip=1,
18 |     episode_length=1000,
19 |     off_center=False,
20 |     environment_kwargs=None
21 | ):
22 |     env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
23 | 
24 |     if from_pixels:
25 |         assert not visualize_reward, 'cannot use visualize reward when learning from pixels'
26 | 
27 |     # shorten episode length
28 |     max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
29 | 
30 |     if not env_id in gym.envs.registry.env_specs:
31 |         register(
32 |             id=env_id,
33 |             entry_point='dmc2gym.wrappers:DMCWrapper',
34 |             kwargs={
35 |                 'domain_name': domain_name,
36 |                 'task_name': task_name,
37 |                 'resource_files': resource_files,
38 |                 'img_source': img_source,
39 |                 'total_frames': total_frames,
40 |                 'task_kwargs': {
41 |                     'random': seed
42 |                 },
43 |                 'environment_kwargs': environment_kwargs,
44 |                 'visualize_reward': visualize_reward,
45 |                 'from_pixels': from_pixels,
46 |                 'height': height,
47 |                 'width': width,
48 |                 'camera_id': camera_id,
49 |                 'frame_skip': frame_skip,
50 |                 'off_center': off_center,
51 |             },
52 |             max_episode_steps=max_episode_steps
53 |         )
54 |     return gym.make(env_id)
55 | 
--------------------------------------------------------------------------------
/encoder.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn as nn
  3 | 
  4 | 
  5 | def tie_weights(src, trg):
  6 |     assert type(src) == type(trg)
  7 |     trg.weight = src.weight
  8 |     trg.bias = src.bias
  9 | 
 10 | 
 11 | ### Works with Finism SAC
 12 | # OUT_DIM = {2: 39, 4: 35, 6: 31}
 13 | # OUT_DIM_84 = {2: 39, 4: 35, 6: 31}
 14 | # OUT_DIM_100 = {2: 39, 4: 43, 6: 31}
 15 | 
 16 | ### Works with RAD SAC
 17 | OUT_DIM = {2: 39, 4: 35, 6: 31}
 18 | OUT_DIM_84 = {2: 29, 4: 35, 6: 21}
 19 | OUT_DIM_100 = {4: 47}
 20 | 
 21 |  
 22 | class PixelEncoder(nn.Module):
 23 |     """Convolutional encoder of pixels observations."""
 24 |     def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32,output_logits=False):
 25 |         super().__init__()
 26 | 
 27 |         assert len(obs_shape) == 3
 28 |         self.obs_shape = obs_shape
 29 |         self.feature_dim = feature_dim
 30 |         self.num_layers = num_layers
 31 |         # try 2 5x5s with strides 2x2. with samep adding, it should reduce 84 to 21, so with valid, it should be even smaller than 21.
 32 |         self.convs = nn.ModuleList(
 33 |             [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
 34 |         )
 35 |         for i in range(num_layers - 1):
 36 |             self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
 37 | 
 38 |         if obs_shape[-1] == 100:
 39 |             assert num_layers in OUT_DIM_100
 40 |             out_dim = OUT_DIM_100[num_layers]
 41 |         elif obs_shape[-1] == 84:
 42 |             out_dim = OUT_DIM_84[num_layers]
 43 |         else:
 44 |             out_dim = OUT_DIM[num_layers]
 45 | 
 46 |         self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
 47 |         self.ln = nn.LayerNorm(self.feature_dim)
 48 | 
 49 |         self.outputs = dict()
 50 |         self.output_logits = output_logits
 51 | 
 52 |     def reparameterize(self, mu, logstd):
 53 |         std = torch.exp(logstd)
 54 |         eps = torch.randn_like(std)
 55 |         return mu + eps * std
 56 | 
 57 |     def forward_conv(self, obs):
 58 |         if obs.max() > 1.:
 59 |             obs = obs / 255.
 60 | 
 61 |         self.outputs['obs'] = obs
 62 | 
 63 |         conv = torch.relu(self.convs[0](obs))
 64 |         self.outputs['conv1'] = conv
 65 | 
 66 |         for i in range(1, self.num_layers):
 67 |             conv = torch.relu(self.convs[i](conv))
 68 |             self.outputs['conv%s' % (i + 1)] = conv
 69 | 
 70 |         h = conv.view(conv.size(0), -1)
 71 | 
 72 |         return h
 73 | 
 74 |     def forward(self, obs, detach=False):
 75 |         h = self.forward_conv(obs)
 76 | 
 77 |         if detach:
 78 |             h = h.detach()
 79 | 
 80 |         h_fc = self.fc(h)
 81 |         self.outputs['fc'] = h_fc
 82 | 
 83 |         h_norm = self.ln(h_fc)
 84 |         self.outputs['ln'] = h_norm
 85 | 
 86 |         if self.output_logits:
 87 |             out = h_norm
 88 |         else:
 89 |             out = torch.tanh(h_norm)
 90 |             self.outputs['tanh'] = out
 91 | 
 92 |         return out
 93 | 
 94 |     def copy_conv_weights_from(self, source):
 95 |         """Tie convolutional layers"""
 96 |         # only tie conv layers
 97 |         for i in range(self.num_layers):
 98 |             tie_weights(src=source.convs[i], trg=self.convs[i])
 99 | 
100 |     def log(self, L, step, log_freq):
101 |         if step % log_freq != 0:
102 |             return
103 | 
104 |         for k, v in self.outputs.items():
105 |             L.log_histogram('train_encoder/%s_hist' % k, v, step)
106 |             if len(v.shape) > 2:
107 |                 L.log_image('train_encoder/%s_img' % k, v[0], step)
108 | 
109 |         for i in range(self.num_layers):
110 |             L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
111 |         L.log_param('train_encoder/fc', self.fc, step)
112 |         L.log_param('train_encoder/ln', self.ln, step)
113 | 
114 | 
115 | class IdentityEncoder(nn.Module):
116 |     def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args):
117 |         super().__init__()
118 | 
119 |         assert len(obs_shape) == 1
120 |         self.feature_dim = obs_shape[0]
121 | 
122 |     def forward(self, obs, detach=False):
123 |         return obs
124 | 
125 |     def copy_conv_weights_from(self, source):
126 |         pass
127 | 
128 |     def log(self, L, step, log_freq):
129 |         pass
130 | 
131 | 
132 | _AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder}
133 | 
134 | 
135 | def make_encoder(
136 |     encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False
137 | ):
138 |     assert encoder_type in _AVAILABLE_ENCODERS
139 |     return _AVAILABLE_ENCODERS[encoder_type](
140 |         obs_shape, feature_dim, num_layers, num_filters, output_logits
141 |     )
142 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/README.md:
--------------------------------------------------------------------------------
 1 | # DeepMind Control Suite.
 2 | 
 3 | This submodule contains the domains and tasks described in the
 4 | [DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690).
 5 | 
 6 | ## Quickstart
 7 | 
 8 | ```python
 9 | from dm_control import suite
10 | import numpy as np
11 | 
12 | # Load one task:
13 | env = suite.load(domain_name="cartpole", task_name="swingup")
14 | 
15 | # Iterate over a task set:
16 | for domain_name, task_name in suite.BENCHMARKING:
17 |   env = suite.load(domain_name, task_name)
18 | 
19 | # Step through an episode and print out reward, discount and observation.
20 | action_spec = env.action_spec()
21 | time_step = env.reset()
22 | while not time_step.last():
23 |   action = np.random.uniform(action_spec.minimum,
24 |                              action_spec.maximum,
25 |                              size=action_spec.shape)
26 |   time_step = env.step(action)
27 |   print(time_step.reward, time_step.discount, time_step.observation)
28 | ```
29 | 
30 | ## Illustration video
31 | 
32 | Below is a video montage of solved Control Suite tasks, with reward
33 | visualisation enabled.
34 | 
35 | [](https://www.youtube.com/watch?v=rAai4QzcYbs)
36 | 
37 | 
38 | ### Quadruped domain [April 2019]
39 | 
40 | Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are:
41 | 
42 | - 4 DoFs per leg, 1 constraining tendon.
43 | - 3 actuators per leg: 'yaw', 'lift', 'extend'.
44 | - Filtered position actuators with timescale of 100ms.
45 | - Sensors include an IMU, force/torque sensors, and rangefinders.
46 | 
47 | Four tasks:
48 | 
49 | - `walk` and `run`: self-right the body then move forward at a desired speed.
50 | - `escape`: escape a bowl-shaped random terrain (uses rangefinders).
51 | - `fetch`, go to a moving ball and bring it to a target.
52 | 
53 | All behaviors in the video below were trained with [Abdolmaleki et al's
54 | MPO](https://arxiv.org/abs/1806.06920).
55 | 
56 | [](https://www.youtube.com/watch?v=RhRLjbb7pBE)
57 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/acrobot.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Acrobot domain."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | from dm_control import mujoco
 25 | from dm_control.rl import control
 26 | from local_dm_control_suite import base
 27 | from local_dm_control_suite import common
 28 | from dm_control.utils import containers
 29 | from dm_control.utils import rewards
 30 | import numpy as np
 31 | 
 32 | _DEFAULT_TIME_LIMIT = 10
 33 | SUITE = containers.TaggedTasks()
 34 | 
 35 | 
 36 | def get_model_and_assets():
 37 |   """Returns a tuple containing the model XML string and a dict of assets."""
 38 |   return common.read_model('acrobot.xml'), common.ASSETS
 39 | 
 40 | 
 41 | @SUITE.add('benchmarking')
 42 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
 43 |             environment_kwargs=None):
 44 |   """Returns Acrobot balance task."""
 45 |   physics = Physics.from_xml_string(*get_model_and_assets())
 46 |   task = Balance(sparse=False, random=random)
 47 |   environment_kwargs = environment_kwargs or {}
 48 |   return control.Environment(
 49 |       physics, task, time_limit=time_limit, **environment_kwargs)
 50 | 
 51 | 
 52 | @SUITE.add('benchmarking')
 53 | def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
 54 |                    environment_kwargs=None):
 55 |   """Returns Acrobot sparse balance."""
 56 |   physics = Physics.from_xml_string(*get_model_and_assets())
 57 |   task = Balance(sparse=True, random=random)
 58 |   environment_kwargs = environment_kwargs or {}
 59 |   return control.Environment(
 60 |       physics, task, time_limit=time_limit, **environment_kwargs)
 61 | 
 62 | 
 63 | class Physics(mujoco.Physics):
 64 |   """Physics simulation with additional features for the Acrobot domain."""
 65 | 
 66 |   def horizontal(self):
 67 |     """Returns horizontal (x) component of body frame z-axes."""
 68 |     return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz']
 69 | 
 70 |   def vertical(self):
 71 |     """Returns vertical (z) component of body frame z-axes."""
 72 |     return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz']
 73 | 
 74 |   def to_target(self):
 75 |     """Returns the distance from the tip to the target."""
 76 |     tip_to_target = (self.named.data.site_xpos['target'] -
 77 |                      self.named.data.site_xpos['tip'])
 78 |     return np.linalg.norm(tip_to_target)
 79 | 
 80 |   def orientations(self):
 81 |     """Returns the sines and cosines of the pole angles."""
 82 |     return np.concatenate((self.horizontal(), self.vertical()))
 83 | 
 84 | 
 85 | class Balance(base.Task):
 86 |   """An Acrobot `Task` to swing up and balance the pole."""
 87 | 
 88 |   def __init__(self, sparse, random=None):
 89 |     """Initializes an instance of `Balance`.
 90 | 
 91 |     Args:
 92 |       sparse: A `bool` specifying whether to use a sparse (indicator) reward.
 93 |       random: Optional, either a `numpy.random.RandomState` instance, an
 94 |         integer seed for creating a new `RandomState`, or None to select a seed
 95 |         automatically (default).
 96 |     """
 97 |     self._sparse = sparse
 98 |     super(Balance, self).__init__(random=random)
 99 | 
100 |   def initialize_episode(self, physics):
101 |     """Sets the state of the environment at the start of each episode.
102 | 
103 |     Shoulder and elbow are set to a random position between [-pi, pi).
104 | 
105 |     Args:
106 |       physics: An instance of `Physics`.
107 |     """
108 |     physics.named.data.qpos[
109 |         ['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2)
110 |     super(Balance, self).initialize_episode(physics)
111 | 
112 |   def get_observation(self, physics):
113 |     """Returns an observation of pole orientation and angular velocities."""
114 |     obs = collections.OrderedDict()
115 |     obs['orientations'] = physics.orientations()
116 |     obs['velocity'] = physics.velocity()
117 |     return obs
118 | 
119 |   def _get_reward(self, physics, sparse):
120 |     target_radius = physics.named.model.site_size['target', 0]
121 |     return rewards.tolerance(physics.to_target(),
122 |                              bounds=(0, target_radius),
123 |                              margin=0 if sparse else 1)
124 | 
125 |   def get_reward(self, physics):
126 |     """Returns a sparse or a smooth reward, as specified in the constructor."""
127 |     return self._get_reward(physics, sparse=self._sparse)
128 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/acrobot.xml:
--------------------------------------------------------------------------------
 1 | 
 8 | 
 9 |   
10 |   
11 |   
12 | 
13 |   
14 |     
15 |     
16 |   
17 | 
18 |   
21 | 
22 |   
23 |     
24 |     
25 |     
26 |     
27 |     
28 |     
29 |       
30 |       
31 |       
32 |       
33 |         
34 |         
35 |         
36 |       
37 |     
38 |   
39 | 
40 |    
41 |     
42 |   
43 | 
44 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/ball_in_cup.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Ball-in-Cup Domain."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | from dm_control import mujoco
 25 | from dm_control.rl import control
 26 | from local_dm_control_suite import base
 27 | from local_dm_control_suite import common
 28 | from dm_control.utils import containers
 29 | 
 30 | _DEFAULT_TIME_LIMIT = 20  # (seconds)
 31 | _CONTROL_TIMESTEP = .02   # (seconds)
 32 | 
 33 | 
 34 | SUITE = containers.TaggedTasks()
 35 | 
 36 | 
 37 | def get_model_and_assets():
 38 |   """Returns a tuple containing the model XML string and a dict of assets."""
 39 |   return common.read_model('ball_in_cup.xml'), common.ASSETS
 40 | 
 41 | 
 42 | @SUITE.add('benchmarking', 'easy')
 43 | def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
 44 |   """Returns the Ball-in-Cup task."""
 45 |   physics = Physics.from_xml_string(*get_model_and_assets())
 46 |   task = BallInCup(random=random)
 47 |   environment_kwargs = environment_kwargs or {}
 48 |   return control.Environment(
 49 |       physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
 50 |       **environment_kwargs)
 51 | 
 52 | 
 53 | class Physics(mujoco.Physics):
 54 |   """Physics with additional features for the Ball-in-Cup domain."""
 55 | 
 56 |   def ball_to_target(self):
 57 |     """Returns the vector from the ball to the target."""
 58 |     target = self.named.data.site_xpos['target', ['x', 'z']]
 59 |     ball = self.named.data.xpos['ball', ['x', 'z']]
 60 |     return target - ball
 61 | 
 62 |   def in_target(self):
 63 |     """Returns 1 if the ball is in the target, 0 otherwise."""
 64 |     ball_to_target = abs(self.ball_to_target())
 65 |     target_size = self.named.model.site_size['target', [0, 2]]
 66 |     ball_size = self.named.model.geom_size['ball', 0]
 67 |     return float(all(ball_to_target < target_size - ball_size))
 68 | 
 69 | 
 70 | class BallInCup(base.Task):
 71 |   """The Ball-in-Cup task. Put the ball in the cup."""
 72 | 
 73 |   def initialize_episode(self, physics):
 74 |     """Sets the state of the environment at the start of each episode.
 75 | 
 76 |     Args:
 77 |       physics: An instance of `Physics`.
 78 | 
 79 |     """
 80 |     # Find a collision-free random initial position of the ball.
 81 |     penetrating = True
 82 |     while penetrating:
 83 |       # Assign a random ball position.
 84 |       physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
 85 |       physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
 86 |       # Check for collisions.
 87 |       physics.after_reset()
 88 |       penetrating = physics.data.ncon > 0
 89 |     super(BallInCup, self).initialize_episode(physics)
 90 | 
 91 |   def get_observation(self, physics):
 92 |     """Returns an observation of the state."""
 93 |     obs = collections.OrderedDict()
 94 |     obs['position'] = physics.position()
 95 |     obs['velocity'] = physics.velocity()
 96 |     return obs
 97 | 
 98 |   def get_reward(self, physics):
 99 |     """Returns a sparse reward."""
100 |     return physics.in_target()
101 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/ball_in_cup.xml:
--------------------------------------------------------------------------------
 1 | 
 2 | 
 3 |   
 4 |   
 5 |   
 6 | 
 7 |   
 8 |     
 9 |     
10 |       
11 |       
12 |     
13 |   
14 | 
15 |   
16 |     
17 |     
18 |     
19 |     
20 | 
21 |     
22 |       
23 |       
24 |       
25 |       
26 |       
27 |       
28 |       
29 |       
30 |       
31 |     
32 | 
33 |     
34 |       
35 |       
36 |       
37 |       
38 |     
39 |   
40 | 
41 |   
42 |     
43 |     
44 |   
45 | 
46 |   
47 |     
48 |       
49 |       
50 |     
51 |   
52 | 
53 | 
54 | 
55 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/base.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Base class for tasks in the Control Suite."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | from dm_control import mujoco
 23 | from dm_control.rl import control
 24 | 
 25 | import numpy as np
 26 | 
 27 | 
 28 | class Task(control.Task):
 29 |   """Base class for tasks in the Control Suite.
 30 | 
 31 |   Actions are mapped directly to the states of MuJoCo actuators: each element of
 32 |   the action array is used to set the control input for a single actuator. The
 33 |   ordering of the actuators is the same as in the corresponding MJCF XML file.
 34 | 
 35 |   Attributes:
 36 |     random: A `numpy.random.RandomState` instance. This should be used to
 37 |       generate all random variables associated with the task, such as random
 38 |       starting states, observation noise* etc.
 39 | 
 40 |   *If sensor noise is enabled in the MuJoCo model then this will be generated
 41 |   using MuJoCo's internal RNG, which has its own independent state.
 42 |   """
 43 | 
 44 |   def __init__(self, random=None):
 45 |     """Initializes a new continuous control task.
 46 | 
 47 |     Args:
 48 |       random: Optional, either a `numpy.random.RandomState` instance, an integer
 49 |         seed for creating a new `RandomState`, or None to select a seed
 50 |         automatically (default).
 51 |     """
 52 |     if not isinstance(random, np.random.RandomState):
 53 |       random = np.random.RandomState(random)
 54 |     self._random = random
 55 |     self._visualize_reward = False
 56 | 
 57 |   @property
 58 |   def random(self):
 59 |     """Task-specific `numpy.random.RandomState` instance."""
 60 |     return self._random
 61 | 
 62 |   def action_spec(self, physics):
 63 |     """Returns a `BoundedArraySpec` matching the `physics` actuators."""
 64 |     return mujoco.action_spec(physics)
 65 | 
 66 |   def initialize_episode(self, physics):
 67 |     """Resets geom colors to their defaults after starting a new episode.
 68 | 
 69 |     Subclasses of `base.Task` must delegate to this method after performing
 70 |     their own initialization.
 71 | 
 72 |     Args:
 73 |       physics: An instance of `mujoco.Physics`.
 74 |     """
 75 |     self.after_step(physics)
 76 | 
 77 |   def before_step(self, action, physics):
 78 |     """Sets the control signal for the actuators to values in `action`."""
 79 |     # Support legacy internal code.
 80 |     action = getattr(action, "continuous_actions", action)
 81 |     physics.set_control(action)
 82 | 
 83 |   def after_step(self, physics):
 84 |     """Modifies colors according to the reward."""
 85 |     if self._visualize_reward:
 86 |       reward = np.clip(self.get_reward(physics), 0.0, 1.0)
 87 |       _set_reward_colors(physics, reward)
 88 | 
 89 |   @property
 90 |   def visualize_reward(self):
 91 |     return self._visualize_reward
 92 | 
 93 |   @visualize_reward.setter
 94 |   def visualize_reward(self, value):
 95 |     if not isinstance(value, bool):
 96 |       raise ValueError("Expected a boolean, got {}.".format(type(value)))
 97 |     self._visualize_reward = value
 98 | 
 99 | 
100 | _MATERIALS = ["self", "effector", "target"]
101 | _DEFAULT = [name + "_default" for name in _MATERIALS]
102 | _HIGHLIGHT = [name + "_highlight" for name in _MATERIALS]
103 | 
104 | 
105 | def _set_reward_colors(physics, reward):
106 |   """Sets the highlight, effector and target colors according to the reward."""
107 |   assert 0.0 <= reward <= 1.0
108 |   colors = physics.named.model.mat_rgba
109 |   default = colors[_DEFAULT]
110 |   highlight = colors[_HIGHLIGHT]
111 |   blend_coef = reward ** 4  # Better color distinction near high rewards.
112 |   colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default
113 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/cartpole.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |       
13 |       
14 |     
15 |   
16 | 
17 |   
18 |     
19 |     
20 |     
21 |     
22 |     
23 |     
24 |     
25 |       
26 |       
27 |       
28 |         
29 |         
30 |       
31 |     
32 |   
33 | 
34 |   
35 |     
36 |   
37 | 
38 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/cheetah.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Cheetah Domain."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import collections
23 | 
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 | 
31 | 
32 | # How long the simulation will run, in seconds.
33 | _DEFAULT_TIME_LIMIT = 10
34 | 
35 | # Running speed above which reward is 1.
36 | _RUN_SPEED = 10
37 | 
38 | SUITE = containers.TaggedTasks()
39 | 
40 | 
41 | def get_model_and_assets():
42 |   """Returns a tuple containing the model XML string and a dict of assets."""
43 |   return common.read_model('cheetah.xml'), common.ASSETS
44 | 
45 | 
46 | @SUITE.add('benchmarking')
47 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
48 |   """Returns the run task."""
49 |   physics = Physics.from_xml_string(*get_model_and_assets())
50 |   task = Cheetah(random=random)
51 |   environment_kwargs = environment_kwargs or {}
52 |   return control.Environment(physics, task, time_limit=time_limit,
53 |                              **environment_kwargs)
54 | 
55 | 
56 | class Physics(mujoco.Physics):
57 |   """Physics simulation with additional features for the Cheetah domain."""
58 | 
59 |   def speed(self):
60 |     """Returns the horizontal speed of the Cheetah."""
61 |     return self.named.data.sensordata['torso_subtreelinvel'][0]
62 | 
63 | 
64 | class Cheetah(base.Task):
65 |   """A `Task` to train a running Cheetah."""
66 | 
67 |   def initialize_episode(self, physics):
68 |     """Sets the state of the environment at the start of each episode."""
69 |     # The indexing below assumes that all joints have a single DOF.
70 |     assert physics.model.nq == physics.model.njnt
71 |     is_limited = physics.model.jnt_limited == 1
72 |     lower, upper = physics.model.jnt_range[is_limited].T
73 |     physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
74 | 
75 |     # Stabilize the model before the actual simulation.
76 |     for _ in range(200):
77 |       physics.step()
78 | 
79 |     physics.data.time = 0
80 |     self._timeout_progress = 0
81 |     super(Cheetah, self).initialize_episode(physics)
82 | 
83 |   def get_observation(self, physics):
84 |     """Returns an observation of the state, ignoring horizontal position."""
85 |     obs = collections.OrderedDict()
86 |     # Ignores horizontal position to maintain translational invariance.
87 |     obs['position'] = physics.data.qpos[1:].copy()
88 |     obs['velocity'] = physics.velocity()
89 |     return obs
90 | 
91 |   def get_reward(self, physics):
92 |     """Returns a reward to the agent."""
93 |     return rewards.tolerance(physics.speed(),
94 |                              bounds=(_RUN_SPEED, float('inf')),
95 |                              margin=_RUN_SPEED,
96 |                              value_at_margin=0,
97 |                              sigmoid='linear')
98 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/cheetah.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 7 | 
 8 |   
 9 |     
10 |       
11 |       
12 |     
13 |     
14 |       
15 |     
16 |     
17 |   
18 | 
19 |   
20 | 
21 |   
22 | 
23 |   
24 |     
25 |     
26 |       
27 |       
28 |       
29 |       
30 |       
31 |       
32 |       
33 |       
34 |       
35 |         
36 |         
37 |         
38 |           
39 |           
40 |           
41 |             
42 |             
43 |           
44 |         
45 |       
46 |       
47 |         
48 |         
49 |         
50 |           
51 |           
52 |           
53 |             
54 |             
55 |           
56 |         
57 |       
58 |     
59 |   
60 | 
61 |   
62 |     
63 |   
64 | 
65 |   
66 |     
67 |     
68 |     
69 |     
70 |     
71 |     
72 |   
73 | 
74 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/common/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Functions to manage the common assets for domains."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import os
23 | from dm_control.utils import io as resources
24 | 
25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
26 | _FILENAMES = [
27 |     "./common/materials.xml",
28 |     "./common/materials_white_floor.xml",
29 |     "./common/skybox.xml",
30 |     "./common/visual.xml",
31 | ]
32 | 
33 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
34 |           for filename in _FILENAMES}
35 | 
36 | 
37 | def read_model(model_filename):
38 |   """Reads a model XML file and returns its contents as a string."""
39 |   return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))
40 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/common/materials.xml:
--------------------------------------------------------------------------------
 1 | 
 6 | 
 7 |   
 8 |     
 9 |     
10 |     
11 |     
12 |     
13 |     
14 |     
15 |     
16 |     
17 |     
18 |     
19 |     
20 |     
21 |     
22 |   
23 | 
24 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/common/materials_white_floor.xml:
--------------------------------------------------------------------------------
 1 | 
 6 | 
 7 |   
 8 |     
 9 |     
10 |     
11 |     
12 |     
13 |     
14 |     
15 |     
16 |     
17 |     
18 |     
19 |     
20 |     
21 |     
22 |   
23 | 
24 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/common/skybox.xml:
--------------------------------------------------------------------------------
1 | 
2 |   
3 |       
5 |   
6 | 
7 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/common/visual.xml:
--------------------------------------------------------------------------------
1 | 
2 |   
3 |     
4 |     
5 |     
6 |   
7 | 
8 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/demos/mocap_demo.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Demonstration of amc parsing for CMU mocap database.
17 | 
18 | To run the demo, supply a path to a `.amc` file:
19 | 
20 |     python mocap_demo --filename='path/to/mocap.amc'
21 | 
22 | CMU motion capture clips are available at mocap.cs.cmu.edu
23 | """
24 | 
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 | 
29 | import time
30 | # Internal dependencies.
31 | 
32 | from absl import app
33 | from absl import flags
34 | 
35 | from local_dm_control_suite import humanoid_CMU
36 | from dm_control.suite.utils import parse_amc
37 | 
38 | import matplotlib.pyplot as plt
39 | import numpy as np
40 | 
41 | FLAGS = flags.FLAGS
42 | flags.DEFINE_string('filename', None, 'amc file to be converted.')
43 | flags.DEFINE_integer('max_num_frames', 90,
44 |                      'Maximum number of frames for plotting/playback')
45 | 
46 | 
47 | def main(unused_argv):
48 |   env = humanoid_CMU.stand()
49 | 
50 |   # Parse and convert specified clip.
51 |   converted = parse_amc.convert(FLAGS.filename,
52 |                                 env.physics, env.control_timestep())
53 | 
54 |   max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1)
55 | 
56 |   width = 480
57 |   height = 480
58 |   video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)
59 | 
60 |   for i in range(max_frame):
61 |     p_i = converted.qpos[:, i]
62 |     with env.physics.reset_context():
63 |       env.physics.data.qpos[:] = p_i
64 |     video[i] = np.hstack([env.physics.render(height, width, camera_id=0),
65 |                           env.physics.render(height, width, camera_id=1)])
66 | 
67 |   tic = time.time()
68 |   for i in range(max_frame):
69 |     if i == 0:
70 |       img = plt.imshow(video[i])
71 |     else:
72 |       img.set_data(video[i])
73 |     toc = time.time()
74 |     clock_dt = toc - tic
75 |     tic = time.time()
76 |     # Real-time playback not always possible as clock_dt > .03
77 |     plt.pause(max(0.01, 0.03 - clock_dt))  # Need min display time > 0.0.
78 |     plt.draw()
79 |   plt.waitforbuttonpress()
80 | 
81 | 
82 | if __name__ == '__main__':
83 |   flags.mark_flag_as_required('filename')
84 |   app.run(main)
85 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/demos/zeros.amc:
--------------------------------------------------------------------------------
  1 | #DUMMY AMC for testing
  2 | :FULLY-SPECIFIED
  3 | :DEGREES
  4 | 1
  5 | root 0 0 0 0 0 0
  6 | lowerback 0 0 0
  7 | upperback 0 0 0
  8 | thorax 0 0 0
  9 | lowerneck 0 0 0
 10 | upperneck 0 0 0
 11 | head 0 0 0
 12 | rclavicle 0 0
 13 | rhumerus 0 0 0
 14 | rradius 0
 15 | rwrist 0
 16 | rhand 0 0
 17 | rfingers 0
 18 | rthumb 0 0
 19 | lclavicle 0 0
 20 | lhumerus 0 0 0
 21 | lradius 0
 22 | lwrist 0
 23 | lhand 0 0
 24 | lfingers 0
 25 | lthumb 0 0
 26 | rfemur 0 0 0
 27 | rtibia 0
 28 | rfoot 0 0
 29 | rtoes 0
 30 | lfemur 0 0 0
 31 | ltibia 0
 32 | lfoot 0 0
 33 | ltoes 0
 34 | 2
 35 | root 0 0 0 0 0 0
 36 | lowerback 0 0 0
 37 | upperback 0 0 0
 38 | thorax 0 0 0
 39 | lowerneck 0 0 0
 40 | upperneck 0 0 0
 41 | head 0 0 0
 42 | rclavicle 0 0
 43 | rhumerus 0 0 0
 44 | rradius 0
 45 | rwrist 0
 46 | rhand 0 0
 47 | rfingers 0
 48 | rthumb 0 0
 49 | lclavicle 0 0
 50 | lhumerus 0 0 0
 51 | lradius 0
 52 | lwrist 0
 53 | lhand 0 0
 54 | lfingers 0
 55 | lthumb 0 0
 56 | rfemur 0 0 0
 57 | rtibia 0
 58 | rfoot 0 0
 59 | rtoes 0
 60 | lfemur 0 0 0
 61 | ltibia 0
 62 | lfoot 0 0
 63 | ltoes 0
 64 | 3
 65 | root 0 0 0 0 0 0
 66 | lowerback 0 0 0
 67 | upperback 0 0 0
 68 | thorax 0 0 0
 69 | lowerneck 0 0 0
 70 | upperneck 0 0 0
 71 | head 0 0 0
 72 | rclavicle 0 0
 73 | rhumerus 0 0 0
 74 | rradius 0
 75 | rwrist 0
 76 | rhand 0 0
 77 | rfingers 0
 78 | rthumb 0 0
 79 | lclavicle 0 0
 80 | lhumerus 0 0 0
 81 | lradius 0
 82 | lwrist 0
 83 | lhand 0 0
 84 | lfingers 0
 85 | lthumb 0 0
 86 | rfemur 0 0 0
 87 | rtibia 0
 88 | rfoot 0 0
 89 | rtoes 0
 90 | lfemur 0 0 0
 91 | ltibia 0
 92 | lfoot 0 0
 93 | ltoes 0
 94 | 4
 95 | root 0 0 0 0 0 0
 96 | lowerback 0 0 0
 97 | upperback 0 0 0
 98 | thorax 0 0 0
 99 | lowerneck 0 0 0
100 | upperneck 0 0 0
101 | head 0 0 0
102 | rclavicle 0 0
103 | rhumerus 0 0 0
104 | rradius 0
105 | rwrist 0
106 | rhand 0 0
107 | rfingers 0
108 | rthumb 0 0
109 | lclavicle 0 0
110 | lhumerus 0 0 0
111 | lradius 0
112 | lwrist 0
113 | lhand 0 0
114 | lfingers 0
115 | lthumb 0 0
116 | rfemur 0 0 0
117 | rtibia 0
118 | rfoot 0 0
119 | rtoes 0
120 | lfemur 0 0 0
121 | ltibia 0
122 | lfoot 0 0
123 | ltoes 0
124 | 5
125 | root 0 0 0 0 0 0
126 | lowerback 0 0 0
127 | upperback 0 0 0
128 | thorax 0 0 0
129 | lowerneck 0 0 0
130 | upperneck 0 0 0
131 | head 0 0 0
132 | rclavicle 0 0
133 | rhumerus 0 0 0
134 | rradius 0
135 | rwrist 0
136 | rhand 0 0
137 | rfingers 0
138 | rthumb 0 0
139 | lclavicle 0 0
140 | lhumerus 0 0 0
141 | lradius 0
142 | lwrist 0
143 | lhand 0 0
144 | lfingers 0
145 | lthumb 0 0
146 | rfemur 0 0 0
147 | rtibia 0
148 | rfoot 0 0
149 | rtoes 0
150 | lfemur 0 0 0
151 | ltibia 0
152 | lfoot 0 0
153 | ltoes 0
154 | 6
155 | root 0 0 0 0 0 0
156 | lowerback 0 0 0
157 | upperback 0 0 0
158 | thorax 0 0 0
159 | lowerneck 0 0 0
160 | upperneck 0 0 0
161 | head 0 0 0
162 | rclavicle 0 0
163 | rhumerus 0 0 0
164 | rradius 0
165 | rwrist 0
166 | rhand 0 0
167 | rfingers 0
168 | rthumb 0 0
169 | lclavicle 0 0
170 | lhumerus 0 0 0
171 | lradius 0
172 | lwrist 0
173 | lhand 0 0
174 | lfingers 0
175 | lthumb 0 0
176 | rfemur 0 0 0
177 | rtibia 0
178 | rfoot 0 0
179 | rtoes 0
180 | lfemur 0 0 0
181 | ltibia 0
182 | lfoot 0 0
183 | ltoes 0
184 | 7
185 | root 0 0 0 0 0 0
186 | lowerback 0 0 0
187 | upperback 0 0 0
188 | thorax 0 0 0
189 | lowerneck 0 0 0
190 | upperneck 0 0 0
191 | head 0 0 0
192 | rclavicle 0 0
193 | rhumerus 0 0 0
194 | rradius 0
195 | rwrist 0
196 | rhand 0 0
197 | rfingers 0
198 | rthumb 0 0
199 | lclavicle 0 0
200 | lhumerus 0 0 0
201 | lradius 0
202 | lwrist 0
203 | lhand 0 0
204 | lfingers 0
205 | lthumb 0 0
206 | rfemur 0 0 0
207 | rtibia 0
208 | rfoot 0 0
209 | rtoes 0
210 | lfemur 0 0 0
211 | ltibia 0
212 | lfoot 0 0
213 | ltoes 0
214 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/explore.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2018 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Control suite environments explorer."""
16 | 
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | 
21 | from absl import app
22 | from absl import flags
23 | from dm_control import suite
24 | from dm_control.suite.wrappers import action_noise
25 | from six.moves import input
26 | 
27 | from dm_control import viewer
28 | 
29 | 
30 | _ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS]
31 | 
32 | flags.DEFINE_enum('environment_name', None, _ALL_NAMES,
33 |                   'Optional \'domain_name.task_name\' pair specifying the '
34 |                   'environment to load. If unspecified a prompt will appear to '
35 |                   'select one.')
36 | flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.')
37 | flags.DEFINE_bool('visualize_reward', True,
38 |                   'Whether to vary the colors of geoms according to the '
39 |                   'current reward value.')
40 | flags.DEFINE_float('action_noise', 0.,
41 |                    'Standard deviation of Gaussian noise to apply to actions, '
42 |                    'expressed as a fraction of the max-min range for each '
43 |                    'action dimension. Defaults to 0, i.e. no noise.')
44 | FLAGS = flags.FLAGS
45 | 
46 | 
47 | def prompt_environment_name(prompt, values):
48 |   environment_name = None
49 |   while not environment_name:
50 |     environment_name = input(prompt)
51 |     if not environment_name or values.index(environment_name) < 0:
52 |       print('"%s" is not a valid environment name.' % environment_name)
53 |       environment_name = None
54 |   return environment_name
55 | 
56 | 
57 | def main(argv):
58 |   del argv
59 |   environment_name = FLAGS.environment_name
60 |   if environment_name is None:
61 |     print('\n  '.join(['Available environments:'] + _ALL_NAMES))
62 |     environment_name = prompt_environment_name(
63 |         'Please select an environment name: ', _ALL_NAMES)
64 | 
65 |   index = _ALL_NAMES.index(environment_name)
66 |   domain_name, task_name = suite.ALL_TASKS[index]
67 | 
68 |   task_kwargs = {}
69 |   if not FLAGS.timeout:
70 |     task_kwargs['time_limit'] = float('inf')
71 | 
72 |   def loader():
73 |     env = suite.load(
74 |         domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs)
75 |     env.task.visualize_reward = FLAGS.visualize_reward
76 |     if FLAGS.action_noise > 0:
77 |       env = action_noise.Wrapper(env, scale=FLAGS.action_noise)
78 |     return env
79 | 
80 |   viewer.launch(loader)
81 | 
82 | 
83 | if __name__ == '__main__':
84 |   app.run(main)
85 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/finger.xml:
--------------------------------------------------------------------------------
 1 |  
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |     
13 |     
14 |     
15 |       
16 |       
17 |     
18 |   
19 | 
20 |   
21 |     
22 |     
23 |     
24 |     
25 | 
26 |     
27 |       
28 |       
29 |       
30 |       
31 |         
32 |         
33 |         
34 |         
35 |         
36 |       
37 |     
38 | 
39 |     
40 |       
41 |       
42 |       
43 |       
44 |       
45 |     
46 | 
47 |     
48 |   
49 | 
50 |   
51 |     
52 |     
53 |   
54 | 
55 |   
56 |   
57 |     
58 |     
59 |     
60 |     
61 |     
62 |     
63 |     
64 |     
65 |     
66 |     
67 |     
68 |     
69 |   
70 | 
71 | 
72 | 
73 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/fish.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 |       
 6 |   
 7 | 
 8 | 
 9 |   
12 | 
13 |   
14 |     
15 |     
16 |       
17 |       
18 |     
19 |   
20 | 
21 |   
22 |     
23 |     
24 |     
25 |     
26 |     
27 |     
28 |     
29 |       
30 |       
31 |       
32 |       
33 |       
34 |       
35 |       
36 |       
37 |       
38 |       
39 |       
40 |         
41 |         
42 |         
43 |         
44 |           
45 |           
46 |         
47 |       
48 |       
49 |         
50 |         
51 |         
52 |       
53 |       
54 |         
55 |         
56 |         
57 |       
58 |     
59 |   
60 | 
61 |   
62 |     
63 |       
64 |       
65 |     
66 |     
67 |       
68 |       
69 |     
70 |   
71 | 
72 |   
73 |     
74 |     
75 |     
76 |     
77 |     
78 |   
79 | 
80 |   
81 |     
82 |     
83 |   
84 | 
85 | 
86 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/hopper.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 7 | 
 8 |   
 9 |     
10 |       
11 |       
12 |       
13 |     
14 |     
15 |       
16 |     
17 |     
18 |   
19 | 
20 |   
21 | 
22 |   
23 |     
24 |     
25 |     
26 |     
27 |       
28 |       
29 |       
30 |       
31 |       
32 |       
33 |       
34 |         
35 |         
36 |         
37 |           
38 |           
39 |           
40 |             
41 |             
42 |             
43 |               
44 |               
45 |               
46 |               
47 |             
48 |           
49 |         
50 |       
51 |     
52 |   
53 | 
54 |   
55 |     
56 |     
57 |     
58 |   
59 | 
60 |   
61 |     
62 |     
63 |     
64 |     
65 |   
66 | 
67 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/lqr.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 7 | 
 8 |   
 9 |     
10 |     
11 |     
12 |     
13 |   
14 | 
15 |   
18 | 
19 |   
20 |     
21 |     
22 |     
23 |     
24 |     
25 |   
26 | 
27 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/lqr_solver.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | r"""Optimal policy for LQR levels.
 17 | 
 18 | LQR control problem is described in
 19 | https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR
 20 | """
 21 | 
 22 | from __future__ import absolute_import
 23 | from __future__ import division
 24 | from __future__ import print_function
 25 | 
 26 | from absl import logging
 27 | from dm_control.mujoco import wrapper
 28 | import numpy as np
 29 | from six.moves import range
 30 | 
 31 | try:
 32 |   import scipy.linalg as sp  # pylint: disable=g-import-not-at-top
 33 | except ImportError:
 34 |   sp = None
 35 | 
 36 | 
 37 | def _solve_dare(a, b, q, r):
 38 |   """Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration.
 39 | 
 40 |   Algebraic Riccati Equation:
 41 |   ```none
 42 |   P_{t-1} = Q + A' * P_{t} * A -
 43 |             A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A
 44 |   ```
 45 | 
 46 |   Args:
 47 |     a: A 2 dimensional numpy array, transition matrix A.
 48 |     b: A 2 dimensional numpy array, control matrix B.
 49 |     q: A 2 dimensional numpy array, symmetric positive definite cost matrix.
 50 |     r: A 2 dimensional numpy array, symmetric positive definite cost matrix
 51 | 
 52 |   Returns:
 53 |     A numpy array, a real symmetric matrix P which is the solution to DARE.
 54 | 
 55 |   Raises:
 56 |     RuntimeError: If the computed P matrix is not symmetric and
 57 |       positive-definite.
 58 |   """
 59 |   p = np.eye(len(a))
 60 |   for _ in range(1000000):
 61 |     a_p = a.T.dot(p)  # A' * P_t
 62 |     a_p_b = np.dot(a_p, b)  # A' * P_t * B
 63 |     # Algebraic Riccati Equation.
 64 |     p_next = q + np.dot(a_p, a) - a_p_b.dot(
 65 |         np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T))
 66 |     p_next += p_next.T
 67 |     p_next *= .5
 68 |     if np.abs(p - p_next).max() < 1e-12:
 69 |       break
 70 |     p = p_next
 71 |   else:
 72 |     logging.warning('DARE solver did not converge')
 73 |   try:
 74 |     # Check that the result is symmetric and positive-definite.
 75 |     np.linalg.cholesky(p_next)
 76 |   except np.linalg.LinAlgError:
 77 |     raise RuntimeError('ARE solver failed: P matrix is not symmetric and '
 78 |                        'positive-definite.')
 79 |   return p_next
 80 | 
 81 | 
 82 | def solve(env):
 83 |   """Returns the optimal value and policy for LQR problem.
 84 | 
 85 |   Args:
 86 |     env: An instance of `control.EnvironmentV2` with LQR level.
 87 | 
 88 |   Returns:
 89 |     p: A numpy array, the Hessian of the optimal total cost-to-go (value
 90 |       function at state x) is V(x) = .5 * x' * p * x.
 91 |     k: A numpy array which gives the optimal linear policy u = k * x.
 92 |     beta: The maximum eigenvalue of (a + b * k). Under optimal policy, at
 93 |       timestep n the state tends to 0 like beta^n.
 94 | 
 95 |   Raises:
 96 |     RuntimeError: If the controlled system is unstable.
 97 |   """
 98 |   n = env.physics.model.nq  # number of DoFs
 99 |   m = env.physics.model.nu  # number of controls
100 | 
101 |   # Compute the mass matrix.
102 |   mass = np.zeros((n, n))
103 |   wrapper.mjbindings.mjlib.mj_fullM(env.physics.model.ptr, mass,
104 |                                     env.physics.data.qM)
105 | 
106 |   # Compute input matrices a, b, q and r to the DARE solvers.
107 |   # State transition matrix a.
108 |   stiffness = np.diag(env.physics.model.jnt_stiffness.ravel())
109 |   damping = np.diag(env.physics.model.dof_damping.ravel())
110 |   dt = env.physics.model.opt.timestep
111 | 
112 |   j = np.linalg.solve(-mass, np.hstack((stiffness, damping)))
113 |   a = np.eye(2 * n) + dt * np.vstack(
114 |       (dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j))
115 | 
116 |   # Control transition matrix b.
117 |   b = env.physics.data.actuator_moment.T
118 |   bc = np.linalg.solve(mass, b)
119 |   b = dt * np.vstack((dt * bc, bc))
120 | 
121 |   # State cost Hessian q.
122 |   q = np.diag(np.hstack([np.ones(n), np.zeros(n)]))
123 | 
124 |   # Control cost Hessian r.
125 |   r = env.task.control_cost_coef * np.eye(m)
126 | 
127 |   if sp:
128 |     # Use scipy's faster DARE solver if available.
129 |     solve_dare = sp.solve_discrete_are
130 |   else:
131 |     # Otherwise fall back on a slower internal implementation.
132 |     solve_dare = _solve_dare
133 | 
134 |   # Solve the discrete algebraic Riccati equation.
135 |   p = solve_dare(a, b, q, r)
136 |   k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a)))
137 | 
138 |   # Under optimal policy, state tends to 0 like beta^n_timesteps
139 |   beta = np.abs(np.linalg.eigvals(a + b.dot(k))).max()
140 |   if beta >= 1.0:
141 |     raise RuntimeError('Controlled system is unstable.')
142 |   return p, k, beta
143 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/pendulum.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Pendulum domain."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | from dm_control import mujoco
 25 | from dm_control.rl import control
 26 | from local_dm_control_suite import base
 27 | from local_dm_control_suite import common
 28 | from dm_control.utils import containers
 29 | from dm_control.utils import rewards
 30 | import numpy as np
 31 | 
 32 | 
 33 | _DEFAULT_TIME_LIMIT = 20
 34 | _ANGLE_BOUND = 8
 35 | _COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND))
 36 | SUITE = containers.TaggedTasks()
 37 | 
 38 | 
 39 | def get_model_and_assets():
 40 |   """Returns a tuple containing the model XML string and a dict of assets."""
 41 |   return common.read_model('pendulum.xml'), common.ASSETS
 42 | 
 43 | 
 44 | @SUITE.add('benchmarking')
 45 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
 46 |             environment_kwargs=None):
 47 |   """Returns pendulum swingup task ."""
 48 |   physics = Physics.from_xml_string(*get_model_and_assets())
 49 |   task = SwingUp(random=random)
 50 |   environment_kwargs = environment_kwargs or {}
 51 |   return control.Environment(
 52 |       physics, task, time_limit=time_limit, **environment_kwargs)
 53 | 
 54 | 
 55 | class Physics(mujoco.Physics):
 56 |   """Physics simulation with additional features for the Pendulum domain."""
 57 | 
 58 |   def pole_vertical(self):
 59 |     """Returns vertical (z) component of pole frame."""
 60 |     return self.named.data.xmat['pole', 'zz']
 61 | 
 62 |   def angular_velocity(self):
 63 |     """Returns the angular velocity of the pole."""
 64 |     return self.named.data.qvel['hinge'].copy()
 65 | 
 66 |   def pole_orientation(self):
 67 |     """Returns both horizontal and vertical components of pole frame."""
 68 |     return self.named.data.xmat['pole', ['zz', 'xz']]
 69 | 
 70 | 
 71 | class SwingUp(base.Task):
 72 |   """A Pendulum `Task` to swing up and balance the pole."""
 73 | 
 74 |   def __init__(self, random=None):
 75 |     """Initialize an instance of `Pendulum`.
 76 | 
 77 |     Args:
 78 |       random: Optional, either a `numpy.random.RandomState` instance, an
 79 |         integer seed for creating a new `RandomState`, or None to select a seed
 80 |         automatically (default).
 81 |     """
 82 |     super(SwingUp, self).__init__(random=random)
 83 | 
 84 |   def initialize_episode(self, physics):
 85 |     """Sets the state of the environment at the start of each episode.
 86 | 
 87 |     Pole is set to a random angle between [-pi, pi).
 88 | 
 89 |     Args:
 90 |       physics: An instance of `Physics`.
 91 | 
 92 |     """
 93 |     physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi)
 94 |     super(SwingUp, self).initialize_episode(physics)
 95 | 
 96 |   def get_observation(self, physics):
 97 |     """Returns an observation.
 98 | 
 99 |     Observations are states concatenating pole orientation and angular velocity
100 |     and pixels from fixed camera.
101 | 
102 |     Args:
103 |       physics: An instance of `physics`, Pendulum physics.
104 | 
105 |     Returns:
106 |       A `dict` of observation.
107 |     """
108 |     obs = collections.OrderedDict()
109 |     obs['orientation'] = physics.pole_orientation()
110 |     obs['velocity'] = physics.angular_velocity()
111 |     return obs
112 | 
113 |   def get_reward(self, physics):
114 |     return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1))
115 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/pendulum.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |     
13 |     
14 |     
15 |     
16 |       
17 |       
18 |       
19 |       
20 |     
21 |   
22 | 
23 |   
24 |     
25 |   
26 | 
27 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/point_mass.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |     
13 |   
14 | 
15 |   
16 |     
17 |     
18 |     
19 |     
20 |     
21 |     
22 |     
23 | 
24 |     
25 |       
26 |       
27 |       
28 |       
29 |     
30 | 
31 |     
32 |   
33 | 
34 |   
35 |     
36 |       
37 |       
38 |     
39 |     
40 |       
41 |       
42 |     
43 |   
44 | 
45 |   
46 |     
47 |     
48 |   
49 | 
50 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/reacher.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Reacher domain."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | from dm_control import mujoco
 25 | from dm_control.rl import control
 26 | from local_dm_control_suite import base
 27 | from local_dm_control_suite import common
 28 | from dm_control.suite.utils import randomizers
 29 | from dm_control.utils import containers
 30 | from dm_control.utils import rewards
 31 | import numpy as np
 32 | 
 33 | SUITE = containers.TaggedTasks()
 34 | _DEFAULT_TIME_LIMIT = 20
 35 | _BIG_TARGET = .05
 36 | _SMALL_TARGET = .015
 37 | 
 38 | 
 39 | def get_model_and_assets():
 40 |   """Returns a tuple containing the model XML string and a dict of assets."""
 41 |   return common.read_model('reacher.xml'), common.ASSETS
 42 | 
 43 | 
 44 | @SUITE.add('benchmarking', 'easy')
 45 | def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
 46 |   """Returns reacher with sparse reward with 5e-2 tol and randomized target."""
 47 |   physics = Physics.from_xml_string(*get_model_and_assets())
 48 |   task = Reacher(target_size=_BIG_TARGET, random=random)
 49 |   environment_kwargs = environment_kwargs or {}
 50 |   return control.Environment(
 51 |       physics, task, time_limit=time_limit, **environment_kwargs)
 52 | 
 53 | 
 54 | @SUITE.add('benchmarking')
 55 | def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
 56 |   """Returns reacher with sparse reward with 1e-2 tol and randomized target."""
 57 |   physics = Physics.from_xml_string(*get_model_and_assets())
 58 |   task = Reacher(target_size=_SMALL_TARGET, random=random)
 59 |   environment_kwargs = environment_kwargs or {}
 60 |   return control.Environment(
 61 |       physics, task, time_limit=time_limit, **environment_kwargs)
 62 | 
 63 | 
 64 | class Physics(mujoco.Physics):
 65 |   """Physics simulation with additional features for the Reacher domain."""
 66 | 
 67 |   def finger_to_target(self):
 68 |     """Returns the vector from target to finger in global coordinates."""
 69 |     return (self.named.data.geom_xpos['target', :2] -
 70 |             self.named.data.geom_xpos['finger', :2])
 71 | 
 72 |   def finger_to_target_dist(self):
 73 |     """Returns the signed distance between the finger and target surface."""
 74 |     return np.linalg.norm(self.finger_to_target())
 75 | 
 76 | 
 77 | class Reacher(base.Task):
 78 |   """A reacher `Task` to reach the target."""
 79 | 
 80 |   def __init__(self, target_size, random=None):
 81 |     """Initialize an instance of `Reacher`.
 82 | 
 83 |     Args:
 84 |       target_size: A `float`, tolerance to determine whether finger reached the
 85 |           target.
 86 |       random: Optional, either a `numpy.random.RandomState` instance, an
 87 |         integer seed for creating a new `RandomState`, or None to select a seed
 88 |         automatically (default).
 89 |     """
 90 |     self._target_size = target_size
 91 |     super(Reacher, self).__init__(random=random)
 92 | 
 93 |   def initialize_episode(self, physics):
 94 |     """Sets the state of the environment at the start of each episode."""
 95 |     physics.named.model.geom_size['target', 0] = self._target_size
 96 |     randomizers.randomize_limited_and_rotational_joints(physics, self.random)
 97 | 
 98 |     # Randomize target position
 99 |     angle = self.random.uniform(0, 2 * np.pi)
100 |     radius = self.random.uniform(.05, .20)
101 |     physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle)
102 |     physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle)
103 | 
104 |     super(Reacher, self).initialize_episode(physics)
105 | 
106 |   def get_observation(self, physics):
107 |     """Returns an observation of the state and the target position."""
108 |     obs = collections.OrderedDict()
109 |     obs['position'] = physics.position()
110 |     obs['to_target'] = physics.finger_to_target()
111 |     obs['velocity'] = physics.velocity()
112 |     return obs
113 | 
114 |   def get_reward(self, physics):
115 |     radii = physics.named.model.geom_size[['target', 'finger'], 0].sum()
116 |     return rewards.tolerance(physics.finger_to_target_dist(), (0, radii))
117 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/reacher.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |     
13 |   
14 | 
15 |   
16 |     
17 |     
18 |     
19 |     
20 |     
21 |     
22 |     
23 |     
24 | 
25 |     
26 |     
27 |     
28 |       
29 |       
30 |       
31 |         
32 |         
33 |         
34 |           
35 |           
36 |         
37 |       
38 |     
39 |     
40 |     
41 |   
42 | 
43 |   
44 |     
45 |     
46 |   
47 | 
48 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/swimmer.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |       
13 |       
14 |         
15 |       
16 |       
17 |         
18 |       
19 |       
20 |     
21 |     
22 |       
23 |     
24 |     
25 |   
26 | 
27 |   
28 |     
29 |     
30 |       
31 |       
32 |       
33 |       
34 |       
35 |       
36 |       
37 |       
38 |       
39 |       
40 |       
41 |       
42 |       
43 |     
44 |     
45 |     
46 |   
47 | 
48 |   
49 |     
50 |     
51 |     
52 |     
53 |     
54 |     
55 |   
56 | 
57 | 
58 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/tests/loader_test.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Tests for the dm_control.suite loader."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | # Internal dependencies.
23 | 
24 | from absl.testing import absltest
25 | 
26 | from dm_control import suite
27 | from dm_control.rl import control
28 | 
29 | 
30 | class LoaderTest(absltest.TestCase):
31 | 
32 |   def test_load_without_kwargs(self):
33 |     env = suite.load('cartpole', 'swingup')
34 |     self.assertIsInstance(env, control.Environment)
35 | 
36 |   def test_load_with_kwargs(self):
37 |     env = suite.load('cartpole', 'swingup',
38 |                      task_kwargs={'time_limit': 40, 'random': 99})
39 |     self.assertIsInstance(env, control.Environment)
40 | 
41 | 
42 | class LoaderConstantsTest(absltest.TestCase):
43 | 
44 |   def testSuiteConstants(self):
45 |     self.assertNotEmpty(suite.BENCHMARKING)
46 |     self.assertNotEmpty(suite.EASY)
47 |     self.assertNotEmpty(suite.HARD)
48 |     self.assertNotEmpty(suite.EXTRA)
49 | 
50 | 
51 | if __name__ == '__main__':
52 |   absltest.main()
53 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/tests/lqr_test.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Tests specific to the LQR domain."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import  math
23 | import unittest
24 | 
25 | # Internal dependencies.
26 | from absl import logging
27 | 
28 | from absl.testing import absltest
29 | from absl.testing import parameterized
30 | 
31 | from local_dm_control_suite import lqr
32 | from local_dm_control_suite import lqr_solver
33 | 
34 | import numpy as np
35 | from six.moves import range
36 | 
37 | 
38 | class LqrTest(parameterized.TestCase):
39 | 
40 |   @parameterized.named_parameters(
41 |       ('lqr_2_1', lqr.lqr_2_1),
42 |       ('lqr_6_2', lqr.lqr_6_2))
43 |   def test_lqr_optimal_policy(self, make_env):
44 |     env = make_env()
45 |     p, k, beta = lqr_solver.solve(env)
46 |     self.assertPolicyisOptimal(env, p, k, beta)
47 | 
48 |   @parameterized.named_parameters(
49 |       ('lqr_2_1', lqr.lqr_2_1),
50 |       ('lqr_6_2', lqr.lqr_6_2))
51 |   @unittest.skipUnless(
52 |       condition=lqr_solver.sp,
53 |       reason='scipy is not available, so non-scipy DARE solver is the default.')
54 |   def test_lqr_optimal_policy_no_scipy(self, make_env):
55 |     env = make_env()
56 |     old_sp = lqr_solver.sp
57 |     try:
58 |       lqr_solver.sp = None  # Force the solver to use the non-scipy code path.
59 |       p, k, beta = lqr_solver.solve(env)
60 |     finally:
61 |       lqr_solver.sp = old_sp
62 |     self.assertPolicyisOptimal(env, p, k, beta)
63 | 
64 |   def assertPolicyisOptimal(self, env, p, k, beta):
65 |     tolerance = 1e-3
66 |     n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta)))
67 |     logging.info('%d timesteps for %g convergence.', n_steps, tolerance)
68 |     total_loss = 0.0
69 | 
70 |     timestep = env.reset()
71 |     initial_state = np.hstack((timestep.observation['position'],
72 |                                timestep.observation['velocity']))
73 |     logging.info('Measuring total cost over %d steps.', n_steps)
74 |     for _ in range(n_steps):
75 |       x = np.hstack((timestep.observation['position'],
76 |                      timestep.observation['velocity']))
77 |       # u = k*x is the optimal policy
78 |       u = k.dot(x)
79 |       total_loss += 1 - (timestep.reward or 0.0)
80 |       timestep = env.step(u)
81 | 
82 |     logging.info('Analytical expected total cost is .5*x^T*p*x.')
83 |     expected_loss = .5 * initial_state.T.dot(p).dot(initial_state)
84 |     logging.info('Comparing measured and predicted costs.')
85 |     np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance)
86 | 
87 | if __name__ == '__main__':
88 |   absltest.main()
89 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/utils/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Utility functions used in the control suite."""
17 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/utils/parse_amc_test.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Tests for parse_amc utility."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import os
23 | 
24 | # Internal dependencies.
25 | 
26 | from absl.testing import absltest
27 | from local_dm_control_suite import humanoid_CMU
28 | from dm_control.suite.utils import parse_amc
29 | 
30 | from dm_control.utils import io as resources
31 | 
32 | _TEST_AMC_PATH = resources.GetResourceFilename(
33 |     os.path.join(os.path.dirname(__file__), '../demos/zeros.amc'))
34 | 
35 | 
36 | class ParseAMCTest(absltest.TestCase):
37 | 
38 |   def test_sizes_of_parsed_data(self):
39 | 
40 |     # Instantiate the humanoid environment.
41 |     env = humanoid_CMU.stand()
42 | 
43 |     # Parse and convert specified clip.
44 |     converted = parse_amc.convert(
45 |         _TEST_AMC_PATH, env.physics, env.control_timestep())
46 | 
47 |     self.assertEqual(converted.qpos.shape[0], 63)
48 |     self.assertEqual(converted.qvel.shape[0], 62)
49 |     self.assertEqual(converted.time.shape[0], converted.qpos.shape[1])
50 |     self.assertEqual(converted.qpos.shape[1],
51 |                      converted.qvel.shape[1] + 1)
52 | 
53 |     # Parse and convert specified clip -- WITH SMALLER TIMESTEP
54 |     converted2 = parse_amc.convert(
55 |         _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep())
56 | 
57 |     self.assertEqual(converted2.qpos.shape[0], 63)
58 |     self.assertEqual(converted2.qvel.shape[0], 62)
59 |     self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1])
60 |     self.assertEqual(converted.qpos.shape[1],
61 |                      converted.qvel.shape[1] + 1)
62 | 
63 |     # Compare sizes of parsed objects for different timesteps
64 |     self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1])
65 | 
66 | 
67 | if __name__ == '__main__':
68 |   absltest.main()
69 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/utils/randomizers.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Randomization functions."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | from dm_control.mujoco.wrapper import mjbindings
23 | import numpy as np
24 | from six.moves import range
25 | 
26 | 
27 | def random_limited_quaternion(random, limit):
28 |   """Generates a random quaternion limited to the specified rotations."""
29 |   axis = random.randn(3)
30 |   axis /= np.linalg.norm(axis)
31 |   angle = random.rand() * limit
32 | 
33 |   quaternion = np.zeros(4)
34 |   mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
35 | 
36 |   return quaternion
37 | 
38 | 
39 | def randomize_limited_and_rotational_joints(physics, random=None):
40 |   """Randomizes the positions of joints defined in the physics body.
41 | 
42 |   The following randomization rules apply:
43 |     - Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
44 |     - Unbounded hinges are samples uniformly in [-pi, pi]
45 |     - Quaternions for unlimited free joints and ball joints are sampled
46 |       uniformly on the unit 3-sphere.
47 |     - Quaternions for limited ball joints are sampled uniformly on a sector
48 |       of the unit 3-sphere.
49 |     - The linear degrees of freedom of free joints are not randomized.
50 | 
51 |   Args:
52 |     physics: Instance of 'Physics' class that holds a loaded model.
53 |     random: Optional instance of 'np.random.RandomState'. Defaults to the global
54 |       NumPy random state.
55 |   """
56 |   random = random or np.random
57 | 
58 |   hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
59 |   slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
60 |   ball = mjbindings.enums.mjtJoint.mjJNT_BALL
61 |   free = mjbindings.enums.mjtJoint.mjJNT_FREE
62 | 
63 |   qpos = physics.named.data.qpos
64 | 
65 |   for joint_id in range(physics.model.njnt):
66 |     joint_name = physics.model.id2name(joint_id, 'joint')
67 |     joint_type = physics.model.jnt_type[joint_id]
68 |     is_limited = physics.model.jnt_limited[joint_id]
69 |     range_min, range_max = physics.model.jnt_range[joint_id]
70 | 
71 |     if is_limited:
72 |       if joint_type == hinge or joint_type == slide:
73 |         qpos[joint_name] = random.uniform(range_min, range_max)
74 | 
75 |       elif joint_type == ball:
76 |         qpos[joint_name] = random_limited_quaternion(random, range_max)
77 | 
78 |     else:
79 |       if joint_type == hinge:
80 |         qpos[joint_name] = random.uniform(-np.pi, np.pi)
81 | 
82 |       elif joint_type == ball:
83 |         quat = random.randn(4)
84 |         quat /= np.linalg.norm(quat)
85 |         qpos[joint_name] = quat
86 | 
87 |       elif joint_type == free:
88 |         quat = random.rand(4)
89 |         quat /= np.linalg.norm(quat)
90 |         qpos[joint_name][3:] = quat
91 | 
92 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/walker.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 7 | 
 8 |   
 9 | 
10 |   
11 |     
12 |     
13 |     
14 |     
15 |     
16 |       
17 |       
18 |     
19 |   
20 | 
21 |   
22 |     
23 |     
24 |       
25 |       
26 |       
27 |       
28 |       
29 |       
30 |       
31 |       
32 |         
33 |         
34 |         
35 |           
36 |           
37 |           
38 |             
39 |             
40 |           
41 |         
42 |       
43 |       
44 |         
45 |         
46 |         
47 |           
48 |           
49 |           
50 |             
51 |             
52 |           
53 |         
54 |       
55 |     
56 |   
57 | 
58 |   
59 |     
60 |   
61 | 
62 |   
63 |     
64 |     
65 |     
66 |     
67 |     
68 |     
69 |   
70 | 
71 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/wrappers/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2018 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Environment wrappers used to extend or modify environment behaviour."""
17 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/wrappers/action_noise.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2018 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Wrapper control suite environments that adds Gaussian noise to actions."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import dm_env
23 | import numpy as np
24 | 
25 | 
26 | _BOUNDS_MUST_BE_FINITE = (
27 |     'All bounds in `env.action_spec()` must be finite, got: {action_spec}')
28 | 
29 | 
30 | class Wrapper(dm_env.Environment):
31 |   """Wraps a control environment and adds Gaussian noise to actions."""
32 | 
33 |   def __init__(self, env, scale=0.01):
34 |     """Initializes a new action noise Wrapper.
35 | 
36 |     Args:
37 |       env: The control suite environment to wrap.
38 |       scale: The standard deviation of the noise, expressed as a fraction
39 |         of the max-min range for each action dimension.
40 | 
41 |     Raises:
42 |       ValueError: If any of the action dimensions of the wrapped environment are
43 |         unbounded.
44 |     """
45 |     action_spec = env.action_spec()
46 |     if not (np.all(np.isfinite(action_spec.minimum)) and
47 |             np.all(np.isfinite(action_spec.maximum))):
48 |       raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec))
49 |     self._minimum = action_spec.minimum
50 |     self._maximum = action_spec.maximum
51 |     self._noise_std = scale * (action_spec.maximum - action_spec.minimum)
52 |     self._env = env
53 | 
54 |   def step(self, action):
55 |     noisy_action = action + self._env.task.random.normal(scale=self._noise_std)
56 |     # Clip the noisy actions in place so that they fall within the bounds
57 |     # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of-
58 |     # bounds control inputs, but we also clip here in case the actions do not
59 |     # correspond directly to MuJoCo actuators, or if there are other wrapper
60 |     # layers that expect the actions to be within bounds.
61 |     np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action)
62 |     return self._env.step(noisy_action)
63 | 
64 |   def reset(self):
65 |     return self._env.reset()
66 | 
67 |   def observation_spec(self):
68 |     return self._env.observation_spec()
69 | 
70 |   def action_spec(self):
71 |     return self._env.action_spec()
72 | 
73 |   def __getattr__(self, name):
74 |     return getattr(self._env, name)
75 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/wrappers/pixels.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Wrapper that adds pixel observations to a control environment."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | import dm_env
 25 | from dm_env import specs
 26 | 
 27 | STATE_KEY = 'state'
 28 | 
 29 | 
 30 | class Wrapper(dm_env.Environment):
 31 |   """Wraps a control environment and adds a rendered pixel observation."""
 32 | 
 33 |   def __init__(self, env, pixels_only=True, render_kwargs=None,
 34 |                observation_key='pixels'):
 35 |     """Initializes a new pixel Wrapper.
 36 | 
 37 |     Args:
 38 |       env: The environment to wrap.
 39 |       pixels_only: If True (default), the original set of 'state' observations
 40 |         returned by the wrapped environment will be discarded, and the
 41 |         `OrderedDict` of observations will only contain pixels. If False, the
 42 |         `OrderedDict` will contain the original observations as well as the
 43 |         pixel observations.
 44 |       render_kwargs: Optional `dict` containing keyword arguments passed to the
 45 |         `mujoco.Physics.render` method.
 46 |       observation_key: Optional custom string specifying the pixel observation's
 47 |         key in the `OrderedDict` of observations. Defaults to 'pixels'.
 48 | 
 49 |     Raises:
 50 |       ValueError: If `env`'s observation spec is not compatible with the
 51 |         wrapper. Supported formats are a single array, or a dict of arrays.
 52 |       ValueError: If `env`'s observation already contains the specified
 53 |         `observation_key`.
 54 |     """
 55 |     if render_kwargs is None:
 56 |       render_kwargs = {}
 57 | 
 58 |     wrapped_observation_spec = env.observation_spec()
 59 | 
 60 |     if isinstance(wrapped_observation_spec, specs.Array):
 61 |       self._observation_is_dict = False
 62 |       invalid_keys = set([STATE_KEY])
 63 |     elif isinstance(wrapped_observation_spec, collections.MutableMapping):
 64 |       self._observation_is_dict = True
 65 |       invalid_keys = set(wrapped_observation_spec.keys())
 66 |     else:
 67 |       raise ValueError('Unsupported observation spec structure.')
 68 | 
 69 |     if not pixels_only and observation_key in invalid_keys:
 70 |       raise ValueError('Duplicate or reserved observation key {!r}.'
 71 |                        .format(observation_key))
 72 | 
 73 |     if pixels_only:
 74 |       self._observation_spec = collections.OrderedDict()
 75 |     elif self._observation_is_dict:
 76 |       self._observation_spec = wrapped_observation_spec.copy()
 77 |     else:
 78 |       self._observation_spec = collections.OrderedDict()
 79 |       self._observation_spec[STATE_KEY] = wrapped_observation_spec
 80 | 
 81 |     # Extend observation spec.
 82 |     pixels = env.physics.render(**render_kwargs)
 83 |     pixels_spec = specs.Array(
 84 |         shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
 85 |     self._observation_spec[observation_key] = pixels_spec
 86 | 
 87 |     self._env = env
 88 |     self._pixels_only = pixels_only
 89 |     self._render_kwargs = render_kwargs
 90 |     self._observation_key = observation_key
 91 | 
 92 |   def reset(self):
 93 |     time_step = self._env.reset()
 94 |     return self._add_pixel_observation(time_step)
 95 | 
 96 |   def step(self, action):
 97 |     time_step = self._env.step(action)
 98 |     return self._add_pixel_observation(time_step)
 99 | 
100 |   def observation_spec(self):
101 |     return self._observation_spec
102 | 
103 |   def action_spec(self):
104 |     return self._env.action_spec()
105 | 
106 |   def _add_pixel_observation(self, time_step):
107 |     if self._pixels_only:
108 |       observation = collections.OrderedDict()
109 |     elif self._observation_is_dict:
110 |       observation = type(time_step.observation)(time_step.observation)
111 |     else:
112 |       observation = collections.OrderedDict()
113 |       observation[STATE_KEY] = time_step.observation
114 | 
115 |     pixels = self._env.physics.render(**self._render_kwargs)
116 |     observation[self._observation_key] = pixels
117 |     return time_step._replace(observation=observation)
118 | 
119 |   def __getattr__(self, name):
120 |     return getattr(self._env, name)
121 | 
--------------------------------------------------------------------------------
/local_dm_control_suite/wrappers/pixels_test.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Tests for the pixel wrapper."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | # Internal dependencies.
 25 | from absl.testing import absltest
 26 | from absl.testing import parameterized
 27 | from local_dm_control_suite import cartpole
 28 | from dm_control.suite.wrappers import pixels
 29 | import dm_env
 30 | from dm_env import specs
 31 | 
 32 | import numpy as np
 33 | 
 34 | 
 35 | class FakePhysics(object):
 36 | 
 37 |   def render(self, *args, **kwargs):
 38 |     del args
 39 |     del kwargs
 40 |     return np.zeros((4, 5, 3), dtype=np.uint8)
 41 | 
 42 | 
 43 | class FakeArrayObservationEnvironment(dm_env.Environment):
 44 | 
 45 |   def __init__(self):
 46 |     self.physics = FakePhysics()
 47 | 
 48 |   def reset(self):
 49 |     return dm_env.restart(np.zeros((2,)))
 50 | 
 51 |   def step(self, action):
 52 |     del action
 53 |     return dm_env.transition(0.0, np.zeros((2,)))
 54 | 
 55 |   def action_spec(self):
 56 |     pass
 57 | 
 58 |   def observation_spec(self):
 59 |     return specs.Array(shape=(2,), dtype=np.float)
 60 | 
 61 | 
 62 | class PixelsTest(parameterized.TestCase):
 63 | 
 64 |   @parameterized.parameters(True, False)
 65 |   def test_dict_observation(self, pixels_only):
 66 |     pixel_key = 'rgb'
 67 | 
 68 |     env = cartpole.swingup()
 69 | 
 70 |     # Make sure we are testing the right environment for the test.
 71 |     observation_spec = env.observation_spec()
 72 |     self.assertIsInstance(observation_spec, collections.OrderedDict)
 73 | 
 74 |     width = 320
 75 |     height = 240
 76 | 
 77 |     # The wrapper should only add one observation.
 78 |     wrapped = pixels.Wrapper(env,
 79 |                              observation_key=pixel_key,
 80 |                              pixels_only=pixels_only,
 81 |                              render_kwargs={'width': width, 'height': height})
 82 | 
 83 |     wrapped_observation_spec = wrapped.observation_spec()
 84 |     self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
 85 | 
 86 |     if pixels_only:
 87 |       self.assertLen(wrapped_observation_spec, 1)
 88 |       self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
 89 |     else:
 90 |       expected_length = len(observation_spec) + 1
 91 |       self.assertLen(wrapped_observation_spec, expected_length)
 92 |       expected_keys = list(observation_spec.keys()) + [pixel_key]
 93 |       self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))
 94 | 
 95 |     # Check that the added spec item is consistent with the added observation.
 96 |     time_step = wrapped.reset()
 97 |     rgb_observation = time_step.observation[pixel_key]
 98 |     wrapped_observation_spec[pixel_key].validate(rgb_observation)
 99 | 
100 |     self.assertEqual(rgb_observation.shape, (height, width, 3))
101 |     self.assertEqual(rgb_observation.dtype, np.uint8)
102 | 
103 |   @parameterized.parameters(True, False)
104 |   def test_single_array_observation(self, pixels_only):
105 |     pixel_key = 'depth'
106 | 
107 |     env = FakeArrayObservationEnvironment()
108 |     observation_spec = env.observation_spec()
109 |     self.assertIsInstance(observation_spec, specs.Array)
110 | 
111 |     wrapped = pixels.Wrapper(env, observation_key=pixel_key,
112 |                              pixels_only=pixels_only)
113 |     wrapped_observation_spec = wrapped.observation_spec()
114 |     self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
115 | 
116 |     if pixels_only:
117 |       self.assertLen(wrapped_observation_spec, 1)
118 |       self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
119 |     else:
120 |       self.assertLen(wrapped_observation_spec, 2)
121 |       self.assertEqual([pixels.STATE_KEY, pixel_key],
122 |                        list(wrapped_observation_spec.keys()))
123 | 
124 |     time_step = wrapped.reset()
125 | 
126 |     depth_observation = time_step.observation[pixel_key]
127 |     wrapped_observation_spec[pixel_key].validate(depth_observation)
128 | 
129 |     self.assertEqual(depth_observation.shape, (4, 5, 3))
130 |     self.assertEqual(depth_observation.dtype, np.uint8)
131 | 
132 | if __name__ == '__main__':
133 |   absltest.main()
134 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/README.md:
--------------------------------------------------------------------------------
 1 | # DeepMind Control Suite.
 2 | 
 3 | This submodule contains the domains and tasks described in the
 4 | [DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690).
 5 | 
 6 | ## Quickstart
 7 | 
 8 | ```python
 9 | from dm_control import suite
10 | import numpy as np
11 | 
12 | # Load one task:
13 | env = suite.load(domain_name="cartpole", task_name="swingup")
14 | 
15 | # Iterate over a task set:
16 | for domain_name, task_name in suite.BENCHMARKING:
17 |   env = suite.load(domain_name, task_name)
18 | 
19 | # Step through an episode and print out reward, discount and observation.
20 | action_spec = env.action_spec()
21 | time_step = env.reset()
22 | while not time_step.last():
23 |   action = np.random.uniform(action_spec.minimum,
24 |                              action_spec.maximum,
25 |                              size=action_spec.shape)
26 |   time_step = env.step(action)
27 |   print(time_step.reward, time_step.discount, time_step.observation)
28 | ```
29 | 
30 | ## Illustration video
31 | 
32 | Below is a video montage of solved Control Suite tasks, with reward
33 | visualisation enabled.
34 | 
35 | [](https://www.youtube.com/watch?v=rAai4QzcYbs)
36 | 
37 | 
38 | ### Quadruped domain [April 2019]
39 | 
40 | Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are:
41 | 
42 | - 4 DoFs per leg, 1 constraining tendon.
43 | - 3 actuators per leg: 'yaw', 'lift', 'extend'.
44 | - Filtered position actuators with timescale of 100ms.
45 | - Sensors include an IMU, force/torque sensors, and rangefinders.
46 | 
47 | Four tasks:
48 | 
49 | - `walk` and `run`: self-right the body then move forward at a desired speed.
50 | - `escape`: escape a bowl-shaped random terrain (uses rangefinders).
51 | - `fetch`, go to a moving ball and bring it to a target.
52 | 
53 | All behaviors in the video below were trained with [Abdolmaleki et al's
54 | MPO](https://arxiv.org/abs/1806.06920).
55 | 
56 | [](https://www.youtube.com/watch?v=RhRLjbb7pBE)
57 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/acrobot.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Acrobot domain."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | from dm_control import mujoco
 25 | from dm_control.rl import control
 26 | from local_dm_control_suite import base
 27 | from local_dm_control_suite import common
 28 | from dm_control.utils import containers
 29 | from dm_control.utils import rewards
 30 | import numpy as np
 31 | 
 32 | _DEFAULT_TIME_LIMIT = 10
 33 | SUITE = containers.TaggedTasks()
 34 | 
 35 | 
 36 | def get_model_and_assets():
 37 |   """Returns a tuple containing the model XML string and a dict of assets."""
 38 |   return common.read_model('acrobot.xml'), common.ASSETS
 39 | 
 40 | 
 41 | @SUITE.add('benchmarking')
 42 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
 43 |             environment_kwargs=None):
 44 |   """Returns Acrobot balance task."""
 45 |   physics = Physics.from_xml_string(*get_model_and_assets())
 46 |   task = Balance(sparse=False, random=random)
 47 |   environment_kwargs = environment_kwargs or {}
 48 |   return control.Environment(
 49 |       physics, task, time_limit=time_limit, **environment_kwargs)
 50 | 
 51 | 
 52 | @SUITE.add('benchmarking')
 53 | def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
 54 |                    environment_kwargs=None):
 55 |   """Returns Acrobot sparse balance."""
 56 |   physics = Physics.from_xml_string(*get_model_and_assets())
 57 |   task = Balance(sparse=True, random=random)
 58 |   environment_kwargs = environment_kwargs or {}
 59 |   return control.Environment(
 60 |       physics, task, time_limit=time_limit, **environment_kwargs)
 61 | 
 62 | 
 63 | class Physics(mujoco.Physics):
 64 |   """Physics simulation with additional features for the Acrobot domain."""
 65 | 
 66 |   def horizontal(self):
 67 |     """Returns horizontal (x) component of body frame z-axes."""
 68 |     return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz']
 69 | 
 70 |   def vertical(self):
 71 |     """Returns vertical (z) component of body frame z-axes."""
 72 |     return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz']
 73 | 
 74 |   def to_target(self):
 75 |     """Returns the distance from the tip to the target."""
 76 |     tip_to_target = (self.named.data.site_xpos['target'] -
 77 |                      self.named.data.site_xpos['tip'])
 78 |     return np.linalg.norm(tip_to_target)
 79 | 
 80 |   def orientations(self):
 81 |     """Returns the sines and cosines of the pole angles."""
 82 |     return np.concatenate((self.horizontal(), self.vertical()))
 83 | 
 84 | 
 85 | class Balance(base.Task):
 86 |   """An Acrobot `Task` to swing up and balance the pole."""
 87 | 
 88 |   def __init__(self, sparse, random=None):
 89 |     """Initializes an instance of `Balance`.
 90 | 
 91 |     Args:
 92 |       sparse: A `bool` specifying whether to use a sparse (indicator) reward.
 93 |       random: Optional, either a `numpy.random.RandomState` instance, an
 94 |         integer seed for creating a new `RandomState`, or None to select a seed
 95 |         automatically (default).
 96 |     """
 97 |     self._sparse = sparse
 98 |     super(Balance, self).__init__(random=random)
 99 | 
100 |   def initialize_episode(self, physics):
101 |     """Sets the state of the environment at the start of each episode.
102 | 
103 |     Shoulder and elbow are set to a random position between [-pi, pi).
104 | 
105 |     Args:
106 |       physics: An instance of `Physics`.
107 |     """
108 |     physics.named.data.qpos[
109 |         ['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2)
110 |     super(Balance, self).initialize_episode(physics)
111 | 
112 |   def get_observation(self, physics):
113 |     """Returns an observation of pole orientation and angular velocities."""
114 |     obs = collections.OrderedDict()
115 |     obs['orientations'] = physics.orientations()
116 |     obs['velocity'] = physics.velocity()
117 |     return obs
118 | 
119 |   def _get_reward(self, physics, sparse):
120 |     target_radius = physics.named.model.site_size['target', 0]
121 |     return rewards.tolerance(physics.to_target(),
122 |                              bounds=(0, target_radius),
123 |                              margin=0 if sparse else 1)
124 | 
125 |   def get_reward(self, physics):
126 |     """Returns a sparse or a smooth reward, as specified in the constructor."""
127 |     return self._get_reward(physics, sparse=self._sparse)
128 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/acrobot.xml:
--------------------------------------------------------------------------------
 1 | 
 8 | 
 9 |   
10 |   
11 |   
12 | 
13 |   
14 |     
15 |     
16 |   
17 | 
18 |   
21 | 
22 |   
23 |     
24 |     
25 |     
26 |     
27 |     
28 |     
29 |       
30 |       
31 |       
32 |       
33 |         
34 |         
35 |         
36 |       
37 |     
38 |   
39 | 
40 |    
41 |     
42 |   
43 | 
44 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/ball_in_cup.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Ball-in-Cup Domain."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | from dm_control import mujoco
 25 | from dm_control.rl import control
 26 | from local_dm_control_suite import base
 27 | from local_dm_control_suite import common
 28 | from dm_control.utils import containers
 29 | 
 30 | _DEFAULT_TIME_LIMIT = 20  # (seconds)
 31 | _CONTROL_TIMESTEP = .02   # (seconds)
 32 | 
 33 | 
 34 | SUITE = containers.TaggedTasks()
 35 | 
 36 | 
 37 | def get_model_and_assets():
 38 |   """Returns a tuple containing the model XML string and a dict of assets."""
 39 |   return common.read_model('ball_in_cup.xml'), common.ASSETS
 40 | 
 41 | 
 42 | @SUITE.add('benchmarking', 'easy')
 43 | def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
 44 |   """Returns the Ball-in-Cup task."""
 45 |   physics = Physics.from_xml_string(*get_model_and_assets())
 46 |   task = BallInCup(random=random)
 47 |   environment_kwargs = environment_kwargs or {}
 48 |   return control.Environment(
 49 |       physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
 50 |       **environment_kwargs)
 51 | 
 52 | 
 53 | class Physics(mujoco.Physics):
 54 |   """Physics with additional features for the Ball-in-Cup domain."""
 55 | 
 56 |   def ball_to_target(self):
 57 |     """Returns the vector from the ball to the target."""
 58 |     target = self.named.data.site_xpos['target', ['x', 'z']]
 59 |     ball = self.named.data.xpos['ball', ['x', 'z']]
 60 |     return target - ball
 61 | 
 62 |   def in_target(self):
 63 |     """Returns 1 if the ball is in the target, 0 otherwise."""
 64 |     ball_to_target = abs(self.ball_to_target())
 65 |     target_size = self.named.model.site_size['target', [0, 2]]
 66 |     ball_size = self.named.model.geom_size['ball', 0]
 67 |     return float(all(ball_to_target < target_size - ball_size))
 68 | 
 69 | 
 70 | class BallInCup(base.Task):
 71 |   """The Ball-in-Cup task. Put the ball in the cup."""
 72 | 
 73 |   def initialize_episode(self, physics):
 74 |     """Sets the state of the environment at the start of each episode.
 75 | 
 76 |     Args:
 77 |       physics: An instance of `Physics`.
 78 | 
 79 |     """
 80 |     # Find a collision-free random initial position of the ball.
 81 |     penetrating = True
 82 |     while penetrating:
 83 |       # Assign a random ball position.
 84 |       physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
 85 |       physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
 86 |       # Check for collisions.
 87 |       physics.after_reset()
 88 |       penetrating = physics.data.ncon > 0
 89 |     super(BallInCup, self).initialize_episode(physics)
 90 | 
 91 |   def get_observation(self, physics):
 92 |     """Returns an observation of the state."""
 93 |     obs = collections.OrderedDict()
 94 |     obs['position'] = physics.position()
 95 |     obs['velocity'] = physics.velocity()
 96 |     return obs
 97 | 
 98 |   def get_reward(self, physics):
 99 |     """Returns a sparse reward."""
100 |     return physics.in_target()
101 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/ball_in_cup.xml:
--------------------------------------------------------------------------------
 1 | 
 2 | 
 3 |   
 4 |   
 5 |   
 6 | 
 7 |   
 8 |     
 9 |     
10 |       
11 |       
12 |     
13 |   
14 | 
15 |   
16 |     
17 |     
18 |     
19 |     
20 | 
21 |     
22 |       
23 |       
24 |       
25 |       
26 |       
27 |       
28 |       
29 |       
30 |       
31 |     
32 | 
33 |     
34 |       
35 |       
36 |       
37 |       
38 |     
39 |   
40 | 
41 |   
42 |     
43 |     
44 |   
45 | 
46 |   
47 |     
48 |       
49 |       
50 |     
51 |   
52 | 
53 | 
54 | 
55 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/base.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Base class for tasks in the Control Suite."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | from dm_control import mujoco
 23 | from dm_control.rl import control
 24 | 
 25 | import numpy as np
 26 | 
 27 | 
 28 | class Task(control.Task):
 29 |   """Base class for tasks in the Control Suite.
 30 | 
 31 |   Actions are mapped directly to the states of MuJoCo actuators: each element of
 32 |   the action array is used to set the control input for a single actuator. The
 33 |   ordering of the actuators is the same as in the corresponding MJCF XML file.
 34 | 
 35 |   Attributes:
 36 |     random: A `numpy.random.RandomState` instance. This should be used to
 37 |       generate all random variables associated with the task, such as random
 38 |       starting states, observation noise* etc.
 39 | 
 40 |   *If sensor noise is enabled in the MuJoCo model then this will be generated
 41 |   using MuJoCo's internal RNG, which has its own independent state.
 42 |   """
 43 | 
 44 |   def __init__(self, random=None):
 45 |     """Initializes a new continuous control task.
 46 | 
 47 |     Args:
 48 |       random: Optional, either a `numpy.random.RandomState` instance, an integer
 49 |         seed for creating a new `RandomState`, or None to select a seed
 50 |         automatically (default).
 51 |     """
 52 |     if not isinstance(random, np.random.RandomState):
 53 |       random = np.random.RandomState(random)
 54 |     self._random = random
 55 |     self._visualize_reward = False
 56 | 
 57 |   @property
 58 |   def random(self):
 59 |     """Task-specific `numpy.random.RandomState` instance."""
 60 |     return self._random
 61 | 
 62 |   def action_spec(self, physics):
 63 |     """Returns a `BoundedArraySpec` matching the `physics` actuators."""
 64 |     return mujoco.action_spec(physics)
 65 | 
 66 |   def initialize_episode(self, physics):
 67 |     """Resets geom colors to their defaults after starting a new episode.
 68 | 
 69 |     Subclasses of `base.Task` must delegate to this method after performing
 70 |     their own initialization.
 71 | 
 72 |     Args:
 73 |       physics: An instance of `mujoco.Physics`.
 74 |     """
 75 |     self.after_step(physics)
 76 | 
 77 |   def before_step(self, action, physics):
 78 |     """Sets the control signal for the actuators to values in `action`."""
 79 |     # Support legacy internal code.
 80 |     action = getattr(action, "continuous_actions", action)
 81 |     physics.set_control(action)
 82 | 
 83 |   def after_step(self, physics):
 84 |     """Modifies colors according to the reward."""
 85 |     if self._visualize_reward:
 86 |       reward = np.clip(self.get_reward(physics), 0.0, 1.0)
 87 |       _set_reward_colors(physics, reward)
 88 | 
 89 |   @property
 90 |   def visualize_reward(self):
 91 |     return self._visualize_reward
 92 | 
 93 |   @visualize_reward.setter
 94 |   def visualize_reward(self, value):
 95 |     if not isinstance(value, bool):
 96 |       raise ValueError("Expected a boolean, got {}.".format(type(value)))
 97 |     self._visualize_reward = value
 98 | 
 99 | 
100 | _MATERIALS = ["self", "effector", "target"]
101 | _DEFAULT = [name + "_default" for name in _MATERIALS]
102 | _HIGHLIGHT = [name + "_highlight" for name in _MATERIALS]
103 | 
104 | 
105 | def _set_reward_colors(physics, reward):
106 |   """Sets the highlight, effector and target colors according to the reward."""
107 |   assert 0.0 <= reward <= 1.0
108 |   colors = physics.named.model.mat_rgba
109 |   default = colors[_DEFAULT]
110 |   highlight = colors[_HIGHLIGHT]
111 |   blend_coef = reward ** 4  # Better color distinction near high rewards.
112 |   colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default
113 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/cartpole.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |       
13 |       
14 |     
15 |   
16 | 
17 |   
18 |     
19 |     
20 |     
21 |     
22 |     
23 |     
24 |     
25 |       
26 |       
27 |       
28 |         
29 |         
30 |       
31 |     
32 |   
33 | 
34 |   
35 |     
36 |   
37 | 
38 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/cheetah.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Cheetah Domain."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import collections
23 | 
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 | 
31 | 
32 | # How long the simulation will run, in seconds.
33 | _DEFAULT_TIME_LIMIT = 10
34 | 
35 | # Running speed above which reward is 1.
36 | _RUN_SPEED = 10
37 | 
38 | SUITE = containers.TaggedTasks()
39 | 
40 | 
41 | def get_model_and_assets():
42 |   """Returns a tuple containing the model XML string and a dict of assets."""
43 |   return common.read_model('cheetah.xml'), common.ASSETS
44 | 
45 | 
46 | @SUITE.add('benchmarking')
47 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
48 |   """Returns the run task."""
49 |   physics = Physics.from_xml_string(*get_model_and_assets())
50 |   task = Cheetah(random=random)
51 |   environment_kwargs = environment_kwargs or {}
52 |   return control.Environment(physics, task, time_limit=time_limit,
53 |                              **environment_kwargs)
54 | 
55 | 
56 | class Physics(mujoco.Physics):
57 |   """Physics simulation with additional features for the Cheetah domain."""
58 | 
59 |   def speed(self):
60 |     """Returns the horizontal speed of the Cheetah."""
61 |     return self.named.data.sensordata['torso_subtreelinvel'][0]
62 | 
63 | 
64 | class Cheetah(base.Task):
65 |   """A `Task` to train a running Cheetah."""
66 | 
67 |   def initialize_episode(self, physics):
68 |     """Sets the state of the environment at the start of each episode."""
69 |     # The indexing below assumes that all joints have a single DOF.
70 |     assert physics.model.nq == physics.model.njnt
71 |     is_limited = physics.model.jnt_limited == 1
72 |     lower, upper = physics.model.jnt_range[is_limited].T
73 |     physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
74 | 
75 |     # Stabilize the model before the actual simulation.
76 |     for _ in range(200):
77 |       physics.step()
78 | 
79 |     physics.data.time = 0
80 |     self._timeout_progress = 0
81 |     super(Cheetah, self).initialize_episode(physics)
82 | 
83 |   def get_observation(self, physics):
84 |     """Returns an observation of the state, ignoring horizontal position."""
85 |     obs = collections.OrderedDict()
86 |     # Ignores horizontal position to maintain translational invariance.
87 |     obs['position'] = physics.data.qpos[1:].copy()
88 |     obs['velocity'] = physics.velocity()
89 |     return obs
90 | 
91 |   def get_reward(self, physics):
92 |     """Returns a reward to the agent."""
93 |     return rewards.tolerance(physics.speed(),
94 |                              bounds=(_RUN_SPEED, float('inf')),
95 |                              margin=_RUN_SPEED,
96 |                              value_at_margin=0,
97 |                              sigmoid='linear')
98 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/cheetah.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 7 | 
 8 |   
 9 |     
10 |       
11 |       
12 |     
13 |     
14 |       
15 |     
16 |     
17 |   
18 | 
19 |   
20 | 
21 |   
22 | 
23 |   
24 |     
25 |     
26 |       
27 |       
28 |       
29 |       
30 |       
31 |       
32 |       
33 |       
34 |       
35 |         
36 |         
37 |         
38 |           
39 |           
40 |           
41 |             
42 |             
43 |           
44 |         
45 |       
46 |       
47 |         
48 |         
49 |         
50 |           
51 |           
52 |           
53 |             
54 |             
55 |           
56 |         
57 |       
58 |     
59 |   
60 | 
61 |   
62 |     
63 |   
64 | 
65 |   
66 |     
67 |     
68 |     
69 |     
70 |     
71 |     
72 |   
73 | 
74 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Functions to manage the common assets for domains."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import os
23 | from dm_control.utils import io as resources
24 | 
25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
26 | _FILENAMES = [
27 |     "./common/materials.xml",
28 |     "./common/materials_white_floor.xml",
29 |     "./common/skybox.xml",
30 |     "./common/visual.xml",
31 | ]
32 | 
33 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
34 |           for filename in _FILENAMES}
35 | 
36 | 
37 | def read_model(model_filename):
38 |   """Reads a model XML file and returns its contents as a string."""
39 |   return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))
40 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/materials.xml:
--------------------------------------------------------------------------------
 1 | 
 6 | 
 7 |   
 8 |     
 9 |     
10 |     
11 |     
12 |     
13 |     
14 |     
15 |     
16 |     
17 |     
18 |     
19 |     
20 |     
21 |     
22 |   
23 | 
24 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/materials_white_floor.xml:
--------------------------------------------------------------------------------
 1 | 
 6 | 
 7 |   
 8 |     
 9 |     
10 |     
11 |     
12 |     
13 |     
14 |     
15 |     
16 |     
17 |     
18 |     
19 |     
20 |     
21 |     
22 |   
23 | 
24 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/skybox.xml:
--------------------------------------------------------------------------------
1 | 
2 |   
3 |       
5 |   
6 | 
7 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/visual.xml:
--------------------------------------------------------------------------------
1 | 
2 |   
3 |     
4 |     
5 |     
6 |   
7 | 
8 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/demos/mocap_demo.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Demonstration of amc parsing for CMU mocap database.
17 | 
18 | To run the demo, supply a path to a `.amc` file:
19 | 
20 |     python mocap_demo --filename='path/to/mocap.amc'
21 | 
22 | CMU motion capture clips are available at mocap.cs.cmu.edu
23 | """
24 | 
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 | 
29 | import time
30 | # Internal dependencies.
31 | 
32 | from absl import app
33 | from absl import flags
34 | 
35 | from local_dm_control_suite import humanoid_CMU
36 | from dm_control.suite.utils import parse_amc
37 | 
38 | import matplotlib.pyplot as plt
39 | import numpy as np
40 | 
41 | FLAGS = flags.FLAGS
42 | flags.DEFINE_string('filename', None, 'amc file to be converted.')
43 | flags.DEFINE_integer('max_num_frames', 90,
44 |                      'Maximum number of frames for plotting/playback')
45 | 
46 | 
47 | def main(unused_argv):
48 |   env = humanoid_CMU.stand()
49 | 
50 |   # Parse and convert specified clip.
51 |   converted = parse_amc.convert(FLAGS.filename,
52 |                                 env.physics, env.control_timestep())
53 | 
54 |   max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1)
55 | 
56 |   width = 480
57 |   height = 480
58 |   video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)
59 | 
60 |   for i in range(max_frame):
61 |     p_i = converted.qpos[:, i]
62 |     with env.physics.reset_context():
63 |       env.physics.data.qpos[:] = p_i
64 |     video[i] = np.hstack([env.physics.render(height, width, camera_id=0),
65 |                           env.physics.render(height, width, camera_id=1)])
66 | 
67 |   tic = time.time()
68 |   for i in range(max_frame):
69 |     if i == 0:
70 |       img = plt.imshow(video[i])
71 |     else:
72 |       img.set_data(video[i])
73 |     toc = time.time()
74 |     clock_dt = toc - tic
75 |     tic = time.time()
76 |     # Real-time playback not always possible as clock_dt > .03
77 |     plt.pause(max(0.01, 0.03 - clock_dt))  # Need min display time > 0.0.
78 |     plt.draw()
79 |   plt.waitforbuttonpress()
80 | 
81 | 
82 | if __name__ == '__main__':
83 |   flags.mark_flag_as_required('filename')
84 |   app.run(main)
85 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/demos/zeros.amc:
--------------------------------------------------------------------------------
  1 | #DUMMY AMC for testing
  2 | :FULLY-SPECIFIED
  3 | :DEGREES
  4 | 1
  5 | root 0 0 0 0 0 0
  6 | lowerback 0 0 0
  7 | upperback 0 0 0
  8 | thorax 0 0 0
  9 | lowerneck 0 0 0
 10 | upperneck 0 0 0
 11 | head 0 0 0
 12 | rclavicle 0 0
 13 | rhumerus 0 0 0
 14 | rradius 0
 15 | rwrist 0
 16 | rhand 0 0
 17 | rfingers 0
 18 | rthumb 0 0
 19 | lclavicle 0 0
 20 | lhumerus 0 0 0
 21 | lradius 0
 22 | lwrist 0
 23 | lhand 0 0
 24 | lfingers 0
 25 | lthumb 0 0
 26 | rfemur 0 0 0
 27 | rtibia 0
 28 | rfoot 0 0
 29 | rtoes 0
 30 | lfemur 0 0 0
 31 | ltibia 0
 32 | lfoot 0 0
 33 | ltoes 0
 34 | 2
 35 | root 0 0 0 0 0 0
 36 | lowerback 0 0 0
 37 | upperback 0 0 0
 38 | thorax 0 0 0
 39 | lowerneck 0 0 0
 40 | upperneck 0 0 0
 41 | head 0 0 0
 42 | rclavicle 0 0
 43 | rhumerus 0 0 0
 44 | rradius 0
 45 | rwrist 0
 46 | rhand 0 0
 47 | rfingers 0
 48 | rthumb 0 0
 49 | lclavicle 0 0
 50 | lhumerus 0 0 0
 51 | lradius 0
 52 | lwrist 0
 53 | lhand 0 0
 54 | lfingers 0
 55 | lthumb 0 0
 56 | rfemur 0 0 0
 57 | rtibia 0
 58 | rfoot 0 0
 59 | rtoes 0
 60 | lfemur 0 0 0
 61 | ltibia 0
 62 | lfoot 0 0
 63 | ltoes 0
 64 | 3
 65 | root 0 0 0 0 0 0
 66 | lowerback 0 0 0
 67 | upperback 0 0 0
 68 | thorax 0 0 0
 69 | lowerneck 0 0 0
 70 | upperneck 0 0 0
 71 | head 0 0 0
 72 | rclavicle 0 0
 73 | rhumerus 0 0 0
 74 | rradius 0
 75 | rwrist 0
 76 | rhand 0 0
 77 | rfingers 0
 78 | rthumb 0 0
 79 | lclavicle 0 0
 80 | lhumerus 0 0 0
 81 | lradius 0
 82 | lwrist 0
 83 | lhand 0 0
 84 | lfingers 0
 85 | lthumb 0 0
 86 | rfemur 0 0 0
 87 | rtibia 0
 88 | rfoot 0 0
 89 | rtoes 0
 90 | lfemur 0 0 0
 91 | ltibia 0
 92 | lfoot 0 0
 93 | ltoes 0
 94 | 4
 95 | root 0 0 0 0 0 0
 96 | lowerback 0 0 0
 97 | upperback 0 0 0
 98 | thorax 0 0 0
 99 | lowerneck 0 0 0
100 | upperneck 0 0 0
101 | head 0 0 0
102 | rclavicle 0 0
103 | rhumerus 0 0 0
104 | rradius 0
105 | rwrist 0
106 | rhand 0 0
107 | rfingers 0
108 | rthumb 0 0
109 | lclavicle 0 0
110 | lhumerus 0 0 0
111 | lradius 0
112 | lwrist 0
113 | lhand 0 0
114 | lfingers 0
115 | lthumb 0 0
116 | rfemur 0 0 0
117 | rtibia 0
118 | rfoot 0 0
119 | rtoes 0
120 | lfemur 0 0 0
121 | ltibia 0
122 | lfoot 0 0
123 | ltoes 0
124 | 5
125 | root 0 0 0 0 0 0
126 | lowerback 0 0 0
127 | upperback 0 0 0
128 | thorax 0 0 0
129 | lowerneck 0 0 0
130 | upperneck 0 0 0
131 | head 0 0 0
132 | rclavicle 0 0
133 | rhumerus 0 0 0
134 | rradius 0
135 | rwrist 0
136 | rhand 0 0
137 | rfingers 0
138 | rthumb 0 0
139 | lclavicle 0 0
140 | lhumerus 0 0 0
141 | lradius 0
142 | lwrist 0
143 | lhand 0 0
144 | lfingers 0
145 | lthumb 0 0
146 | rfemur 0 0 0
147 | rtibia 0
148 | rfoot 0 0
149 | rtoes 0
150 | lfemur 0 0 0
151 | ltibia 0
152 | lfoot 0 0
153 | ltoes 0
154 | 6
155 | root 0 0 0 0 0 0
156 | lowerback 0 0 0
157 | upperback 0 0 0
158 | thorax 0 0 0
159 | lowerneck 0 0 0
160 | upperneck 0 0 0
161 | head 0 0 0
162 | rclavicle 0 0
163 | rhumerus 0 0 0
164 | rradius 0
165 | rwrist 0
166 | rhand 0 0
167 | rfingers 0
168 | rthumb 0 0
169 | lclavicle 0 0
170 | lhumerus 0 0 0
171 | lradius 0
172 | lwrist 0
173 | lhand 0 0
174 | lfingers 0
175 | lthumb 0 0
176 | rfemur 0 0 0
177 | rtibia 0
178 | rfoot 0 0
179 | rtoes 0
180 | lfemur 0 0 0
181 | ltibia 0
182 | lfoot 0 0
183 | ltoes 0
184 | 7
185 | root 0 0 0 0 0 0
186 | lowerback 0 0 0
187 | upperback 0 0 0
188 | thorax 0 0 0
189 | lowerneck 0 0 0
190 | upperneck 0 0 0
191 | head 0 0 0
192 | rclavicle 0 0
193 | rhumerus 0 0 0
194 | rradius 0
195 | rwrist 0
196 | rhand 0 0
197 | rfingers 0
198 | rthumb 0 0
199 | lclavicle 0 0
200 | lhumerus 0 0 0
201 | lradius 0
202 | lwrist 0
203 | lhand 0 0
204 | lfingers 0
205 | lthumb 0 0
206 | rfemur 0 0 0
207 | rtibia 0
208 | rfoot 0 0
209 | rtoes 0
210 | lfemur 0 0 0
211 | ltibia 0
212 | lfoot 0 0
213 | ltoes 0
214 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/explore.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2018 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Control suite environments explorer."""
16 | 
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | 
21 | from absl import app
22 | from absl import flags
23 | from dm_control import suite
24 | from dm_control.suite.wrappers import action_noise
25 | from six.moves import input
26 | 
27 | from dm_control import viewer
28 | 
29 | 
30 | _ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS]
31 | 
32 | flags.DEFINE_enum('environment_name', None, _ALL_NAMES,
33 |                   'Optional \'domain_name.task_name\' pair specifying the '
34 |                   'environment to load. If unspecified a prompt will appear to '
35 |                   'select one.')
36 | flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.')
37 | flags.DEFINE_bool('visualize_reward', True,
38 |                   'Whether to vary the colors of geoms according to the '
39 |                   'current reward value.')
40 | flags.DEFINE_float('action_noise', 0.,
41 |                    'Standard deviation of Gaussian noise to apply to actions, '
42 |                    'expressed as a fraction of the max-min range for each '
43 |                    'action dimension. Defaults to 0, i.e. no noise.')
44 | FLAGS = flags.FLAGS
45 | 
46 | 
47 | def prompt_environment_name(prompt, values):
48 |   environment_name = None
49 |   while not environment_name:
50 |     environment_name = input(prompt)
51 |     if not environment_name or values.index(environment_name) < 0:
52 |       print('"%s" is not a valid environment name.' % environment_name)
53 |       environment_name = None
54 |   return environment_name
55 | 
56 | 
57 | def main(argv):
58 |   del argv
59 |   environment_name = FLAGS.environment_name
60 |   if environment_name is None:
61 |     print('\n  '.join(['Available environments:'] + _ALL_NAMES))
62 |     environment_name = prompt_environment_name(
63 |         'Please select an environment name: ', _ALL_NAMES)
64 | 
65 |   index = _ALL_NAMES.index(environment_name)
66 |   domain_name, task_name = suite.ALL_TASKS[index]
67 | 
68 |   task_kwargs = {}
69 |   if not FLAGS.timeout:
70 |     task_kwargs['time_limit'] = float('inf')
71 | 
72 |   def loader():
73 |     env = suite.load(
74 |         domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs)
75 |     env.task.visualize_reward = FLAGS.visualize_reward
76 |     if FLAGS.action_noise > 0:
77 |       env = action_noise.Wrapper(env, scale=FLAGS.action_noise)
78 |     return env
79 | 
80 |   viewer.launch(loader)
81 | 
82 | 
83 | if __name__ == '__main__':
84 |   app.run(main)
85 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/finger.xml:
--------------------------------------------------------------------------------
 1 |  
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |     
13 |     
14 |     
15 |       
16 |       
17 |     
18 |   
19 | 
20 |   
21 |     
22 |     
23 |     
24 |     
25 | 
26 |     
27 |       
28 |       
29 |       
30 |       
31 |         
32 |         
33 |         
34 |         
35 |         
36 |       
37 |     
38 | 
39 |     
40 |       
41 |       
42 |       
43 |       
44 |       
45 |     
46 | 
47 |     
48 |   
49 | 
50 |   
51 |     
52 |     
53 |   
54 | 
55 |   
56 |   
57 |     
58 |     
59 |     
60 |     
61 |     
62 |     
63 |     
64 |     
65 |     
66 |     
67 |     
68 |     
69 |   
70 | 
71 | 
72 | 
73 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/fish.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 |       
 6 |   
 7 | 
 8 | 
 9 |   
12 | 
13 |   
14 |     
15 |     
16 |       
17 |       
18 |     
19 |   
20 | 
21 |   
22 |     
23 |     
24 |     
25 |     
26 |     
27 |     
28 |     
29 |       
30 |       
31 |       
32 |       
33 |       
34 |       
35 |       
36 |       
37 |       
38 |       
39 |       
40 |         
41 |         
42 |         
43 |         
44 |           
45 |           
46 |         
47 |       
48 |       
49 |         
50 |         
51 |         
52 |       
53 |       
54 |         
55 |         
56 |         
57 |       
58 |     
59 |   
60 | 
61 |   
62 |     
63 |       
64 |       
65 |     
66 |     
67 |       
68 |       
69 |     
70 |   
71 | 
72 |   
73 |     
74 |     
75 |     
76 |     
77 |     
78 |   
79 | 
80 |   
81 |     
82 |     
83 |   
84 | 
85 | 
86 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/hopper.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 7 | 
 8 |   
 9 |     
10 |       
11 |       
12 |       
13 |     
14 |     
15 |       
16 |     
17 |     
18 |   
19 | 
20 |   
21 | 
22 |   
23 |     
24 |     
25 |     
26 |     
27 |       
28 |       
29 |       
30 |       
31 |       
32 |       
33 |       
34 |         
35 |         
36 |         
37 |           
38 |           
39 |           
40 |             
41 |             
42 |             
43 |               
44 |               
45 |               
46 |               
47 |             
48 |           
49 |         
50 |       
51 |     
52 |   
53 | 
54 |   
55 |     
56 |     
57 |     
58 |   
59 | 
60 |   
61 |     
62 |     
63 |     
64 |     
65 |   
66 | 
67 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/lqr.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 7 | 
 8 |   
 9 |     
10 |     
11 |     
12 |     
13 |   
14 | 
15 |   
18 | 
19 |   
20 |     
21 |     
22 |     
23 |     
24 |     
25 |   
26 | 
27 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/pendulum.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Pendulum domain."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | from dm_control import mujoco
 25 | from dm_control.rl import control
 26 | from local_dm_control_suite import base
 27 | from local_dm_control_suite import common
 28 | from dm_control.utils import containers
 29 | from dm_control.utils import rewards
 30 | import numpy as np
 31 | 
 32 | 
 33 | _DEFAULT_TIME_LIMIT = 20
 34 | _ANGLE_BOUND = 8
 35 | _COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND))
 36 | SUITE = containers.TaggedTasks()
 37 | 
 38 | 
 39 | def get_model_and_assets():
 40 |   """Returns a tuple containing the model XML string and a dict of assets."""
 41 |   return common.read_model('pendulum.xml'), common.ASSETS
 42 | 
 43 | 
 44 | @SUITE.add('benchmarking')
 45 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
 46 |             environment_kwargs=None):
 47 |   """Returns pendulum swingup task ."""
 48 |   physics = Physics.from_xml_string(*get_model_and_assets())
 49 |   task = SwingUp(random=random)
 50 |   environment_kwargs = environment_kwargs or {}
 51 |   return control.Environment(
 52 |       physics, task, time_limit=time_limit, **environment_kwargs)
 53 | 
 54 | 
 55 | class Physics(mujoco.Physics):
 56 |   """Physics simulation with additional features for the Pendulum domain."""
 57 | 
 58 |   def pole_vertical(self):
 59 |     """Returns vertical (z) component of pole frame."""
 60 |     return self.named.data.xmat['pole', 'zz']
 61 | 
 62 |   def angular_velocity(self):
 63 |     """Returns the angular velocity of the pole."""
 64 |     return self.named.data.qvel['hinge'].copy()
 65 | 
 66 |   def pole_orientation(self):
 67 |     """Returns both horizontal and vertical components of pole frame."""
 68 |     return self.named.data.xmat['pole', ['zz', 'xz']]
 69 | 
 70 | 
 71 | class SwingUp(base.Task):
 72 |   """A Pendulum `Task` to swing up and balance the pole."""
 73 | 
 74 |   def __init__(self, random=None):
 75 |     """Initialize an instance of `Pendulum`.
 76 | 
 77 |     Args:
 78 |       random: Optional, either a `numpy.random.RandomState` instance, an
 79 |         integer seed for creating a new `RandomState`, or None to select a seed
 80 |         automatically (default).
 81 |     """
 82 |     super(SwingUp, self).__init__(random=random)
 83 | 
 84 |   def initialize_episode(self, physics):
 85 |     """Sets the state of the environment at the start of each episode.
 86 | 
 87 |     Pole is set to a random angle between [-pi, pi).
 88 | 
 89 |     Args:
 90 |       physics: An instance of `Physics`.
 91 | 
 92 |     """
 93 |     physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi)
 94 |     super(SwingUp, self).initialize_episode(physics)
 95 | 
 96 |   def get_observation(self, physics):
 97 |     """Returns an observation.
 98 | 
 99 |     Observations are states concatenating pole orientation and angular velocity
100 |     and pixels from fixed camera.
101 | 
102 |     Args:
103 |       physics: An instance of `physics`, Pendulum physics.
104 | 
105 |     Returns:
106 |       A `dict` of observation.
107 |     """
108 |     obs = collections.OrderedDict()
109 |     obs['orientation'] = physics.pole_orientation()
110 |     obs['velocity'] = physics.angular_velocity()
111 |     return obs
112 | 
113 |   def get_reward(self, physics):
114 |     return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1))
115 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/pendulum.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |     
13 |     
14 |     
15 |     
16 |       
17 |       
18 |       
19 |       
20 |     
21 |   
22 | 
23 |   
24 |     
25 |   
26 | 
27 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/point_mass.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |     
13 |   
14 | 
15 |   
16 |     
17 |     
18 |     
19 |     
20 |     
21 |     
22 |     
23 | 
24 |     
25 |       
26 |       
27 |       
28 |       
29 |     
30 | 
31 |     
32 |   
33 | 
34 |   
35 |     
36 |       
37 |       
38 |     
39 |     
40 |       
41 |       
42 |     
43 |   
44 | 
45 |   
46 |     
47 |     
48 |   
49 | 
50 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/reacher.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Reacher domain."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | from dm_control import mujoco
 25 | from dm_control.rl import control
 26 | from local_dm_control_suite import base
 27 | from local_dm_control_suite import common
 28 | from dm_control.suite.utils import randomizers
 29 | from dm_control.utils import containers
 30 | from dm_control.utils import rewards
 31 | import numpy as np
 32 | 
 33 | SUITE = containers.TaggedTasks()
 34 | _DEFAULT_TIME_LIMIT = 20
 35 | _BIG_TARGET = .05
 36 | _SMALL_TARGET = .015
 37 | 
 38 | 
 39 | def get_model_and_assets():
 40 |   """Returns a tuple containing the model XML string and a dict of assets."""
 41 |   return common.read_model('reacher.xml'), common.ASSETS
 42 | 
 43 | 
 44 | @SUITE.add('benchmarking', 'easy')
 45 | def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
 46 |   """Returns reacher with sparse reward with 5e-2 tol and randomized target."""
 47 |   physics = Physics.from_xml_string(*get_model_and_assets())
 48 |   task = Reacher(target_size=_BIG_TARGET, random=random)
 49 |   environment_kwargs = environment_kwargs or {}
 50 |   return control.Environment(
 51 |       physics, task, time_limit=time_limit, **environment_kwargs)
 52 | 
 53 | 
 54 | @SUITE.add('benchmarking')
 55 | def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
 56 |   """Returns reacher with sparse reward with 1e-2 tol and randomized target."""
 57 |   physics = Physics.from_xml_string(*get_model_and_assets())
 58 |   task = Reacher(target_size=_SMALL_TARGET, random=random)
 59 |   environment_kwargs = environment_kwargs or {}
 60 |   return control.Environment(
 61 |       physics, task, time_limit=time_limit, **environment_kwargs)
 62 | 
 63 | 
 64 | class Physics(mujoco.Physics):
 65 |   """Physics simulation with additional features for the Reacher domain."""
 66 | 
 67 |   def finger_to_target(self):
 68 |     """Returns the vector from target to finger in global coordinates."""
 69 |     return (self.named.data.geom_xpos['target', :2] -
 70 |             self.named.data.geom_xpos['finger', :2])
 71 | 
 72 |   def finger_to_target_dist(self):
 73 |     """Returns the signed distance between the finger and target surface."""
 74 |     return np.linalg.norm(self.finger_to_target())
 75 | 
 76 | 
 77 | class Reacher(base.Task):
 78 |   """A reacher `Task` to reach the target."""
 79 | 
 80 |   def __init__(self, target_size, random=None):
 81 |     """Initialize an instance of `Reacher`.
 82 | 
 83 |     Args:
 84 |       target_size: A `float`, tolerance to determine whether finger reached the
 85 |           target.
 86 |       random: Optional, either a `numpy.random.RandomState` instance, an
 87 |         integer seed for creating a new `RandomState`, or None to select a seed
 88 |         automatically (default).
 89 |     """
 90 |     self._target_size = target_size
 91 |     super(Reacher, self).__init__(random=random)
 92 | 
 93 |   def initialize_episode(self, physics):
 94 |     """Sets the state of the environment at the start of each episode."""
 95 |     physics.named.model.geom_size['target', 0] = self._target_size
 96 |     randomizers.randomize_limited_and_rotational_joints(physics, self.random)
 97 | 
 98 |     # Randomize target position
 99 |     angle = self.random.uniform(0, 2 * np.pi)
100 |     radius = self.random.uniform(.05, .20)
101 |     physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle)
102 |     physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle)
103 | 
104 |     super(Reacher, self).initialize_episode(physics)
105 | 
106 |   def get_observation(self, physics):
107 |     """Returns an observation of the state and the target position."""
108 |     obs = collections.OrderedDict()
109 |     obs['position'] = physics.position()
110 |     obs['to_target'] = physics.finger_to_target()
111 |     obs['velocity'] = physics.velocity()
112 |     return obs
113 | 
114 |   def get_reward(self, physics):
115 |     radii = physics.named.model.geom_size[['target', 'finger'], 0].sum()
116 |     return rewards.tolerance(physics.finger_to_target_dist(), (0, radii))
117 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/reacher.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |     
13 |   
14 | 
15 |   
16 |     
17 |     
18 |     
19 |     
20 |     
21 |     
22 |     
23 |     
24 | 
25 |     
26 |     
27 |     
28 |       
29 |       
30 |       
31 |         
32 |         
33 |         
34 |           
35 |           
36 |         
37 |       
38 |     
39 |     
40 |     
41 |   
42 | 
43 |   
44 |     
45 |     
46 |   
47 | 
48 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/swimmer.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 9 | 
10 |   
11 |     
12 |       
13 |       
14 |         
15 |       
16 |       
17 |         
18 |       
19 |       
20 |     
21 |     
22 |       
23 |     
24 |     
25 |   
26 | 
27 |   
28 |     
29 |     
30 |       
31 |       
32 |       
33 |       
34 |       
35 |       
36 |       
37 |       
38 |       
39 |       
40 |       
41 |       
42 |       
43 |     
44 |     
45 |     
46 |   
47 | 
48 |   
49 |     
50 |     
51 |     
52 |     
53 |     
54 |     
55 |   
56 | 
57 | 
58 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/tests/loader_test.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Tests for the dm_control.suite loader."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | # Internal dependencies.
23 | 
24 | from absl.testing import absltest
25 | 
26 | from dm_control import suite
27 | from dm_control.rl import control
28 | 
29 | 
30 | class LoaderTest(absltest.TestCase):
31 | 
32 |   def test_load_without_kwargs(self):
33 |     env = suite.load('cartpole', 'swingup')
34 |     self.assertIsInstance(env, control.Environment)
35 | 
36 |   def test_load_with_kwargs(self):
37 |     env = suite.load('cartpole', 'swingup',
38 |                      task_kwargs={'time_limit': 40, 'random': 99})
39 |     self.assertIsInstance(env, control.Environment)
40 | 
41 | 
42 | class LoaderConstantsTest(absltest.TestCase):
43 | 
44 |   def testSuiteConstants(self):
45 |     self.assertNotEmpty(suite.BENCHMARKING)
46 |     self.assertNotEmpty(suite.EASY)
47 |     self.assertNotEmpty(suite.HARD)
48 |     self.assertNotEmpty(suite.EXTRA)
49 | 
50 | 
51 | if __name__ == '__main__':
52 |   absltest.main()
53 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/tests/lqr_test.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Tests specific to the LQR domain."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import  math
23 | import unittest
24 | 
25 | # Internal dependencies.
26 | from absl import logging
27 | 
28 | from absl.testing import absltest
29 | from absl.testing import parameterized
30 | 
31 | from local_dm_control_suite import lqr
32 | from local_dm_control_suite import lqr_solver
33 | 
34 | import numpy as np
35 | from six.moves import range
36 | 
37 | 
38 | class LqrTest(parameterized.TestCase):
39 | 
40 |   @parameterized.named_parameters(
41 |       ('lqr_2_1', lqr.lqr_2_1),
42 |       ('lqr_6_2', lqr.lqr_6_2))
43 |   def test_lqr_optimal_policy(self, make_env):
44 |     env = make_env()
45 |     p, k, beta = lqr_solver.solve(env)
46 |     self.assertPolicyisOptimal(env, p, k, beta)
47 | 
48 |   @parameterized.named_parameters(
49 |       ('lqr_2_1', lqr.lqr_2_1),
50 |       ('lqr_6_2', lqr.lqr_6_2))
51 |   @unittest.skipUnless(
52 |       condition=lqr_solver.sp,
53 |       reason='scipy is not available, so non-scipy DARE solver is the default.')
54 |   def test_lqr_optimal_policy_no_scipy(self, make_env):
55 |     env = make_env()
56 |     old_sp = lqr_solver.sp
57 |     try:
58 |       lqr_solver.sp = None  # Force the solver to use the non-scipy code path.
59 |       p, k, beta = lqr_solver.solve(env)
60 |     finally:
61 |       lqr_solver.sp = old_sp
62 |     self.assertPolicyisOptimal(env, p, k, beta)
63 | 
64 |   def assertPolicyisOptimal(self, env, p, k, beta):
65 |     tolerance = 1e-3
66 |     n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta)))
67 |     logging.info('%d timesteps for %g convergence.', n_steps, tolerance)
68 |     total_loss = 0.0
69 | 
70 |     timestep = env.reset()
71 |     initial_state = np.hstack((timestep.observation['position'],
72 |                                timestep.observation['velocity']))
73 |     logging.info('Measuring total cost over %d steps.', n_steps)
74 |     for _ in range(n_steps):
75 |       x = np.hstack((timestep.observation['position'],
76 |                      timestep.observation['velocity']))
77 |       # u = k*x is the optimal policy
78 |       u = k.dot(x)
79 |       total_loss += 1 - (timestep.reward or 0.0)
80 |       timestep = env.step(u)
81 | 
82 |     logging.info('Analytical expected total cost is .5*x^T*p*x.')
83 |     expected_loss = .5 * initial_state.T.dot(p).dot(initial_state)
84 |     logging.info('Comparing measured and predicted costs.')
85 |     np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance)
86 | 
87 | if __name__ == '__main__':
88 |   absltest.main()
89 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/utils/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Utility functions used in the control suite."""
17 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/utils/parse_amc_test.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Tests for parse_amc utility."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import os
23 | 
24 | # Internal dependencies.
25 | 
26 | from absl.testing import absltest
27 | from local_dm_control_suite import humanoid_CMU
28 | from dm_control.suite.utils import parse_amc
29 | 
30 | from dm_control.utils import io as resources
31 | 
32 | _TEST_AMC_PATH = resources.GetResourceFilename(
33 |     os.path.join(os.path.dirname(__file__), '../demos/zeros.amc'))
34 | 
35 | 
36 | class ParseAMCTest(absltest.TestCase):
37 | 
38 |   def test_sizes_of_parsed_data(self):
39 | 
40 |     # Instantiate the humanoid environment.
41 |     env = humanoid_CMU.stand()
42 | 
43 |     # Parse and convert specified clip.
44 |     converted = parse_amc.convert(
45 |         _TEST_AMC_PATH, env.physics, env.control_timestep())
46 | 
47 |     self.assertEqual(converted.qpos.shape[0], 63)
48 |     self.assertEqual(converted.qvel.shape[0], 62)
49 |     self.assertEqual(converted.time.shape[0], converted.qpos.shape[1])
50 |     self.assertEqual(converted.qpos.shape[1],
51 |                      converted.qvel.shape[1] + 1)
52 | 
53 |     # Parse and convert specified clip -- WITH SMALLER TIMESTEP
54 |     converted2 = parse_amc.convert(
55 |         _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep())
56 | 
57 |     self.assertEqual(converted2.qpos.shape[0], 63)
58 |     self.assertEqual(converted2.qvel.shape[0], 62)
59 |     self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1])
60 |     self.assertEqual(converted.qpos.shape[1],
61 |                      converted.qvel.shape[1] + 1)
62 | 
63 |     # Compare sizes of parsed objects for different timesteps
64 |     self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1])
65 | 
66 | 
67 | if __name__ == '__main__':
68 |   absltest.main()
69 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/utils/randomizers.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2017 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Randomization functions."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | from dm_control.mujoco.wrapper import mjbindings
23 | import numpy as np
24 | from six.moves import range
25 | 
26 | 
27 | def random_limited_quaternion(random, limit):
28 |   """Generates a random quaternion limited to the specified rotations."""
29 |   axis = random.randn(3)
30 |   axis /= np.linalg.norm(axis)
31 |   angle = random.rand() * limit
32 | 
33 |   quaternion = np.zeros(4)
34 |   mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
35 | 
36 |   return quaternion
37 | 
38 | 
39 | def randomize_limited_and_rotational_joints(physics, random=None):
40 |   """Randomizes the positions of joints defined in the physics body.
41 | 
42 |   The following randomization rules apply:
43 |     - Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
44 |     - Unbounded hinges are samples uniformly in [-pi, pi]
45 |     - Quaternions for unlimited free joints and ball joints are sampled
46 |       uniformly on the unit 3-sphere.
47 |     - Quaternions for limited ball joints are sampled uniformly on a sector
48 |       of the unit 3-sphere.
49 |     - The linear degrees of freedom of free joints are not randomized.
50 | 
51 |   Args:
52 |     physics: Instance of 'Physics' class that holds a loaded model.
53 |     random: Optional instance of 'np.random.RandomState'. Defaults to the global
54 |       NumPy random state.
55 |   """
56 |   random = random or np.random
57 | 
58 |   hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
59 |   slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
60 |   ball = mjbindings.enums.mjtJoint.mjJNT_BALL
61 |   free = mjbindings.enums.mjtJoint.mjJNT_FREE
62 | 
63 |   qpos = physics.named.data.qpos
64 | 
65 |   for joint_id in range(physics.model.njnt):
66 |     joint_name = physics.model.id2name(joint_id, 'joint')
67 |     joint_type = physics.model.jnt_type[joint_id]
68 |     is_limited = physics.model.jnt_limited[joint_id]
69 |     range_min, range_max = physics.model.jnt_range[joint_id]
70 | 
71 |     if is_limited:
72 |       if joint_type == hinge or joint_type == slide:
73 |         qpos[joint_name] = random.uniform(range_min, range_max)
74 | 
75 |       elif joint_type == ball:
76 |         qpos[joint_name] = random_limited_quaternion(random, range_max)
77 | 
78 |     else:
79 |       if joint_type == hinge:
80 |         qpos[joint_name] = random.uniform(-np.pi, np.pi)
81 | 
82 |       elif joint_type == ball:
83 |         quat = random.randn(4)
84 |         quat /= np.linalg.norm(quat)
85 |         qpos[joint_name] = quat
86 | 
87 |       elif joint_type == free:
88 |         quat = random.rand(4)
89 |         quat /= np.linalg.norm(quat)
90 |         qpos[joint_name][3:] = quat
91 | 
92 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/walker.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |   
 4 |   
 5 | 
 6 |   
 7 | 
 8 |   
 9 | 
10 |   
11 |     
12 |     
13 |     
14 |     
15 |     
16 |       
17 |       
18 |     
19 |   
20 | 
21 |   
22 |     
23 |     
24 |       
25 |       
26 |       
27 |       
28 |       
29 |       
30 |       
31 |       
32 |       
33 |         
34 |         
35 |         
36 |           
37 |           
38 |           
39 |             
40 |             
41 |           
42 |         
43 |       
44 |       
45 |         
46 |         
47 |         
48 |           
49 |           
50 |           
51 |             
52 |             
53 |           
54 |         
55 |       
56 |     
57 |   
58 | 
59 |   
60 |     
61 |   
62 | 
63 |   
64 |     
65 |     
66 |     
67 |     
68 |     
69 |     
70 |   
71 | 
72 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/wrappers/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2018 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Environment wrappers used to extend or modify environment behaviour."""
17 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/wrappers/action_noise.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2018 The dm_control Authors.
 2 | #
 3 | # Licensed under the Apache License, Version 2.0 (the "License");
 4 | # you may not use this file except in compliance with the License.
 5 | # You may obtain a copy of the License at
 6 | #
 7 | #    http://www.apache.org/licenses/LICENSE-2.0
 8 | #
 9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | 
16 | """Wrapper control suite environments that adds Gaussian noise to actions."""
17 | 
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | 
22 | import dm_env
23 | import numpy as np
24 | 
25 | 
26 | _BOUNDS_MUST_BE_FINITE = (
27 |     'All bounds in `env.action_spec()` must be finite, got: {action_spec}')
28 | 
29 | 
30 | class Wrapper(dm_env.Environment):
31 |   """Wraps a control environment and adds Gaussian noise to actions."""
32 | 
33 |   def __init__(self, env, scale=0.01):
34 |     """Initializes a new action noise Wrapper.
35 | 
36 |     Args:
37 |       env: The control suite environment to wrap.
38 |       scale: The standard deviation of the noise, expressed as a fraction
39 |         of the max-min range for each action dimension.
40 | 
41 |     Raises:
42 |       ValueError: If any of the action dimensions of the wrapped environment are
43 |         unbounded.
44 |     """
45 |     action_spec = env.action_spec()
46 |     if not (np.all(np.isfinite(action_spec.minimum)) and
47 |             np.all(np.isfinite(action_spec.maximum))):
48 |       raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec))
49 |     self._minimum = action_spec.minimum
50 |     self._maximum = action_spec.maximum
51 |     self._noise_std = scale * (action_spec.maximum - action_spec.minimum)
52 |     self._env = env
53 | 
54 |   def step(self, action):
55 |     noisy_action = action + self._env.task.random.normal(scale=self._noise_std)
56 |     # Clip the noisy actions in place so that they fall within the bounds
57 |     # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of-
58 |     # bounds control inputs, but we also clip here in case the actions do not
59 |     # correspond directly to MuJoCo actuators, or if there are other wrapper
60 |     # layers that expect the actions to be within bounds.
61 |     np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action)
62 |     return self._env.step(noisy_action)
63 | 
64 |   def reset(self):
65 |     return self._env.reset()
66 | 
67 |   def observation_spec(self):
68 |     return self._env.observation_spec()
69 | 
70 |   def action_spec(self):
71 |     return self._env.action_spec()
72 | 
73 |   def __getattr__(self, name):
74 |     return getattr(self._env, name)
75 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/wrappers/pixels.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Wrapper that adds pixel observations to a control environment."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | import dm_env
 25 | from dm_env import specs
 26 | 
 27 | STATE_KEY = 'state'
 28 | 
 29 | 
 30 | class Wrapper(dm_env.Environment):
 31 |   """Wraps a control environment and adds a rendered pixel observation."""
 32 | 
 33 |   def __init__(self, env, pixels_only=True, render_kwargs=None,
 34 |                observation_key='pixels'):
 35 |     """Initializes a new pixel Wrapper.
 36 | 
 37 |     Args:
 38 |       env: The environment to wrap.
 39 |       pixels_only: If True (default), the original set of 'state' observations
 40 |         returned by the wrapped environment will be discarded, and the
 41 |         `OrderedDict` of observations will only contain pixels. If False, the
 42 |         `OrderedDict` will contain the original observations as well as the
 43 |         pixel observations.
 44 |       render_kwargs: Optional `dict` containing keyword arguments passed to the
 45 |         `mujoco.Physics.render` method.
 46 |       observation_key: Optional custom string specifying the pixel observation's
 47 |         key in the `OrderedDict` of observations. Defaults to 'pixels'.
 48 | 
 49 |     Raises:
 50 |       ValueError: If `env`'s observation spec is not compatible with the
 51 |         wrapper. Supported formats are a single array, or a dict of arrays.
 52 |       ValueError: If `env`'s observation already contains the specified
 53 |         `observation_key`.
 54 |     """
 55 |     if render_kwargs is None:
 56 |       render_kwargs = {}
 57 | 
 58 |     wrapped_observation_spec = env.observation_spec()
 59 | 
 60 |     if isinstance(wrapped_observation_spec, specs.Array):
 61 |       self._observation_is_dict = False
 62 |       invalid_keys = set([STATE_KEY])
 63 |     elif isinstance(wrapped_observation_spec, collections.MutableMapping):
 64 |       self._observation_is_dict = True
 65 |       invalid_keys = set(wrapped_observation_spec.keys())
 66 |     else:
 67 |       raise ValueError('Unsupported observation spec structure.')
 68 | 
 69 |     if not pixels_only and observation_key in invalid_keys:
 70 |       raise ValueError('Duplicate or reserved observation key {!r}.'
 71 |                        .format(observation_key))
 72 | 
 73 |     if pixels_only:
 74 |       self._observation_spec = collections.OrderedDict()
 75 |     elif self._observation_is_dict:
 76 |       self._observation_spec = wrapped_observation_spec.copy()
 77 |     else:
 78 |       self._observation_spec = collections.OrderedDict()
 79 |       self._observation_spec[STATE_KEY] = wrapped_observation_spec
 80 | 
 81 |     # Extend observation spec.
 82 |     pixels = env.physics.render(**render_kwargs)
 83 |     pixels_spec = specs.Array(
 84 |         shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
 85 |     self._observation_spec[observation_key] = pixels_spec
 86 | 
 87 |     self._env = env
 88 |     self._pixels_only = pixels_only
 89 |     self._render_kwargs = render_kwargs
 90 |     self._observation_key = observation_key
 91 | 
 92 |   def reset(self):
 93 |     time_step = self._env.reset()
 94 |     return self._add_pixel_observation(time_step)
 95 | 
 96 |   def step(self, action):
 97 |     time_step = self._env.step(action)
 98 |     return self._add_pixel_observation(time_step)
 99 | 
100 |   def observation_spec(self):
101 |     return self._observation_spec
102 | 
103 |   def action_spec(self):
104 |     return self._env.action_spec()
105 | 
106 |   def _add_pixel_observation(self, time_step):
107 |     if self._pixels_only:
108 |       observation = collections.OrderedDict()
109 |     elif self._observation_is_dict:
110 |       observation = type(time_step.observation)(time_step.observation)
111 |     else:
112 |       observation = collections.OrderedDict()
113 |       observation[STATE_KEY] = time_step.observation
114 | 
115 |     pixels = self._env.physics.render(**self._render_kwargs)
116 |     observation[self._observation_key] = pixels
117 |     return time_step._replace(observation=observation)
118 | 
119 |   def __getattr__(self, name):
120 |     return getattr(self._env, name)
121 | 
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/wrappers/pixels_test.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2017 The dm_control Authors.
  2 | #
  3 | # Licensed under the Apache License, Version 2.0 (the "License");
  4 | # you may not use this file except in compliance with the License.
  5 | # You may obtain a copy of the License at
  6 | #
  7 | #    http://www.apache.org/licenses/LICENSE-2.0
  8 | #
  9 | # Unless required by applicable law or agreed to in writing, software
 10 | # distributed under the License is distributed on an "AS IS" BASIS,
 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
 12 | # See the License for the specific language governing permissions and
 13 | # limitations under the License.
 14 | # ============================================================================
 15 | 
 16 | """Tests for the pixel wrapper."""
 17 | 
 18 | from __future__ import absolute_import
 19 | from __future__ import division
 20 | from __future__ import print_function
 21 | 
 22 | import collections
 23 | 
 24 | # Internal dependencies.
 25 | from absl.testing import absltest
 26 | from absl.testing import parameterized
 27 | from local_dm_control_suite import cartpole
 28 | from dm_control.suite.wrappers import pixels
 29 | import dm_env
 30 | from dm_env import specs
 31 | 
 32 | import numpy as np
 33 | 
 34 | 
 35 | class FakePhysics(object):
 36 | 
 37 |   def render(self, *args, **kwargs):
 38 |     del args
 39 |     del kwargs
 40 |     return np.zeros((4, 5, 3), dtype=np.uint8)
 41 | 
 42 | 
 43 | class FakeArrayObservationEnvironment(dm_env.Environment):
 44 | 
 45 |   def __init__(self):
 46 |     self.physics = FakePhysics()
 47 | 
 48 |   def reset(self):
 49 |     return dm_env.restart(np.zeros((2,)))
 50 | 
 51 |   def step(self, action):
 52 |     del action
 53 |     return dm_env.transition(0.0, np.zeros((2,)))
 54 | 
 55 |   def action_spec(self):
 56 |     pass
 57 | 
 58 |   def observation_spec(self):
 59 |     return specs.Array(shape=(2,), dtype=np.float)
 60 | 
 61 | 
 62 | class PixelsTest(parameterized.TestCase):
 63 | 
 64 |   @parameterized.parameters(True, False)
 65 |   def test_dict_observation(self, pixels_only):
 66 |     pixel_key = 'rgb'
 67 | 
 68 |     env = cartpole.swingup()
 69 | 
 70 |     # Make sure we are testing the right environment for the test.
 71 |     observation_spec = env.observation_spec()
 72 |     self.assertIsInstance(observation_spec, collections.OrderedDict)
 73 | 
 74 |     width = 320
 75 |     height = 240
 76 | 
 77 |     # The wrapper should only add one observation.
 78 |     wrapped = pixels.Wrapper(env,
 79 |                              observation_key=pixel_key,
 80 |                              pixels_only=pixels_only,
 81 |                              render_kwargs={'width': width, 'height': height})
 82 | 
 83 |     wrapped_observation_spec = wrapped.observation_spec()
 84 |     self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
 85 | 
 86 |     if pixels_only:
 87 |       self.assertLen(wrapped_observation_spec, 1)
 88 |       self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
 89 |     else:
 90 |       expected_length = len(observation_spec) + 1
 91 |       self.assertLen(wrapped_observation_spec, expected_length)
 92 |       expected_keys = list(observation_spec.keys()) + [pixel_key]
 93 |       self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))
 94 | 
 95 |     # Check that the added spec item is consistent with the added observation.
 96 |     time_step = wrapped.reset()
 97 |     rgb_observation = time_step.observation[pixel_key]
 98 |     wrapped_observation_spec[pixel_key].validate(rgb_observation)
 99 | 
100 |     self.assertEqual(rgb_observation.shape, (height, width, 3))
101 |     self.assertEqual(rgb_observation.dtype, np.uint8)
102 | 
103 |   @parameterized.parameters(True, False)
104 |   def test_single_array_observation(self, pixels_only):
105 |     pixel_key = 'depth'
106 | 
107 |     env = FakeArrayObservationEnvironment()
108 |     observation_spec = env.observation_spec()
109 |     self.assertIsInstance(observation_spec, specs.Array)
110 | 
111 |     wrapped = pixels.Wrapper(env, observation_key=pixel_key,
112 |                              pixels_only=pixels_only)
113 |     wrapped_observation_spec = wrapped.observation_spec()
114 |     self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
115 | 
116 |     if pixels_only:
117 |       self.assertLen(wrapped_observation_spec, 1)
118 |       self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
119 |     else:
120 |       self.assertLen(wrapped_observation_spec, 2)
121 |       self.assertEqual([pixels.STATE_KEY, pixel_key],
122 |                        list(wrapped_observation_spec.keys()))
123 | 
124 |     time_step = wrapped.reset()
125 | 
126 |     depth_observation = time_step.observation[pixel_key]
127 |     wrapped_observation_spec[pixel_key].validate(depth_observation)
128 | 
129 |     self.assertEqual(depth_observation.shape, (4, 5, 3))
130 |     self.assertEqual(depth_observation.dtype, np.uint8)
131 | 
132 | if __name__ == '__main__':
133 |   absltest.main()
134 | 
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | absl-py==0.12.0
 2 | cachetools==4.2.2
 3 | certifi==2021.5.30
 4 | cffi==1.15.0
 5 | chardet==4.0.0
 6 | cloudpickle==1.6.0
 7 | cycler==0.10.0
 8 | Cython==0.29.24
 9 | dataclasses==0.8
10 | decorator==4.4.2
11 | fasteners==0.16.3
12 | flatbuffers==1.12
13 | future==0.18.2
14 | gast==0.2.2
15 | gin-config==0.4.0
16 | glfw==2.1.0
17 | google-auth==1.28.1
18 | google-auth-oauthlib==0.4.4
19 | google-pasta==0.2.0
20 | grpcio==1.32.0
21 | gym==0.18.0
22 | h5py==2.10.0
23 | idna==2.10
24 | imageio==2.9.0
25 | imageio-ffmpeg==0.4.2
26 | importlib-metadata==3.10.0
27 | kiwisolver==1.3.1
28 | labmaze==1.0.5
29 | lockfile==0.12.2
30 | Markdown==3.3.4
31 | matplotlib==3.3.4
32 | mujoco-py==2.0.2.13
33 | networkx==2.5.1
34 | numpy==1.19.1
35 | oauthlib==3.1.0
36 | opencv-python==4.5.1.48
37 | opt-einsum==3.3.0
38 | pandas==1.1.5
39 | Pillow>=8.3.2
40 | Pillow-SIMD==7.0.0.post3
41 | protobuf==3.17.1
42 | pyasn1==0.4.8
43 | pyasn1-modules==0.2.8
44 | pybullet==3.2.0
45 | pycparser==2.20
46 | pyglet==1.5.0
47 | PyOpenGL==3.1.5
48 | pyparsing==2.4.7
49 | python-dateutil==2.8.1
50 | pytz==2021.1
51 | PyWavelets==1.1.1
52 | PyYAML==6.0
53 | requests==2.25.1
54 | requests-oauthlib==1.3.0
55 | rsa==4.7.2
56 | scikit-image==0.17.2
57 | scipy==1.5.4
58 | seaborn==0.11.1
59 | six==1.16.0
60 | sortedcontainers==2.4.0
61 | tensorboard==2.0.2
62 | tensorboard-data-server==0.6.0
63 | tensorboard-plugin-wit==1.8.0
64 | termcolor==1.1.0
65 | tifffile==2020.9.3
66 | torch==1.8.1
67 | tqdm==4.60.0
68 | typing-extensions==3.7.4.3
69 | urllib3>=1.26.5
70 | Werkzeug==1.0.1
71 | wrapt==1.12.1
72 | zipp==3.4.1
73 | 
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
 1 | CUDA_VISIBLE_DEVICES=0 python train.py \
 2 |     --domain_name $3 \
 3 |     --task_name $4 --case $1 \
 4 |     --encoder_type pixel --work_dir ./tmp/test \
 5 |     --action_repeat 8 --num_eval_episodes 10 \
 6 |     --pre_transform_image_size 100 --image_size 84 --replay_buffer_capacity 100 \
 7 |     --frame_stack 3 --data_augs no_aug  \
 8 |     --seed 23 --critic_lr 1e-3 --actor_lr 1e-3 --eval_freq 10000 \
 9 |     --batch_size 16 --num_train_steps 200000 --metric_loss \
10 |     --init_steps 1000 \
11 |     --resource_files './distractors/driving/*.mp4' --img_source 'video' --total_frames 50 \
12 |     --horizon $2 --save_model
13 | 
--------------------------------------------------------------------------------
/video.py:
--------------------------------------------------------------------------------
 1 | import imageio
 2 | import os
 3 | import numpy as np
 4 | 
 5 | 
 6 | class VideoRecorder(object):
 7 |     def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30):
 8 |         self.dir_name = dir_name
 9 |         self.height = height
10 |         self.width = width
11 |         self.camera_id = camera_id
12 |         self.fps = fps
13 |         self.frames = []
14 | 
15 |     def init(self, enabled=True):
16 |         self.frames = []
17 |         self.enabled = self.dir_name is not None and enabled
18 | 
19 |     def record(self, env):
20 |         if self.enabled:
21 |             try:
22 |                 frame = env.render(
23 |                     mode='rgb_array',
24 |                     height=self.height,
25 |                     width=self.width,
26 |                     camera_id=self.camera_id
27 |                 )
28 |             except:
29 |                 frame = env.render(
30 |                     mode='rgb_array',
31 |                 )
32 |     
33 |             self.frames.append(frame)
34 | 
35 |     def save(self, file_name):
36 |         if self.enabled:
37 |             path = os.path.join(self.dir_name, file_name)
38 |             imageio.mimsave(path, self.frames, fps=self.fps)
39 | 
--------------------------------------------------------------------------------