├── .gitignore
├── LICENSE
├── README.md
├── TransformLayer.py
├── agents
├── agent_sac.py
├── agent_sac_SPR.py
├── agent_sac_base.py
├── agent_sac_contrastive.py
├── agent_sac_noreward.py
├── agent_sac_norewtotransition.py
├── agent_sac_notransition.py
├── agent_sac_reconstruction.py
├── agent_sac_value.py
└── auxiliary_funcs.py
├── curl_sac.py
├── data_augs.py
├── distractors
└── driving
│ └── 1.mp4
├── dmc2gym
├── __init__.py
├── natural_imgsource.py
└── wrappers.py
├── encoder.py
├── local_dm_control_suite
├── README.md
├── __init__.py
├── acrobot.py
├── acrobot.xml
├── ball_in_cup.py
├── ball_in_cup.xml
├── base.py
├── cartpole.py
├── cartpole.xml
├── cheetah.py
├── cheetah.xml
├── common
│ ├── __init__.py
│ ├── materials.xml
│ ├── materials_white_floor.xml
│ ├── skybox.xml
│ └── visual.xml
├── demos
│ ├── mocap_demo.py
│ └── zeros.amc
├── explore.py
├── finger.py
├── finger.xml
├── fish.py
├── fish.xml
├── hopper.py
├── hopper.xml
├── humanoid.py
├── humanoid.xml
├── humanoid_CMU.py
├── humanoid_CMU.xml
├── lqr.py
├── lqr.xml
├── lqr_solver.py
├── manipulator.py
├── manipulator.xml
├── pendulum.py
├── pendulum.xml
├── point_mass.py
├── point_mass.xml
├── quadruped.py
├── quadruped.xml
├── reacher.py
├── reacher.xml
├── stacker.py
├── stacker.xml
├── swimmer.py
├── swimmer.xml
├── tests
│ ├── domains_test.py
│ ├── loader_test.py
│ └── lqr_test.py
├── utils
│ ├── __init__.py
│ ├── parse_amc.py
│ ├── parse_amc_test.py
│ ├── randomizers.py
│ └── randomizers_test.py
├── walker.py
├── walker.xml
└── wrappers
│ ├── __init__.py
│ ├── action_noise.py
│ ├── action_noise_test.py
│ ├── pixels.py
│ └── pixels_test.py
├── local_dm_control_suite_off_center
├── README.md
├── __init__.py
├── acrobot.py
├── acrobot.xml
├── ball_in_cup.py
├── ball_in_cup.xml
├── base.py
├── cartpole.py
├── cartpole.xml
├── cheetah.py
├── cheetah.xml
├── common
│ ├── __init__.py
│ ├── materials.xml
│ ├── materials_white_floor.xml
│ ├── skybox.xml
│ └── visual.xml
├── demos
│ ├── mocap_demo.py
│ └── zeros.amc
├── explore.py
├── finger.py
├── finger.xml
├── fish.py
├── fish.xml
├── hopper.py
├── hopper.xml
├── humanoid.py
├── humanoid.xml
├── humanoid_CMU.py
├── humanoid_CMU.xml
├── lqr.py
├── lqr.xml
├── lqr_solver.py
├── manipulator.py
├── manipulator.xml
├── pendulum.py
├── pendulum.xml
├── point_mass.py
├── point_mass.xml
├── quadruped.py
├── quadruped.xml
├── reacher.py
├── reacher.xml
├── stacker.py
├── stacker.xml
├── swimmer.py
├── swimmer.xml
├── tests
│ ├── domains_test.py
│ ├── loader_test.py
│ └── lqr_test.py
├── utils
│ ├── __init__.py
│ ├── parse_amc.py
│ ├── parse_amc_test.py
│ ├── randomizers.py
│ └── randomizers_test.py
├── walker.py
├── walker.xml
└── wrappers
│ ├── __init__.py
│ ├── action_noise.py
│ ├── action_noise_test.py
│ ├── pixels.py
│ └── pixels_test.py
├── logger.py
├── multistep_dynamics.py
├── multistep_replay.py
├── multistep_utils.py
├── requirements.txt
├── scripts
└── run.sh
├── train.py
├── utils.py
└── video.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | **/tb/**
3 | **/__pycache__/
4 | **/tmp/**
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Utkarsh Mishra
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/agents/agent_sac.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.core.numeric import tensordot
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import utils
8 | import data_augs as rad
9 | from agents.auxiliary_funcs import BaseSacAgent
10 | from multistep_dynamics import MultiStepDynamicsModel
11 | import multistep_utils as mutils
12 |
13 | class PixelSacAgent(BaseSacAgent):
14 | """Learning Representations of Pixel Observations with SAC + Self-Supervised Techniques.."""
15 | def __init__(
16 | self,
17 | obs_shape,
18 | action_shape,
19 | horizon,
20 | device,
21 | hidden_dim=256,
22 | discount=0.99,
23 | init_temperature=0.01,
24 | alpha_lr=1e-3,
25 | alpha_beta=0.9,
26 | actor_lr=1e-3,
27 | actor_beta=0.9,
28 | actor_log_std_min=-10,
29 | actor_log_std_max=2,
30 | actor_update_freq=2,
31 | critic_lr=1e-3,
32 | critic_beta=0.9,
33 | critic_tau=0.005,
34 | critic_target_update_freq=2,
35 | encoder_type='pixel',
36 | encoder_feature_dim=50,
37 | encoder_lr=1e-3,
38 | encoder_tau=0.005,
39 | decoder_lr=1e-3,
40 | decoder_weight_lambda=0.0,
41 | num_layers=4,
42 | num_filters=32,
43 | cpc_update_freq=1,
44 | log_interval=100,
45 | detach_encoder=False,
46 | latent_dim=128,
47 | data_augs = '',
48 | use_metric_loss=False
49 | ):
50 |
51 | print('########################################################################')
52 | print('################### Starting Case 0: Baseline Agent ####################')
53 | print('###################### Reward: No; Transition: No ######################')
54 | print('########################################################################')
55 |
56 | super().__init__(
57 | obs_shape,
58 | action_shape,
59 | horizon,
60 | device,
61 | hidden_dim,
62 | discount,
63 | init_temperature,
64 | alpha_lr,
65 | alpha_beta,
66 | actor_lr,
67 | actor_beta,
68 | actor_log_std_min,
69 | actor_log_std_max,
70 | actor_update_freq,
71 | critic_lr,
72 | critic_beta,
73 | critic_tau,
74 | critic_target_update_freq,
75 | encoder_type,
76 | encoder_feature_dim,
77 | encoder_lr,
78 | encoder_tau,
79 | decoder_lr,
80 | decoder_weight_lambda,
81 | num_layers,
82 | num_filters,
83 | cpc_update_freq,
84 | log_interval,
85 | detach_encoder,
86 | latent_dim,
87 | data_augs,
88 | use_metric_loss
89 | )
90 |
91 | self.train()
92 | self.critic_target.train()
93 |
94 | def update(self, replay_buffer, L, step):
95 | if self.encoder_type == 'pixel':
96 |
97 | batch_obs, batch_action, batch_reward, batch_not_done = replay_buffer.sample_multistep()
98 |
99 | obs = batch_obs[0]
100 | action = batch_action[0]
101 | next_obs = batch_obs[1]
102 | reward = batch_reward[0].unsqueeze(-1)
103 | not_done = batch_not_done[0].unsqueeze(-1)
104 |
105 | self.update_critic(obs, action, reward, next_obs, not_done, L, step)
106 |
107 | encoded_batch_obs = []
108 |
109 | for en_iter in range(batch_obs.size(0)):
110 | encoded_obs = self.critic.encoder(batch_obs[en_iter])
111 | encoded_batch_obs.append(encoded_obs)
112 |
113 | encoded_batch_obs = torch.stack(encoded_batch_obs)
114 |
115 | else:
116 | obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio()
117 | self.update_critic(obs, action, reward, next_obs, not_done, L, step)
118 |
119 | if step % self.log_interval == 0:
120 | L.log('train/batch_reward', reward.mean(), step)
121 |
122 | if step % self.actor_update_freq == 0:
123 | self.update_actor_and_alpha(obs, L, step)
124 |
125 | if step % self.critic_target_update_freq == 0:
126 | utils.soft_update_params(
127 | self.critic.Q1, self.critic_target.Q1, self.critic_tau
128 | )
129 | utils.soft_update_params(
130 | self.critic.Q2, self.critic_target.Q2, self.critic_tau
131 | )
132 | utils.soft_update_params(
133 | self.critic.encoder, self.critic_target.encoder,
134 | self.encoder_tau
135 | )
--------------------------------------------------------------------------------
/distractors/driving/1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UtkarshMishra04/pixel-representations-RL/8f457adcf41eb3b8975eaa5a752736c07b908c63/distractors/driving/1.mp4
--------------------------------------------------------------------------------
/dmc2gym/__init__.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym.envs.registration import register
3 |
4 |
5 | def make(
6 | domain_name,
7 | task_name,
8 | resource_files,
9 | img_source,
10 | total_frames,
11 | seed=1,
12 | visualize_reward=True,
13 | from_pixels=False,
14 | height=84,
15 | width=84,
16 | camera_id=0,
17 | frame_skip=1,
18 | episode_length=1000,
19 | off_center=False,
20 | environment_kwargs=None
21 | ):
22 | env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)
23 |
24 | if from_pixels:
25 | assert not visualize_reward, 'cannot use visualize reward when learning from pixels'
26 |
27 | # shorten episode length
28 | max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
29 |
30 | if not env_id in gym.envs.registry.env_specs:
31 | register(
32 | id=env_id,
33 | entry_point='dmc2gym.wrappers:DMCWrapper',
34 | kwargs={
35 | 'domain_name': domain_name,
36 | 'task_name': task_name,
37 | 'resource_files': resource_files,
38 | 'img_source': img_source,
39 | 'total_frames': total_frames,
40 | 'task_kwargs': {
41 | 'random': seed
42 | },
43 | 'environment_kwargs': environment_kwargs,
44 | 'visualize_reward': visualize_reward,
45 | 'from_pixels': from_pixels,
46 | 'height': height,
47 | 'width': width,
48 | 'camera_id': camera_id,
49 | 'frame_skip': frame_skip,
50 | 'off_center': off_center,
51 | },
52 | max_episode_steps=max_episode_steps
53 | )
54 | return gym.make(env_id)
55 |
--------------------------------------------------------------------------------
/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def tie_weights(src, trg):
6 | assert type(src) == type(trg)
7 | trg.weight = src.weight
8 | trg.bias = src.bias
9 |
10 |
11 | ### Works with Finism SAC
12 | # OUT_DIM = {2: 39, 4: 35, 6: 31}
13 | # OUT_DIM_84 = {2: 39, 4: 35, 6: 31}
14 | # OUT_DIM_100 = {2: 39, 4: 43, 6: 31}
15 |
16 | ### Works with RAD SAC
17 | OUT_DIM = {2: 39, 4: 35, 6: 31}
18 | OUT_DIM_84 = {2: 29, 4: 35, 6: 21}
19 | OUT_DIM_100 = {4: 47}
20 |
21 |
22 | class PixelEncoder(nn.Module):
23 | """Convolutional encoder of pixels observations."""
24 | def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32,output_logits=False):
25 | super().__init__()
26 |
27 | assert len(obs_shape) == 3
28 | self.obs_shape = obs_shape
29 | self.feature_dim = feature_dim
30 | self.num_layers = num_layers
31 | # try 2 5x5s with strides 2x2. with samep adding, it should reduce 84 to 21, so with valid, it should be even smaller than 21.
32 | self.convs = nn.ModuleList(
33 | [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
34 | )
35 | for i in range(num_layers - 1):
36 | self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
37 |
38 | if obs_shape[-1] == 100:
39 | assert num_layers in OUT_DIM_100
40 | out_dim = OUT_DIM_100[num_layers]
41 | elif obs_shape[-1] == 84:
42 | out_dim = OUT_DIM_84[num_layers]
43 | else:
44 | out_dim = OUT_DIM[num_layers]
45 |
46 | self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
47 | self.ln = nn.LayerNorm(self.feature_dim)
48 |
49 | self.outputs = dict()
50 | self.output_logits = output_logits
51 |
52 | def reparameterize(self, mu, logstd):
53 | std = torch.exp(logstd)
54 | eps = torch.randn_like(std)
55 | return mu + eps * std
56 |
57 | def forward_conv(self, obs):
58 | if obs.max() > 1.:
59 | obs = obs / 255.
60 |
61 | self.outputs['obs'] = obs
62 |
63 | conv = torch.relu(self.convs[0](obs))
64 | self.outputs['conv1'] = conv
65 |
66 | for i in range(1, self.num_layers):
67 | conv = torch.relu(self.convs[i](conv))
68 | self.outputs['conv%s' % (i + 1)] = conv
69 |
70 | h = conv.view(conv.size(0), -1)
71 |
72 | return h
73 |
74 | def forward(self, obs, detach=False):
75 | h = self.forward_conv(obs)
76 |
77 | if detach:
78 | h = h.detach()
79 |
80 | h_fc = self.fc(h)
81 | self.outputs['fc'] = h_fc
82 |
83 | h_norm = self.ln(h_fc)
84 | self.outputs['ln'] = h_norm
85 |
86 | if self.output_logits:
87 | out = h_norm
88 | else:
89 | out = torch.tanh(h_norm)
90 | self.outputs['tanh'] = out
91 |
92 | return out
93 |
94 | def copy_conv_weights_from(self, source):
95 | """Tie convolutional layers"""
96 | # only tie conv layers
97 | for i in range(self.num_layers):
98 | tie_weights(src=source.convs[i], trg=self.convs[i])
99 |
100 | def log(self, L, step, log_freq):
101 | if step % log_freq != 0:
102 | return
103 |
104 | for k, v in self.outputs.items():
105 | L.log_histogram('train_encoder/%s_hist' % k, v, step)
106 | if len(v.shape) > 2:
107 | L.log_image('train_encoder/%s_img' % k, v[0], step)
108 |
109 | for i in range(self.num_layers):
110 | L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
111 | L.log_param('train_encoder/fc', self.fc, step)
112 | L.log_param('train_encoder/ln', self.ln, step)
113 |
114 |
115 | class IdentityEncoder(nn.Module):
116 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args):
117 | super().__init__()
118 |
119 | assert len(obs_shape) == 1
120 | self.feature_dim = obs_shape[0]
121 |
122 | def forward(self, obs, detach=False):
123 | return obs
124 |
125 | def copy_conv_weights_from(self, source):
126 | pass
127 |
128 | def log(self, L, step, log_freq):
129 | pass
130 |
131 |
132 | _AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder}
133 |
134 |
135 | def make_encoder(
136 | encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False
137 | ):
138 | assert encoder_type in _AVAILABLE_ENCODERS
139 | return _AVAILABLE_ENCODERS[encoder_type](
140 | obs_shape, feature_dim, num_layers, num_filters, output_logits
141 | )
142 |
--------------------------------------------------------------------------------
/local_dm_control_suite/README.md:
--------------------------------------------------------------------------------
1 | # DeepMind Control Suite.
2 |
3 | This submodule contains the domains and tasks described in the
4 | [DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690).
5 |
6 | ## Quickstart
7 |
8 | ```python
9 | from dm_control import suite
10 | import numpy as np
11 |
12 | # Load one task:
13 | env = suite.load(domain_name="cartpole", task_name="swingup")
14 |
15 | # Iterate over a task set:
16 | for domain_name, task_name in suite.BENCHMARKING:
17 | env = suite.load(domain_name, task_name)
18 |
19 | # Step through an episode and print out reward, discount and observation.
20 | action_spec = env.action_spec()
21 | time_step = env.reset()
22 | while not time_step.last():
23 | action = np.random.uniform(action_spec.minimum,
24 | action_spec.maximum,
25 | size=action_spec.shape)
26 | time_step = env.step(action)
27 | print(time_step.reward, time_step.discount, time_step.observation)
28 | ```
29 |
30 | ## Illustration video
31 |
32 | Below is a video montage of solved Control Suite tasks, with reward
33 | visualisation enabled.
34 |
35 | [](https://www.youtube.com/watch?v=rAai4QzcYbs)
36 |
37 |
38 | ### Quadruped domain [April 2019]
39 |
40 | Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are:
41 |
42 | - 4 DoFs per leg, 1 constraining tendon.
43 | - 3 actuators per leg: 'yaw', 'lift', 'extend'.
44 | - Filtered position actuators with timescale of 100ms.
45 | - Sensors include an IMU, force/torque sensors, and rangefinders.
46 |
47 | Four tasks:
48 |
49 | - `walk` and `run`: self-right the body then move forward at a desired speed.
50 | - `escape`: escape a bowl-shaped random terrain (uses rangefinders).
51 | - `fetch`, go to a moving ball and bring it to a target.
52 |
53 | All behaviors in the video below were trained with [Abdolmaleki et al's
54 | MPO](https://arxiv.org/abs/1806.06920).
55 |
56 | [](https://www.youtube.com/watch?v=RhRLjbb7pBE)
57 |
--------------------------------------------------------------------------------
/local_dm_control_suite/acrobot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Acrobot domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 | import numpy as np
31 |
32 | _DEFAULT_TIME_LIMIT = 10
33 | SUITE = containers.TaggedTasks()
34 |
35 |
36 | def get_model_and_assets():
37 | """Returns a tuple containing the model XML string and a dict of assets."""
38 | return common.read_model('acrobot.xml'), common.ASSETS
39 |
40 |
41 | @SUITE.add('benchmarking')
42 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
43 | environment_kwargs=None):
44 | """Returns Acrobot balance task."""
45 | physics = Physics.from_xml_string(*get_model_and_assets())
46 | task = Balance(sparse=False, random=random)
47 | environment_kwargs = environment_kwargs or {}
48 | return control.Environment(
49 | physics, task, time_limit=time_limit, **environment_kwargs)
50 |
51 |
52 | @SUITE.add('benchmarking')
53 | def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
54 | environment_kwargs=None):
55 | """Returns Acrobot sparse balance."""
56 | physics = Physics.from_xml_string(*get_model_and_assets())
57 | task = Balance(sparse=True, random=random)
58 | environment_kwargs = environment_kwargs or {}
59 | return control.Environment(
60 | physics, task, time_limit=time_limit, **environment_kwargs)
61 |
62 |
63 | class Physics(mujoco.Physics):
64 | """Physics simulation with additional features for the Acrobot domain."""
65 |
66 | def horizontal(self):
67 | """Returns horizontal (x) component of body frame z-axes."""
68 | return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz']
69 |
70 | def vertical(self):
71 | """Returns vertical (z) component of body frame z-axes."""
72 | return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz']
73 |
74 | def to_target(self):
75 | """Returns the distance from the tip to the target."""
76 | tip_to_target = (self.named.data.site_xpos['target'] -
77 | self.named.data.site_xpos['tip'])
78 | return np.linalg.norm(tip_to_target)
79 |
80 | def orientations(self):
81 | """Returns the sines and cosines of the pole angles."""
82 | return np.concatenate((self.horizontal(), self.vertical()))
83 |
84 |
85 | class Balance(base.Task):
86 | """An Acrobot `Task` to swing up and balance the pole."""
87 |
88 | def __init__(self, sparse, random=None):
89 | """Initializes an instance of `Balance`.
90 |
91 | Args:
92 | sparse: A `bool` specifying whether to use a sparse (indicator) reward.
93 | random: Optional, either a `numpy.random.RandomState` instance, an
94 | integer seed for creating a new `RandomState`, or None to select a seed
95 | automatically (default).
96 | """
97 | self._sparse = sparse
98 | super(Balance, self).__init__(random=random)
99 |
100 | def initialize_episode(self, physics):
101 | """Sets the state of the environment at the start of each episode.
102 |
103 | Shoulder and elbow are set to a random position between [-pi, pi).
104 |
105 | Args:
106 | physics: An instance of `Physics`.
107 | """
108 | physics.named.data.qpos[
109 | ['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2)
110 | super(Balance, self).initialize_episode(physics)
111 |
112 | def get_observation(self, physics):
113 | """Returns an observation of pole orientation and angular velocities."""
114 | obs = collections.OrderedDict()
115 | obs['orientations'] = physics.orientations()
116 | obs['velocity'] = physics.velocity()
117 | return obs
118 |
119 | def _get_reward(self, physics, sparse):
120 | target_radius = physics.named.model.site_size['target', 0]
121 | return rewards.tolerance(physics.to_target(),
122 | bounds=(0, target_radius),
123 | margin=0 if sparse else 1)
124 |
125 | def get_reward(self, physics):
126 | """Returns a sparse or a smooth reward, as specified in the constructor."""
127 | return self._get_reward(physics, sparse=self._sparse)
128 |
--------------------------------------------------------------------------------
/local_dm_control_suite/acrobot.xml:
--------------------------------------------------------------------------------
1 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/local_dm_control_suite/ball_in_cup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Ball-in-Cup Domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 |
30 | _DEFAULT_TIME_LIMIT = 20 # (seconds)
31 | _CONTROL_TIMESTEP = .02 # (seconds)
32 |
33 |
34 | SUITE = containers.TaggedTasks()
35 |
36 |
37 | def get_model_and_assets():
38 | """Returns a tuple containing the model XML string and a dict of assets."""
39 | return common.read_model('ball_in_cup.xml'), common.ASSETS
40 |
41 |
42 | @SUITE.add('benchmarking', 'easy')
43 | def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
44 | """Returns the Ball-in-Cup task."""
45 | physics = Physics.from_xml_string(*get_model_and_assets())
46 | task = BallInCup(random=random)
47 | environment_kwargs = environment_kwargs or {}
48 | return control.Environment(
49 | physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
50 | **environment_kwargs)
51 |
52 |
53 | class Physics(mujoco.Physics):
54 | """Physics with additional features for the Ball-in-Cup domain."""
55 |
56 | def ball_to_target(self):
57 | """Returns the vector from the ball to the target."""
58 | target = self.named.data.site_xpos['target', ['x', 'z']]
59 | ball = self.named.data.xpos['ball', ['x', 'z']]
60 | return target - ball
61 |
62 | def in_target(self):
63 | """Returns 1 if the ball is in the target, 0 otherwise."""
64 | ball_to_target = abs(self.ball_to_target())
65 | target_size = self.named.model.site_size['target', [0, 2]]
66 | ball_size = self.named.model.geom_size['ball', 0]
67 | return float(all(ball_to_target < target_size - ball_size))
68 |
69 |
70 | class BallInCup(base.Task):
71 | """The Ball-in-Cup task. Put the ball in the cup."""
72 |
73 | def initialize_episode(self, physics):
74 | """Sets the state of the environment at the start of each episode.
75 |
76 | Args:
77 | physics: An instance of `Physics`.
78 |
79 | """
80 | # Find a collision-free random initial position of the ball.
81 | penetrating = True
82 | while penetrating:
83 | # Assign a random ball position.
84 | physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
85 | physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
86 | # Check for collisions.
87 | physics.after_reset()
88 | penetrating = physics.data.ncon > 0
89 | super(BallInCup, self).initialize_episode(physics)
90 |
91 | def get_observation(self, physics):
92 | """Returns an observation of the state."""
93 | obs = collections.OrderedDict()
94 | obs['position'] = physics.position()
95 | obs['velocity'] = physics.velocity()
96 | return obs
97 |
98 | def get_reward(self, physics):
99 | """Returns a sparse reward."""
100 | return physics.in_target()
101 |
--------------------------------------------------------------------------------
/local_dm_control_suite/ball_in_cup.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/local_dm_control_suite/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Base class for tasks in the Control Suite."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from dm_control import mujoco
23 | from dm_control.rl import control
24 |
25 | import numpy as np
26 |
27 |
28 | class Task(control.Task):
29 | """Base class for tasks in the Control Suite.
30 |
31 | Actions are mapped directly to the states of MuJoCo actuators: each element of
32 | the action array is used to set the control input for a single actuator. The
33 | ordering of the actuators is the same as in the corresponding MJCF XML file.
34 |
35 | Attributes:
36 | random: A `numpy.random.RandomState` instance. This should be used to
37 | generate all random variables associated with the task, such as random
38 | starting states, observation noise* etc.
39 |
40 | *If sensor noise is enabled in the MuJoCo model then this will be generated
41 | using MuJoCo's internal RNG, which has its own independent state.
42 | """
43 |
44 | def __init__(self, random=None):
45 | """Initializes a new continuous control task.
46 |
47 | Args:
48 | random: Optional, either a `numpy.random.RandomState` instance, an integer
49 | seed for creating a new `RandomState`, or None to select a seed
50 | automatically (default).
51 | """
52 | if not isinstance(random, np.random.RandomState):
53 | random = np.random.RandomState(random)
54 | self._random = random
55 | self._visualize_reward = False
56 |
57 | @property
58 | def random(self):
59 | """Task-specific `numpy.random.RandomState` instance."""
60 | return self._random
61 |
62 | def action_spec(self, physics):
63 | """Returns a `BoundedArraySpec` matching the `physics` actuators."""
64 | return mujoco.action_spec(physics)
65 |
66 | def initialize_episode(self, physics):
67 | """Resets geom colors to their defaults after starting a new episode.
68 |
69 | Subclasses of `base.Task` must delegate to this method after performing
70 | their own initialization.
71 |
72 | Args:
73 | physics: An instance of `mujoco.Physics`.
74 | """
75 | self.after_step(physics)
76 |
77 | def before_step(self, action, physics):
78 | """Sets the control signal for the actuators to values in `action`."""
79 | # Support legacy internal code.
80 | action = getattr(action, "continuous_actions", action)
81 | physics.set_control(action)
82 |
83 | def after_step(self, physics):
84 | """Modifies colors according to the reward."""
85 | if self._visualize_reward:
86 | reward = np.clip(self.get_reward(physics), 0.0, 1.0)
87 | _set_reward_colors(physics, reward)
88 |
89 | @property
90 | def visualize_reward(self):
91 | return self._visualize_reward
92 |
93 | @visualize_reward.setter
94 | def visualize_reward(self, value):
95 | if not isinstance(value, bool):
96 | raise ValueError("Expected a boolean, got {}.".format(type(value)))
97 | self._visualize_reward = value
98 |
99 |
100 | _MATERIALS = ["self", "effector", "target"]
101 | _DEFAULT = [name + "_default" for name in _MATERIALS]
102 | _HIGHLIGHT = [name + "_highlight" for name in _MATERIALS]
103 |
104 |
105 | def _set_reward_colors(physics, reward):
106 | """Sets the highlight, effector and target colors according to the reward."""
107 | assert 0.0 <= reward <= 1.0
108 | colors = physics.named.model.mat_rgba
109 | default = colors[_DEFAULT]
110 | highlight = colors[_HIGHLIGHT]
111 | blend_coef = reward ** 4 # Better color distinction near high rewards.
112 | colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default
113 |
--------------------------------------------------------------------------------
/local_dm_control_suite/cartpole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/local_dm_control_suite/cheetah.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Cheetah Domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 |
31 |
32 | # How long the simulation will run, in seconds.
33 | _DEFAULT_TIME_LIMIT = 10
34 |
35 | # Running speed above which reward is 1.
36 | _RUN_SPEED = 10
37 |
38 | SUITE = containers.TaggedTasks()
39 |
40 |
41 | def get_model_and_assets():
42 | """Returns a tuple containing the model XML string and a dict of assets."""
43 | return common.read_model('cheetah.xml'), common.ASSETS
44 |
45 |
46 | @SUITE.add('benchmarking')
47 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
48 | """Returns the run task."""
49 | physics = Physics.from_xml_string(*get_model_and_assets())
50 | task = Cheetah(random=random)
51 | environment_kwargs = environment_kwargs or {}
52 | return control.Environment(physics, task, time_limit=time_limit,
53 | **environment_kwargs)
54 |
55 |
56 | class Physics(mujoco.Physics):
57 | """Physics simulation with additional features for the Cheetah domain."""
58 |
59 | def speed(self):
60 | """Returns the horizontal speed of the Cheetah."""
61 | return self.named.data.sensordata['torso_subtreelinvel'][0]
62 |
63 |
64 | class Cheetah(base.Task):
65 | """A `Task` to train a running Cheetah."""
66 |
67 | def initialize_episode(self, physics):
68 | """Sets the state of the environment at the start of each episode."""
69 | # The indexing below assumes that all joints have a single DOF.
70 | assert physics.model.nq == physics.model.njnt
71 | is_limited = physics.model.jnt_limited == 1
72 | lower, upper = physics.model.jnt_range[is_limited].T
73 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
74 |
75 | # Stabilize the model before the actual simulation.
76 | for _ in range(200):
77 | physics.step()
78 |
79 | physics.data.time = 0
80 | self._timeout_progress = 0
81 | super(Cheetah, self).initialize_episode(physics)
82 |
83 | def get_observation(self, physics):
84 | """Returns an observation of the state, ignoring horizontal position."""
85 | obs = collections.OrderedDict()
86 | # Ignores horizontal position to maintain translational invariance.
87 | obs['position'] = physics.data.qpos[1:].copy()
88 | obs['velocity'] = physics.velocity()
89 | return obs
90 |
91 | def get_reward(self, physics):
92 | """Returns a reward to the agent."""
93 | return rewards.tolerance(physics.speed(),
94 | bounds=(_RUN_SPEED, float('inf')),
95 | margin=_RUN_SPEED,
96 | value_at_margin=0,
97 | sigmoid='linear')
98 |
--------------------------------------------------------------------------------
/local_dm_control_suite/cheetah.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/local_dm_control_suite/common/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Functions to manage the common assets for domains."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | from dm_control.utils import io as resources
24 |
25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
26 | _FILENAMES = [
27 | "./common/materials.xml",
28 | "./common/materials_white_floor.xml",
29 | "./common/skybox.xml",
30 | "./common/visual.xml",
31 | ]
32 |
33 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
34 | for filename in _FILENAMES}
35 |
36 |
37 | def read_model(model_filename):
38 | """Reads a model XML file and returns its contents as a string."""
39 | return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))
40 |
--------------------------------------------------------------------------------
/local_dm_control_suite/common/materials.xml:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/local_dm_control_suite/common/materials_white_floor.xml:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/local_dm_control_suite/common/skybox.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/local_dm_control_suite/common/visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/local_dm_control_suite/demos/mocap_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Demonstration of amc parsing for CMU mocap database.
17 |
18 | To run the demo, supply a path to a `.amc` file:
19 |
20 | python mocap_demo --filename='path/to/mocap.amc'
21 |
22 | CMU motion capture clips are available at mocap.cs.cmu.edu
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import time
30 | # Internal dependencies.
31 |
32 | from absl import app
33 | from absl import flags
34 |
35 | from local_dm_control_suite import humanoid_CMU
36 | from dm_control.suite.utils import parse_amc
37 |
38 | import matplotlib.pyplot as plt
39 | import numpy as np
40 |
41 | FLAGS = flags.FLAGS
42 | flags.DEFINE_string('filename', None, 'amc file to be converted.')
43 | flags.DEFINE_integer('max_num_frames', 90,
44 | 'Maximum number of frames for plotting/playback')
45 |
46 |
47 | def main(unused_argv):
48 | env = humanoid_CMU.stand()
49 |
50 | # Parse and convert specified clip.
51 | converted = parse_amc.convert(FLAGS.filename,
52 | env.physics, env.control_timestep())
53 |
54 | max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1)
55 |
56 | width = 480
57 | height = 480
58 | video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)
59 |
60 | for i in range(max_frame):
61 | p_i = converted.qpos[:, i]
62 | with env.physics.reset_context():
63 | env.physics.data.qpos[:] = p_i
64 | video[i] = np.hstack([env.physics.render(height, width, camera_id=0),
65 | env.physics.render(height, width, camera_id=1)])
66 |
67 | tic = time.time()
68 | for i in range(max_frame):
69 | if i == 0:
70 | img = plt.imshow(video[i])
71 | else:
72 | img.set_data(video[i])
73 | toc = time.time()
74 | clock_dt = toc - tic
75 | tic = time.time()
76 | # Real-time playback not always possible as clock_dt > .03
77 | plt.pause(max(0.01, 0.03 - clock_dt)) # Need min display time > 0.0.
78 | plt.draw()
79 | plt.waitforbuttonpress()
80 |
81 |
82 | if __name__ == '__main__':
83 | flags.mark_flag_as_required('filename')
84 | app.run(main)
85 |
--------------------------------------------------------------------------------
/local_dm_control_suite/demos/zeros.amc:
--------------------------------------------------------------------------------
1 | #DUMMY AMC for testing
2 | :FULLY-SPECIFIED
3 | :DEGREES
4 | 1
5 | root 0 0 0 0 0 0
6 | lowerback 0 0 0
7 | upperback 0 0 0
8 | thorax 0 0 0
9 | lowerneck 0 0 0
10 | upperneck 0 0 0
11 | head 0 0 0
12 | rclavicle 0 0
13 | rhumerus 0 0 0
14 | rradius 0
15 | rwrist 0
16 | rhand 0 0
17 | rfingers 0
18 | rthumb 0 0
19 | lclavicle 0 0
20 | lhumerus 0 0 0
21 | lradius 0
22 | lwrist 0
23 | lhand 0 0
24 | lfingers 0
25 | lthumb 0 0
26 | rfemur 0 0 0
27 | rtibia 0
28 | rfoot 0 0
29 | rtoes 0
30 | lfemur 0 0 0
31 | ltibia 0
32 | lfoot 0 0
33 | ltoes 0
34 | 2
35 | root 0 0 0 0 0 0
36 | lowerback 0 0 0
37 | upperback 0 0 0
38 | thorax 0 0 0
39 | lowerneck 0 0 0
40 | upperneck 0 0 0
41 | head 0 0 0
42 | rclavicle 0 0
43 | rhumerus 0 0 0
44 | rradius 0
45 | rwrist 0
46 | rhand 0 0
47 | rfingers 0
48 | rthumb 0 0
49 | lclavicle 0 0
50 | lhumerus 0 0 0
51 | lradius 0
52 | lwrist 0
53 | lhand 0 0
54 | lfingers 0
55 | lthumb 0 0
56 | rfemur 0 0 0
57 | rtibia 0
58 | rfoot 0 0
59 | rtoes 0
60 | lfemur 0 0 0
61 | ltibia 0
62 | lfoot 0 0
63 | ltoes 0
64 | 3
65 | root 0 0 0 0 0 0
66 | lowerback 0 0 0
67 | upperback 0 0 0
68 | thorax 0 0 0
69 | lowerneck 0 0 0
70 | upperneck 0 0 0
71 | head 0 0 0
72 | rclavicle 0 0
73 | rhumerus 0 0 0
74 | rradius 0
75 | rwrist 0
76 | rhand 0 0
77 | rfingers 0
78 | rthumb 0 0
79 | lclavicle 0 0
80 | lhumerus 0 0 0
81 | lradius 0
82 | lwrist 0
83 | lhand 0 0
84 | lfingers 0
85 | lthumb 0 0
86 | rfemur 0 0 0
87 | rtibia 0
88 | rfoot 0 0
89 | rtoes 0
90 | lfemur 0 0 0
91 | ltibia 0
92 | lfoot 0 0
93 | ltoes 0
94 | 4
95 | root 0 0 0 0 0 0
96 | lowerback 0 0 0
97 | upperback 0 0 0
98 | thorax 0 0 0
99 | lowerneck 0 0 0
100 | upperneck 0 0 0
101 | head 0 0 0
102 | rclavicle 0 0
103 | rhumerus 0 0 0
104 | rradius 0
105 | rwrist 0
106 | rhand 0 0
107 | rfingers 0
108 | rthumb 0 0
109 | lclavicle 0 0
110 | lhumerus 0 0 0
111 | lradius 0
112 | lwrist 0
113 | lhand 0 0
114 | lfingers 0
115 | lthumb 0 0
116 | rfemur 0 0 0
117 | rtibia 0
118 | rfoot 0 0
119 | rtoes 0
120 | lfemur 0 0 0
121 | ltibia 0
122 | lfoot 0 0
123 | ltoes 0
124 | 5
125 | root 0 0 0 0 0 0
126 | lowerback 0 0 0
127 | upperback 0 0 0
128 | thorax 0 0 0
129 | lowerneck 0 0 0
130 | upperneck 0 0 0
131 | head 0 0 0
132 | rclavicle 0 0
133 | rhumerus 0 0 0
134 | rradius 0
135 | rwrist 0
136 | rhand 0 0
137 | rfingers 0
138 | rthumb 0 0
139 | lclavicle 0 0
140 | lhumerus 0 0 0
141 | lradius 0
142 | lwrist 0
143 | lhand 0 0
144 | lfingers 0
145 | lthumb 0 0
146 | rfemur 0 0 0
147 | rtibia 0
148 | rfoot 0 0
149 | rtoes 0
150 | lfemur 0 0 0
151 | ltibia 0
152 | lfoot 0 0
153 | ltoes 0
154 | 6
155 | root 0 0 0 0 0 0
156 | lowerback 0 0 0
157 | upperback 0 0 0
158 | thorax 0 0 0
159 | lowerneck 0 0 0
160 | upperneck 0 0 0
161 | head 0 0 0
162 | rclavicle 0 0
163 | rhumerus 0 0 0
164 | rradius 0
165 | rwrist 0
166 | rhand 0 0
167 | rfingers 0
168 | rthumb 0 0
169 | lclavicle 0 0
170 | lhumerus 0 0 0
171 | lradius 0
172 | lwrist 0
173 | lhand 0 0
174 | lfingers 0
175 | lthumb 0 0
176 | rfemur 0 0 0
177 | rtibia 0
178 | rfoot 0 0
179 | rtoes 0
180 | lfemur 0 0 0
181 | ltibia 0
182 | lfoot 0 0
183 | ltoes 0
184 | 7
185 | root 0 0 0 0 0 0
186 | lowerback 0 0 0
187 | upperback 0 0 0
188 | thorax 0 0 0
189 | lowerneck 0 0 0
190 | upperneck 0 0 0
191 | head 0 0 0
192 | rclavicle 0 0
193 | rhumerus 0 0 0
194 | rradius 0
195 | rwrist 0
196 | rhand 0 0
197 | rfingers 0
198 | rthumb 0 0
199 | lclavicle 0 0
200 | lhumerus 0 0 0
201 | lradius 0
202 | lwrist 0
203 | lhand 0 0
204 | lfingers 0
205 | lthumb 0 0
206 | rfemur 0 0 0
207 | rtibia 0
208 | rfoot 0 0
209 | rtoes 0
210 | lfemur 0 0 0
211 | ltibia 0
212 | lfoot 0 0
213 | ltoes 0
214 |
--------------------------------------------------------------------------------
/local_dm_control_suite/explore.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Control suite environments explorer."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from absl import app
22 | from absl import flags
23 | from dm_control import suite
24 | from dm_control.suite.wrappers import action_noise
25 | from six.moves import input
26 |
27 | from dm_control import viewer
28 |
29 |
30 | _ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS]
31 |
32 | flags.DEFINE_enum('environment_name', None, _ALL_NAMES,
33 | 'Optional \'domain_name.task_name\' pair specifying the '
34 | 'environment to load. If unspecified a prompt will appear to '
35 | 'select one.')
36 | flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.')
37 | flags.DEFINE_bool('visualize_reward', True,
38 | 'Whether to vary the colors of geoms according to the '
39 | 'current reward value.')
40 | flags.DEFINE_float('action_noise', 0.,
41 | 'Standard deviation of Gaussian noise to apply to actions, '
42 | 'expressed as a fraction of the max-min range for each '
43 | 'action dimension. Defaults to 0, i.e. no noise.')
44 | FLAGS = flags.FLAGS
45 |
46 |
47 | def prompt_environment_name(prompt, values):
48 | environment_name = None
49 | while not environment_name:
50 | environment_name = input(prompt)
51 | if not environment_name or values.index(environment_name) < 0:
52 | print('"%s" is not a valid environment name.' % environment_name)
53 | environment_name = None
54 | return environment_name
55 |
56 |
57 | def main(argv):
58 | del argv
59 | environment_name = FLAGS.environment_name
60 | if environment_name is None:
61 | print('\n '.join(['Available environments:'] + _ALL_NAMES))
62 | environment_name = prompt_environment_name(
63 | 'Please select an environment name: ', _ALL_NAMES)
64 |
65 | index = _ALL_NAMES.index(environment_name)
66 | domain_name, task_name = suite.ALL_TASKS[index]
67 |
68 | task_kwargs = {}
69 | if not FLAGS.timeout:
70 | task_kwargs['time_limit'] = float('inf')
71 |
72 | def loader():
73 | env = suite.load(
74 | domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs)
75 | env.task.visualize_reward = FLAGS.visualize_reward
76 | if FLAGS.action_noise > 0:
77 | env = action_noise.Wrapper(env, scale=FLAGS.action_noise)
78 | return env
79 |
80 | viewer.launch(loader)
81 |
82 |
83 | if __name__ == '__main__':
84 | app.run(main)
85 |
--------------------------------------------------------------------------------
/local_dm_control_suite/finger.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/local_dm_control_suite/fish.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/local_dm_control_suite/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/local_dm_control_suite/lqr.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/local_dm_control_suite/lqr_solver.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | r"""Optimal policy for LQR levels.
17 |
18 | LQR control problem is described in
19 | https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | from absl import logging
27 | from dm_control.mujoco import wrapper
28 | import numpy as np
29 | from six.moves import range
30 |
31 | try:
32 | import scipy.linalg as sp # pylint: disable=g-import-not-at-top
33 | except ImportError:
34 | sp = None
35 |
36 |
37 | def _solve_dare(a, b, q, r):
38 | """Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration.
39 |
40 | Algebraic Riccati Equation:
41 | ```none
42 | P_{t-1} = Q + A' * P_{t} * A -
43 | A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A
44 | ```
45 |
46 | Args:
47 | a: A 2 dimensional numpy array, transition matrix A.
48 | b: A 2 dimensional numpy array, control matrix B.
49 | q: A 2 dimensional numpy array, symmetric positive definite cost matrix.
50 | r: A 2 dimensional numpy array, symmetric positive definite cost matrix
51 |
52 | Returns:
53 | A numpy array, a real symmetric matrix P which is the solution to DARE.
54 |
55 | Raises:
56 | RuntimeError: If the computed P matrix is not symmetric and
57 | positive-definite.
58 | """
59 | p = np.eye(len(a))
60 | for _ in range(1000000):
61 | a_p = a.T.dot(p) # A' * P_t
62 | a_p_b = np.dot(a_p, b) # A' * P_t * B
63 | # Algebraic Riccati Equation.
64 | p_next = q + np.dot(a_p, a) - a_p_b.dot(
65 | np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T))
66 | p_next += p_next.T
67 | p_next *= .5
68 | if np.abs(p - p_next).max() < 1e-12:
69 | break
70 | p = p_next
71 | else:
72 | logging.warning('DARE solver did not converge')
73 | try:
74 | # Check that the result is symmetric and positive-definite.
75 | np.linalg.cholesky(p_next)
76 | except np.linalg.LinAlgError:
77 | raise RuntimeError('ARE solver failed: P matrix is not symmetric and '
78 | 'positive-definite.')
79 | return p_next
80 |
81 |
82 | def solve(env):
83 | """Returns the optimal value and policy for LQR problem.
84 |
85 | Args:
86 | env: An instance of `control.EnvironmentV2` with LQR level.
87 |
88 | Returns:
89 | p: A numpy array, the Hessian of the optimal total cost-to-go (value
90 | function at state x) is V(x) = .5 * x' * p * x.
91 | k: A numpy array which gives the optimal linear policy u = k * x.
92 | beta: The maximum eigenvalue of (a + b * k). Under optimal policy, at
93 | timestep n the state tends to 0 like beta^n.
94 |
95 | Raises:
96 | RuntimeError: If the controlled system is unstable.
97 | """
98 | n = env.physics.model.nq # number of DoFs
99 | m = env.physics.model.nu # number of controls
100 |
101 | # Compute the mass matrix.
102 | mass = np.zeros((n, n))
103 | wrapper.mjbindings.mjlib.mj_fullM(env.physics.model.ptr, mass,
104 | env.physics.data.qM)
105 |
106 | # Compute input matrices a, b, q and r to the DARE solvers.
107 | # State transition matrix a.
108 | stiffness = np.diag(env.physics.model.jnt_stiffness.ravel())
109 | damping = np.diag(env.physics.model.dof_damping.ravel())
110 | dt = env.physics.model.opt.timestep
111 |
112 | j = np.linalg.solve(-mass, np.hstack((stiffness, damping)))
113 | a = np.eye(2 * n) + dt * np.vstack(
114 | (dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j))
115 |
116 | # Control transition matrix b.
117 | b = env.physics.data.actuator_moment.T
118 | bc = np.linalg.solve(mass, b)
119 | b = dt * np.vstack((dt * bc, bc))
120 |
121 | # State cost Hessian q.
122 | q = np.diag(np.hstack([np.ones(n), np.zeros(n)]))
123 |
124 | # Control cost Hessian r.
125 | r = env.task.control_cost_coef * np.eye(m)
126 |
127 | if sp:
128 | # Use scipy's faster DARE solver if available.
129 | solve_dare = sp.solve_discrete_are
130 | else:
131 | # Otherwise fall back on a slower internal implementation.
132 | solve_dare = _solve_dare
133 |
134 | # Solve the discrete algebraic Riccati equation.
135 | p = solve_dare(a, b, q, r)
136 | k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a)))
137 |
138 | # Under optimal policy, state tends to 0 like beta^n_timesteps
139 | beta = np.abs(np.linalg.eigvals(a + b.dot(k))).max()
140 | if beta >= 1.0:
141 | raise RuntimeError('Controlled system is unstable.')
142 | return p, k, beta
143 |
--------------------------------------------------------------------------------
/local_dm_control_suite/pendulum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Pendulum domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 | import numpy as np
31 |
32 |
33 | _DEFAULT_TIME_LIMIT = 20
34 | _ANGLE_BOUND = 8
35 | _COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND))
36 | SUITE = containers.TaggedTasks()
37 |
38 |
39 | def get_model_and_assets():
40 | """Returns a tuple containing the model XML string and a dict of assets."""
41 | return common.read_model('pendulum.xml'), common.ASSETS
42 |
43 |
44 | @SUITE.add('benchmarking')
45 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
46 | environment_kwargs=None):
47 | """Returns pendulum swingup task ."""
48 | physics = Physics.from_xml_string(*get_model_and_assets())
49 | task = SwingUp(random=random)
50 | environment_kwargs = environment_kwargs or {}
51 | return control.Environment(
52 | physics, task, time_limit=time_limit, **environment_kwargs)
53 |
54 |
55 | class Physics(mujoco.Physics):
56 | """Physics simulation with additional features for the Pendulum domain."""
57 |
58 | def pole_vertical(self):
59 | """Returns vertical (z) component of pole frame."""
60 | return self.named.data.xmat['pole', 'zz']
61 |
62 | def angular_velocity(self):
63 | """Returns the angular velocity of the pole."""
64 | return self.named.data.qvel['hinge'].copy()
65 |
66 | def pole_orientation(self):
67 | """Returns both horizontal and vertical components of pole frame."""
68 | return self.named.data.xmat['pole', ['zz', 'xz']]
69 |
70 |
71 | class SwingUp(base.Task):
72 | """A Pendulum `Task` to swing up and balance the pole."""
73 |
74 | def __init__(self, random=None):
75 | """Initialize an instance of `Pendulum`.
76 |
77 | Args:
78 | random: Optional, either a `numpy.random.RandomState` instance, an
79 | integer seed for creating a new `RandomState`, or None to select a seed
80 | automatically (default).
81 | """
82 | super(SwingUp, self).__init__(random=random)
83 |
84 | def initialize_episode(self, physics):
85 | """Sets the state of the environment at the start of each episode.
86 |
87 | Pole is set to a random angle between [-pi, pi).
88 |
89 | Args:
90 | physics: An instance of `Physics`.
91 |
92 | """
93 | physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi)
94 | super(SwingUp, self).initialize_episode(physics)
95 |
96 | def get_observation(self, physics):
97 | """Returns an observation.
98 |
99 | Observations are states concatenating pole orientation and angular velocity
100 | and pixels from fixed camera.
101 |
102 | Args:
103 | physics: An instance of `physics`, Pendulum physics.
104 |
105 | Returns:
106 | A `dict` of observation.
107 | """
108 | obs = collections.OrderedDict()
109 | obs['orientation'] = physics.pole_orientation()
110 | obs['velocity'] = physics.angular_velocity()
111 | return obs
112 |
113 | def get_reward(self, physics):
114 | return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1))
115 |
--------------------------------------------------------------------------------
/local_dm_control_suite/pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/local_dm_control_suite/point_mass.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/local_dm_control_suite/reacher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Reacher domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | import numpy as np
32 |
33 | SUITE = containers.TaggedTasks()
34 | _DEFAULT_TIME_LIMIT = 20
35 | _BIG_TARGET = .05
36 | _SMALL_TARGET = .015
37 |
38 |
39 | def get_model_and_assets():
40 | """Returns a tuple containing the model XML string and a dict of assets."""
41 | return common.read_model('reacher.xml'), common.ASSETS
42 |
43 |
44 | @SUITE.add('benchmarking', 'easy')
45 | def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
46 | """Returns reacher with sparse reward with 5e-2 tol and randomized target."""
47 | physics = Physics.from_xml_string(*get_model_and_assets())
48 | task = Reacher(target_size=_BIG_TARGET, random=random)
49 | environment_kwargs = environment_kwargs or {}
50 | return control.Environment(
51 | physics, task, time_limit=time_limit, **environment_kwargs)
52 |
53 |
54 | @SUITE.add('benchmarking')
55 | def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
56 | """Returns reacher with sparse reward with 1e-2 tol and randomized target."""
57 | physics = Physics.from_xml_string(*get_model_and_assets())
58 | task = Reacher(target_size=_SMALL_TARGET, random=random)
59 | environment_kwargs = environment_kwargs or {}
60 | return control.Environment(
61 | physics, task, time_limit=time_limit, **environment_kwargs)
62 |
63 |
64 | class Physics(mujoco.Physics):
65 | """Physics simulation with additional features for the Reacher domain."""
66 |
67 | def finger_to_target(self):
68 | """Returns the vector from target to finger in global coordinates."""
69 | return (self.named.data.geom_xpos['target', :2] -
70 | self.named.data.geom_xpos['finger', :2])
71 |
72 | def finger_to_target_dist(self):
73 | """Returns the signed distance between the finger and target surface."""
74 | return np.linalg.norm(self.finger_to_target())
75 |
76 |
77 | class Reacher(base.Task):
78 | """A reacher `Task` to reach the target."""
79 |
80 | def __init__(self, target_size, random=None):
81 | """Initialize an instance of `Reacher`.
82 |
83 | Args:
84 | target_size: A `float`, tolerance to determine whether finger reached the
85 | target.
86 | random: Optional, either a `numpy.random.RandomState` instance, an
87 | integer seed for creating a new `RandomState`, or None to select a seed
88 | automatically (default).
89 | """
90 | self._target_size = target_size
91 | super(Reacher, self).__init__(random=random)
92 |
93 | def initialize_episode(self, physics):
94 | """Sets the state of the environment at the start of each episode."""
95 | physics.named.model.geom_size['target', 0] = self._target_size
96 | randomizers.randomize_limited_and_rotational_joints(physics, self.random)
97 |
98 | # Randomize target position
99 | angle = self.random.uniform(0, 2 * np.pi)
100 | radius = self.random.uniform(.05, .20)
101 | physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle)
102 | physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle)
103 |
104 | super(Reacher, self).initialize_episode(physics)
105 |
106 | def get_observation(self, physics):
107 | """Returns an observation of the state and the target position."""
108 | obs = collections.OrderedDict()
109 | obs['position'] = physics.position()
110 | obs['to_target'] = physics.finger_to_target()
111 | obs['velocity'] = physics.velocity()
112 | return obs
113 |
114 | def get_reward(self, physics):
115 | radii = physics.named.model.geom_size[['target', 'finger'], 0].sum()
116 | return rewards.tolerance(physics.finger_to_target_dist(), (0, radii))
117 |
--------------------------------------------------------------------------------
/local_dm_control_suite/reacher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/local_dm_control_suite/swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/local_dm_control_suite/tests/loader_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Tests for the dm_control.suite loader."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | # Internal dependencies.
23 |
24 | from absl.testing import absltest
25 |
26 | from dm_control import suite
27 | from dm_control.rl import control
28 |
29 |
30 | class LoaderTest(absltest.TestCase):
31 |
32 | def test_load_without_kwargs(self):
33 | env = suite.load('cartpole', 'swingup')
34 | self.assertIsInstance(env, control.Environment)
35 |
36 | def test_load_with_kwargs(self):
37 | env = suite.load('cartpole', 'swingup',
38 | task_kwargs={'time_limit': 40, 'random': 99})
39 | self.assertIsInstance(env, control.Environment)
40 |
41 |
42 | class LoaderConstantsTest(absltest.TestCase):
43 |
44 | def testSuiteConstants(self):
45 | self.assertNotEmpty(suite.BENCHMARKING)
46 | self.assertNotEmpty(suite.EASY)
47 | self.assertNotEmpty(suite.HARD)
48 | self.assertNotEmpty(suite.EXTRA)
49 |
50 |
51 | if __name__ == '__main__':
52 | absltest.main()
53 |
--------------------------------------------------------------------------------
/local_dm_control_suite/tests/lqr_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Tests specific to the LQR domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 | import unittest
24 |
25 | # Internal dependencies.
26 | from absl import logging
27 |
28 | from absl.testing import absltest
29 | from absl.testing import parameterized
30 |
31 | from local_dm_control_suite import lqr
32 | from local_dm_control_suite import lqr_solver
33 |
34 | import numpy as np
35 | from six.moves import range
36 |
37 |
38 | class LqrTest(parameterized.TestCase):
39 |
40 | @parameterized.named_parameters(
41 | ('lqr_2_1', lqr.lqr_2_1),
42 | ('lqr_6_2', lqr.lqr_6_2))
43 | def test_lqr_optimal_policy(self, make_env):
44 | env = make_env()
45 | p, k, beta = lqr_solver.solve(env)
46 | self.assertPolicyisOptimal(env, p, k, beta)
47 |
48 | @parameterized.named_parameters(
49 | ('lqr_2_1', lqr.lqr_2_1),
50 | ('lqr_6_2', lqr.lqr_6_2))
51 | @unittest.skipUnless(
52 | condition=lqr_solver.sp,
53 | reason='scipy is not available, so non-scipy DARE solver is the default.')
54 | def test_lqr_optimal_policy_no_scipy(self, make_env):
55 | env = make_env()
56 | old_sp = lqr_solver.sp
57 | try:
58 | lqr_solver.sp = None # Force the solver to use the non-scipy code path.
59 | p, k, beta = lqr_solver.solve(env)
60 | finally:
61 | lqr_solver.sp = old_sp
62 | self.assertPolicyisOptimal(env, p, k, beta)
63 |
64 | def assertPolicyisOptimal(self, env, p, k, beta):
65 | tolerance = 1e-3
66 | n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta)))
67 | logging.info('%d timesteps for %g convergence.', n_steps, tolerance)
68 | total_loss = 0.0
69 |
70 | timestep = env.reset()
71 | initial_state = np.hstack((timestep.observation['position'],
72 | timestep.observation['velocity']))
73 | logging.info('Measuring total cost over %d steps.', n_steps)
74 | for _ in range(n_steps):
75 | x = np.hstack((timestep.observation['position'],
76 | timestep.observation['velocity']))
77 | # u = k*x is the optimal policy
78 | u = k.dot(x)
79 | total_loss += 1 - (timestep.reward or 0.0)
80 | timestep = env.step(u)
81 |
82 | logging.info('Analytical expected total cost is .5*x^T*p*x.')
83 | expected_loss = .5 * initial_state.T.dot(p).dot(initial_state)
84 | logging.info('Comparing measured and predicted costs.')
85 | np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance)
86 |
87 | if __name__ == '__main__':
88 | absltest.main()
89 |
--------------------------------------------------------------------------------
/local_dm_control_suite/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Utility functions used in the control suite."""
17 |
--------------------------------------------------------------------------------
/local_dm_control_suite/utils/parse_amc_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Tests for parse_amc utility."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 |
24 | # Internal dependencies.
25 |
26 | from absl.testing import absltest
27 | from local_dm_control_suite import humanoid_CMU
28 | from dm_control.suite.utils import parse_amc
29 |
30 | from dm_control.utils import io as resources
31 |
32 | _TEST_AMC_PATH = resources.GetResourceFilename(
33 | os.path.join(os.path.dirname(__file__), '../demos/zeros.amc'))
34 |
35 |
36 | class ParseAMCTest(absltest.TestCase):
37 |
38 | def test_sizes_of_parsed_data(self):
39 |
40 | # Instantiate the humanoid environment.
41 | env = humanoid_CMU.stand()
42 |
43 | # Parse and convert specified clip.
44 | converted = parse_amc.convert(
45 | _TEST_AMC_PATH, env.physics, env.control_timestep())
46 |
47 | self.assertEqual(converted.qpos.shape[0], 63)
48 | self.assertEqual(converted.qvel.shape[0], 62)
49 | self.assertEqual(converted.time.shape[0], converted.qpos.shape[1])
50 | self.assertEqual(converted.qpos.shape[1],
51 | converted.qvel.shape[1] + 1)
52 |
53 | # Parse and convert specified clip -- WITH SMALLER TIMESTEP
54 | converted2 = parse_amc.convert(
55 | _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep())
56 |
57 | self.assertEqual(converted2.qpos.shape[0], 63)
58 | self.assertEqual(converted2.qvel.shape[0], 62)
59 | self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1])
60 | self.assertEqual(converted.qpos.shape[1],
61 | converted.qvel.shape[1] + 1)
62 |
63 | # Compare sizes of parsed objects for different timesteps
64 | self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1])
65 |
66 |
67 | if __name__ == '__main__':
68 | absltest.main()
69 |
--------------------------------------------------------------------------------
/local_dm_control_suite/utils/randomizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Randomization functions."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from dm_control.mujoco.wrapper import mjbindings
23 | import numpy as np
24 | from six.moves import range
25 |
26 |
27 | def random_limited_quaternion(random, limit):
28 | """Generates a random quaternion limited to the specified rotations."""
29 | axis = random.randn(3)
30 | axis /= np.linalg.norm(axis)
31 | angle = random.rand() * limit
32 |
33 | quaternion = np.zeros(4)
34 | mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
35 |
36 | return quaternion
37 |
38 |
39 | def randomize_limited_and_rotational_joints(physics, random=None):
40 | """Randomizes the positions of joints defined in the physics body.
41 |
42 | The following randomization rules apply:
43 | - Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
44 | - Unbounded hinges are samples uniformly in [-pi, pi]
45 | - Quaternions for unlimited free joints and ball joints are sampled
46 | uniformly on the unit 3-sphere.
47 | - Quaternions for limited ball joints are sampled uniformly on a sector
48 | of the unit 3-sphere.
49 | - The linear degrees of freedom of free joints are not randomized.
50 |
51 | Args:
52 | physics: Instance of 'Physics' class that holds a loaded model.
53 | random: Optional instance of 'np.random.RandomState'. Defaults to the global
54 | NumPy random state.
55 | """
56 | random = random or np.random
57 |
58 | hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
59 | slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
60 | ball = mjbindings.enums.mjtJoint.mjJNT_BALL
61 | free = mjbindings.enums.mjtJoint.mjJNT_FREE
62 |
63 | qpos = physics.named.data.qpos
64 |
65 | for joint_id in range(physics.model.njnt):
66 | joint_name = physics.model.id2name(joint_id, 'joint')
67 | joint_type = physics.model.jnt_type[joint_id]
68 | is_limited = physics.model.jnt_limited[joint_id]
69 | range_min, range_max = physics.model.jnt_range[joint_id]
70 |
71 | if is_limited:
72 | if joint_type == hinge or joint_type == slide:
73 | qpos[joint_name] = random.uniform(range_min, range_max)
74 |
75 | elif joint_type == ball:
76 | qpos[joint_name] = random_limited_quaternion(random, range_max)
77 |
78 | else:
79 | if joint_type == hinge:
80 | qpos[joint_name] = random.uniform(-np.pi, np.pi)
81 |
82 | elif joint_type == ball:
83 | quat = random.randn(4)
84 | quat /= np.linalg.norm(quat)
85 | qpos[joint_name] = quat
86 |
87 | elif joint_type == free:
88 | quat = random.rand(4)
89 | quat /= np.linalg.norm(quat)
90 | qpos[joint_name][3:] = quat
91 |
92 |
--------------------------------------------------------------------------------
/local_dm_control_suite/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/local_dm_control_suite/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Environment wrappers used to extend or modify environment behaviour."""
17 |
--------------------------------------------------------------------------------
/local_dm_control_suite/wrappers/action_noise.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Wrapper control suite environments that adds Gaussian noise to actions."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import dm_env
23 | import numpy as np
24 |
25 |
26 | _BOUNDS_MUST_BE_FINITE = (
27 | 'All bounds in `env.action_spec()` must be finite, got: {action_spec}')
28 |
29 |
30 | class Wrapper(dm_env.Environment):
31 | """Wraps a control environment and adds Gaussian noise to actions."""
32 |
33 | def __init__(self, env, scale=0.01):
34 | """Initializes a new action noise Wrapper.
35 |
36 | Args:
37 | env: The control suite environment to wrap.
38 | scale: The standard deviation of the noise, expressed as a fraction
39 | of the max-min range for each action dimension.
40 |
41 | Raises:
42 | ValueError: If any of the action dimensions of the wrapped environment are
43 | unbounded.
44 | """
45 | action_spec = env.action_spec()
46 | if not (np.all(np.isfinite(action_spec.minimum)) and
47 | np.all(np.isfinite(action_spec.maximum))):
48 | raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec))
49 | self._minimum = action_spec.minimum
50 | self._maximum = action_spec.maximum
51 | self._noise_std = scale * (action_spec.maximum - action_spec.minimum)
52 | self._env = env
53 |
54 | def step(self, action):
55 | noisy_action = action + self._env.task.random.normal(scale=self._noise_std)
56 | # Clip the noisy actions in place so that they fall within the bounds
57 | # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of-
58 | # bounds control inputs, but we also clip here in case the actions do not
59 | # correspond directly to MuJoCo actuators, or if there are other wrapper
60 | # layers that expect the actions to be within bounds.
61 | np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action)
62 | return self._env.step(noisy_action)
63 |
64 | def reset(self):
65 | return self._env.reset()
66 |
67 | def observation_spec(self):
68 | return self._env.observation_spec()
69 |
70 | def action_spec(self):
71 | return self._env.action_spec()
72 |
73 | def __getattr__(self, name):
74 | return getattr(self._env, name)
75 |
--------------------------------------------------------------------------------
/local_dm_control_suite/wrappers/pixels.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Wrapper that adds pixel observations to a control environment."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | import dm_env
25 | from dm_env import specs
26 |
27 | STATE_KEY = 'state'
28 |
29 |
30 | class Wrapper(dm_env.Environment):
31 | """Wraps a control environment and adds a rendered pixel observation."""
32 |
33 | def __init__(self, env, pixels_only=True, render_kwargs=None,
34 | observation_key='pixels'):
35 | """Initializes a new pixel Wrapper.
36 |
37 | Args:
38 | env: The environment to wrap.
39 | pixels_only: If True (default), the original set of 'state' observations
40 | returned by the wrapped environment will be discarded, and the
41 | `OrderedDict` of observations will only contain pixels. If False, the
42 | `OrderedDict` will contain the original observations as well as the
43 | pixel observations.
44 | render_kwargs: Optional `dict` containing keyword arguments passed to the
45 | `mujoco.Physics.render` method.
46 | observation_key: Optional custom string specifying the pixel observation's
47 | key in the `OrderedDict` of observations. Defaults to 'pixels'.
48 |
49 | Raises:
50 | ValueError: If `env`'s observation spec is not compatible with the
51 | wrapper. Supported formats are a single array, or a dict of arrays.
52 | ValueError: If `env`'s observation already contains the specified
53 | `observation_key`.
54 | """
55 | if render_kwargs is None:
56 | render_kwargs = {}
57 |
58 | wrapped_observation_spec = env.observation_spec()
59 |
60 | if isinstance(wrapped_observation_spec, specs.Array):
61 | self._observation_is_dict = False
62 | invalid_keys = set([STATE_KEY])
63 | elif isinstance(wrapped_observation_spec, collections.MutableMapping):
64 | self._observation_is_dict = True
65 | invalid_keys = set(wrapped_observation_spec.keys())
66 | else:
67 | raise ValueError('Unsupported observation spec structure.')
68 |
69 | if not pixels_only and observation_key in invalid_keys:
70 | raise ValueError('Duplicate or reserved observation key {!r}.'
71 | .format(observation_key))
72 |
73 | if pixels_only:
74 | self._observation_spec = collections.OrderedDict()
75 | elif self._observation_is_dict:
76 | self._observation_spec = wrapped_observation_spec.copy()
77 | else:
78 | self._observation_spec = collections.OrderedDict()
79 | self._observation_spec[STATE_KEY] = wrapped_observation_spec
80 |
81 | # Extend observation spec.
82 | pixels = env.physics.render(**render_kwargs)
83 | pixels_spec = specs.Array(
84 | shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
85 | self._observation_spec[observation_key] = pixels_spec
86 |
87 | self._env = env
88 | self._pixels_only = pixels_only
89 | self._render_kwargs = render_kwargs
90 | self._observation_key = observation_key
91 |
92 | def reset(self):
93 | time_step = self._env.reset()
94 | return self._add_pixel_observation(time_step)
95 |
96 | def step(self, action):
97 | time_step = self._env.step(action)
98 | return self._add_pixel_observation(time_step)
99 |
100 | def observation_spec(self):
101 | return self._observation_spec
102 |
103 | def action_spec(self):
104 | return self._env.action_spec()
105 |
106 | def _add_pixel_observation(self, time_step):
107 | if self._pixels_only:
108 | observation = collections.OrderedDict()
109 | elif self._observation_is_dict:
110 | observation = type(time_step.observation)(time_step.observation)
111 | else:
112 | observation = collections.OrderedDict()
113 | observation[STATE_KEY] = time_step.observation
114 |
115 | pixels = self._env.physics.render(**self._render_kwargs)
116 | observation[self._observation_key] = pixels
117 | return time_step._replace(observation=observation)
118 |
119 | def __getattr__(self, name):
120 | return getattr(self._env, name)
121 |
--------------------------------------------------------------------------------
/local_dm_control_suite/wrappers/pixels_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Tests for the pixel wrapper."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | # Internal dependencies.
25 | from absl.testing import absltest
26 | from absl.testing import parameterized
27 | from local_dm_control_suite import cartpole
28 | from dm_control.suite.wrappers import pixels
29 | import dm_env
30 | from dm_env import specs
31 |
32 | import numpy as np
33 |
34 |
35 | class FakePhysics(object):
36 |
37 | def render(self, *args, **kwargs):
38 | del args
39 | del kwargs
40 | return np.zeros((4, 5, 3), dtype=np.uint8)
41 |
42 |
43 | class FakeArrayObservationEnvironment(dm_env.Environment):
44 |
45 | def __init__(self):
46 | self.physics = FakePhysics()
47 |
48 | def reset(self):
49 | return dm_env.restart(np.zeros((2,)))
50 |
51 | def step(self, action):
52 | del action
53 | return dm_env.transition(0.0, np.zeros((2,)))
54 |
55 | def action_spec(self):
56 | pass
57 |
58 | def observation_spec(self):
59 | return specs.Array(shape=(2,), dtype=np.float)
60 |
61 |
62 | class PixelsTest(parameterized.TestCase):
63 |
64 | @parameterized.parameters(True, False)
65 | def test_dict_observation(self, pixels_only):
66 | pixel_key = 'rgb'
67 |
68 | env = cartpole.swingup()
69 |
70 | # Make sure we are testing the right environment for the test.
71 | observation_spec = env.observation_spec()
72 | self.assertIsInstance(observation_spec, collections.OrderedDict)
73 |
74 | width = 320
75 | height = 240
76 |
77 | # The wrapper should only add one observation.
78 | wrapped = pixels.Wrapper(env,
79 | observation_key=pixel_key,
80 | pixels_only=pixels_only,
81 | render_kwargs={'width': width, 'height': height})
82 |
83 | wrapped_observation_spec = wrapped.observation_spec()
84 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
85 |
86 | if pixels_only:
87 | self.assertLen(wrapped_observation_spec, 1)
88 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
89 | else:
90 | expected_length = len(observation_spec) + 1
91 | self.assertLen(wrapped_observation_spec, expected_length)
92 | expected_keys = list(observation_spec.keys()) + [pixel_key]
93 | self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))
94 |
95 | # Check that the added spec item is consistent with the added observation.
96 | time_step = wrapped.reset()
97 | rgb_observation = time_step.observation[pixel_key]
98 | wrapped_observation_spec[pixel_key].validate(rgb_observation)
99 |
100 | self.assertEqual(rgb_observation.shape, (height, width, 3))
101 | self.assertEqual(rgb_observation.dtype, np.uint8)
102 |
103 | @parameterized.parameters(True, False)
104 | def test_single_array_observation(self, pixels_only):
105 | pixel_key = 'depth'
106 |
107 | env = FakeArrayObservationEnvironment()
108 | observation_spec = env.observation_spec()
109 | self.assertIsInstance(observation_spec, specs.Array)
110 |
111 | wrapped = pixels.Wrapper(env, observation_key=pixel_key,
112 | pixels_only=pixels_only)
113 | wrapped_observation_spec = wrapped.observation_spec()
114 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
115 |
116 | if pixels_only:
117 | self.assertLen(wrapped_observation_spec, 1)
118 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
119 | else:
120 | self.assertLen(wrapped_observation_spec, 2)
121 | self.assertEqual([pixels.STATE_KEY, pixel_key],
122 | list(wrapped_observation_spec.keys()))
123 |
124 | time_step = wrapped.reset()
125 |
126 | depth_observation = time_step.observation[pixel_key]
127 | wrapped_observation_spec[pixel_key].validate(depth_observation)
128 |
129 | self.assertEqual(depth_observation.shape, (4, 5, 3))
130 | self.assertEqual(depth_observation.dtype, np.uint8)
131 |
132 | if __name__ == '__main__':
133 | absltest.main()
134 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/README.md:
--------------------------------------------------------------------------------
1 | # DeepMind Control Suite.
2 |
3 | This submodule contains the domains and tasks described in the
4 | [DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690).
5 |
6 | ## Quickstart
7 |
8 | ```python
9 | from dm_control import suite
10 | import numpy as np
11 |
12 | # Load one task:
13 | env = suite.load(domain_name="cartpole", task_name="swingup")
14 |
15 | # Iterate over a task set:
16 | for domain_name, task_name in suite.BENCHMARKING:
17 | env = suite.load(domain_name, task_name)
18 |
19 | # Step through an episode and print out reward, discount and observation.
20 | action_spec = env.action_spec()
21 | time_step = env.reset()
22 | while not time_step.last():
23 | action = np.random.uniform(action_spec.minimum,
24 | action_spec.maximum,
25 | size=action_spec.shape)
26 | time_step = env.step(action)
27 | print(time_step.reward, time_step.discount, time_step.observation)
28 | ```
29 |
30 | ## Illustration video
31 |
32 | Below is a video montage of solved Control Suite tasks, with reward
33 | visualisation enabled.
34 |
35 | [](https://www.youtube.com/watch?v=rAai4QzcYbs)
36 |
37 |
38 | ### Quadruped domain [April 2019]
39 |
40 | Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are:
41 |
42 | - 4 DoFs per leg, 1 constraining tendon.
43 | - 3 actuators per leg: 'yaw', 'lift', 'extend'.
44 | - Filtered position actuators with timescale of 100ms.
45 | - Sensors include an IMU, force/torque sensors, and rangefinders.
46 |
47 | Four tasks:
48 |
49 | - `walk` and `run`: self-right the body then move forward at a desired speed.
50 | - `escape`: escape a bowl-shaped random terrain (uses rangefinders).
51 | - `fetch`, go to a moving ball and bring it to a target.
52 |
53 | All behaviors in the video below were trained with [Abdolmaleki et al's
54 | MPO](https://arxiv.org/abs/1806.06920).
55 |
56 | [](https://www.youtube.com/watch?v=RhRLjbb7pBE)
57 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/acrobot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Acrobot domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 | import numpy as np
31 |
32 | _DEFAULT_TIME_LIMIT = 10
33 | SUITE = containers.TaggedTasks()
34 |
35 |
36 | def get_model_and_assets():
37 | """Returns a tuple containing the model XML string and a dict of assets."""
38 | return common.read_model('acrobot.xml'), common.ASSETS
39 |
40 |
41 | @SUITE.add('benchmarking')
42 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
43 | environment_kwargs=None):
44 | """Returns Acrobot balance task."""
45 | physics = Physics.from_xml_string(*get_model_and_assets())
46 | task = Balance(sparse=False, random=random)
47 | environment_kwargs = environment_kwargs or {}
48 | return control.Environment(
49 | physics, task, time_limit=time_limit, **environment_kwargs)
50 |
51 |
52 | @SUITE.add('benchmarking')
53 | def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None,
54 | environment_kwargs=None):
55 | """Returns Acrobot sparse balance."""
56 | physics = Physics.from_xml_string(*get_model_and_assets())
57 | task = Balance(sparse=True, random=random)
58 | environment_kwargs = environment_kwargs or {}
59 | return control.Environment(
60 | physics, task, time_limit=time_limit, **environment_kwargs)
61 |
62 |
63 | class Physics(mujoco.Physics):
64 | """Physics simulation with additional features for the Acrobot domain."""
65 |
66 | def horizontal(self):
67 | """Returns horizontal (x) component of body frame z-axes."""
68 | return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz']
69 |
70 | def vertical(self):
71 | """Returns vertical (z) component of body frame z-axes."""
72 | return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz']
73 |
74 | def to_target(self):
75 | """Returns the distance from the tip to the target."""
76 | tip_to_target = (self.named.data.site_xpos['target'] -
77 | self.named.data.site_xpos['tip'])
78 | return np.linalg.norm(tip_to_target)
79 |
80 | def orientations(self):
81 | """Returns the sines and cosines of the pole angles."""
82 | return np.concatenate((self.horizontal(), self.vertical()))
83 |
84 |
85 | class Balance(base.Task):
86 | """An Acrobot `Task` to swing up and balance the pole."""
87 |
88 | def __init__(self, sparse, random=None):
89 | """Initializes an instance of `Balance`.
90 |
91 | Args:
92 | sparse: A `bool` specifying whether to use a sparse (indicator) reward.
93 | random: Optional, either a `numpy.random.RandomState` instance, an
94 | integer seed for creating a new `RandomState`, or None to select a seed
95 | automatically (default).
96 | """
97 | self._sparse = sparse
98 | super(Balance, self).__init__(random=random)
99 |
100 | def initialize_episode(self, physics):
101 | """Sets the state of the environment at the start of each episode.
102 |
103 | Shoulder and elbow are set to a random position between [-pi, pi).
104 |
105 | Args:
106 | physics: An instance of `Physics`.
107 | """
108 | physics.named.data.qpos[
109 | ['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2)
110 | super(Balance, self).initialize_episode(physics)
111 |
112 | def get_observation(self, physics):
113 | """Returns an observation of pole orientation and angular velocities."""
114 | obs = collections.OrderedDict()
115 | obs['orientations'] = physics.orientations()
116 | obs['velocity'] = physics.velocity()
117 | return obs
118 |
119 | def _get_reward(self, physics, sparse):
120 | target_radius = physics.named.model.site_size['target', 0]
121 | return rewards.tolerance(physics.to_target(),
122 | bounds=(0, target_radius),
123 | margin=0 if sparse else 1)
124 |
125 | def get_reward(self, physics):
126 | """Returns a sparse or a smooth reward, as specified in the constructor."""
127 | return self._get_reward(physics, sparse=self._sparse)
128 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/acrobot.xml:
--------------------------------------------------------------------------------
1 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/ball_in_cup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Ball-in-Cup Domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 |
30 | _DEFAULT_TIME_LIMIT = 20 # (seconds)
31 | _CONTROL_TIMESTEP = .02 # (seconds)
32 |
33 |
34 | SUITE = containers.TaggedTasks()
35 |
36 |
37 | def get_model_and_assets():
38 | """Returns a tuple containing the model XML string and a dict of assets."""
39 | return common.read_model('ball_in_cup.xml'), common.ASSETS
40 |
41 |
42 | @SUITE.add('benchmarking', 'easy')
43 | def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
44 | """Returns the Ball-in-Cup task."""
45 | physics = Physics.from_xml_string(*get_model_and_assets())
46 | task = BallInCup(random=random)
47 | environment_kwargs = environment_kwargs or {}
48 | return control.Environment(
49 | physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
50 | **environment_kwargs)
51 |
52 |
53 | class Physics(mujoco.Physics):
54 | """Physics with additional features for the Ball-in-Cup domain."""
55 |
56 | def ball_to_target(self):
57 | """Returns the vector from the ball to the target."""
58 | target = self.named.data.site_xpos['target', ['x', 'z']]
59 | ball = self.named.data.xpos['ball', ['x', 'z']]
60 | return target - ball
61 |
62 | def in_target(self):
63 | """Returns 1 if the ball is in the target, 0 otherwise."""
64 | ball_to_target = abs(self.ball_to_target())
65 | target_size = self.named.model.site_size['target', [0, 2]]
66 | ball_size = self.named.model.geom_size['ball', 0]
67 | return float(all(ball_to_target < target_size - ball_size))
68 |
69 |
70 | class BallInCup(base.Task):
71 | """The Ball-in-Cup task. Put the ball in the cup."""
72 |
73 | def initialize_episode(self, physics):
74 | """Sets the state of the environment at the start of each episode.
75 |
76 | Args:
77 | physics: An instance of `Physics`.
78 |
79 | """
80 | # Find a collision-free random initial position of the ball.
81 | penetrating = True
82 | while penetrating:
83 | # Assign a random ball position.
84 | physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
85 | physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
86 | # Check for collisions.
87 | physics.after_reset()
88 | penetrating = physics.data.ncon > 0
89 | super(BallInCup, self).initialize_episode(physics)
90 |
91 | def get_observation(self, physics):
92 | """Returns an observation of the state."""
93 | obs = collections.OrderedDict()
94 | obs['position'] = physics.position()
95 | obs['velocity'] = physics.velocity()
96 | return obs
97 |
98 | def get_reward(self, physics):
99 | """Returns a sparse reward."""
100 | return physics.in_target()
101 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/ball_in_cup.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Base class for tasks in the Control Suite."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from dm_control import mujoco
23 | from dm_control.rl import control
24 |
25 | import numpy as np
26 |
27 |
28 | class Task(control.Task):
29 | """Base class for tasks in the Control Suite.
30 |
31 | Actions are mapped directly to the states of MuJoCo actuators: each element of
32 | the action array is used to set the control input for a single actuator. The
33 | ordering of the actuators is the same as in the corresponding MJCF XML file.
34 |
35 | Attributes:
36 | random: A `numpy.random.RandomState` instance. This should be used to
37 | generate all random variables associated with the task, such as random
38 | starting states, observation noise* etc.
39 |
40 | *If sensor noise is enabled in the MuJoCo model then this will be generated
41 | using MuJoCo's internal RNG, which has its own independent state.
42 | """
43 |
44 | def __init__(self, random=None):
45 | """Initializes a new continuous control task.
46 |
47 | Args:
48 | random: Optional, either a `numpy.random.RandomState` instance, an integer
49 | seed for creating a new `RandomState`, or None to select a seed
50 | automatically (default).
51 | """
52 | if not isinstance(random, np.random.RandomState):
53 | random = np.random.RandomState(random)
54 | self._random = random
55 | self._visualize_reward = False
56 |
57 | @property
58 | def random(self):
59 | """Task-specific `numpy.random.RandomState` instance."""
60 | return self._random
61 |
62 | def action_spec(self, physics):
63 | """Returns a `BoundedArraySpec` matching the `physics` actuators."""
64 | return mujoco.action_spec(physics)
65 |
66 | def initialize_episode(self, physics):
67 | """Resets geom colors to their defaults after starting a new episode.
68 |
69 | Subclasses of `base.Task` must delegate to this method after performing
70 | their own initialization.
71 |
72 | Args:
73 | physics: An instance of `mujoco.Physics`.
74 | """
75 | self.after_step(physics)
76 |
77 | def before_step(self, action, physics):
78 | """Sets the control signal for the actuators to values in `action`."""
79 | # Support legacy internal code.
80 | action = getattr(action, "continuous_actions", action)
81 | physics.set_control(action)
82 |
83 | def after_step(self, physics):
84 | """Modifies colors according to the reward."""
85 | if self._visualize_reward:
86 | reward = np.clip(self.get_reward(physics), 0.0, 1.0)
87 | _set_reward_colors(physics, reward)
88 |
89 | @property
90 | def visualize_reward(self):
91 | return self._visualize_reward
92 |
93 | @visualize_reward.setter
94 | def visualize_reward(self, value):
95 | if not isinstance(value, bool):
96 | raise ValueError("Expected a boolean, got {}.".format(type(value)))
97 | self._visualize_reward = value
98 |
99 |
100 | _MATERIALS = ["self", "effector", "target"]
101 | _DEFAULT = [name + "_default" for name in _MATERIALS]
102 | _HIGHLIGHT = [name + "_highlight" for name in _MATERIALS]
103 |
104 |
105 | def _set_reward_colors(physics, reward):
106 | """Sets the highlight, effector and target colors according to the reward."""
107 | assert 0.0 <= reward <= 1.0
108 | colors = physics.named.model.mat_rgba
109 | default = colors[_DEFAULT]
110 | highlight = colors[_HIGHLIGHT]
111 | blend_coef = reward ** 4 # Better color distinction near high rewards.
112 | colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default
113 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/cartpole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/cheetah.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Cheetah Domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 |
31 |
32 | # How long the simulation will run, in seconds.
33 | _DEFAULT_TIME_LIMIT = 10
34 |
35 | # Running speed above which reward is 1.
36 | _RUN_SPEED = 10
37 |
38 | SUITE = containers.TaggedTasks()
39 |
40 |
41 | def get_model_and_assets():
42 | """Returns a tuple containing the model XML string and a dict of assets."""
43 | return common.read_model('cheetah.xml'), common.ASSETS
44 |
45 |
46 | @SUITE.add('benchmarking')
47 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
48 | """Returns the run task."""
49 | physics = Physics.from_xml_string(*get_model_and_assets())
50 | task = Cheetah(random=random)
51 | environment_kwargs = environment_kwargs or {}
52 | return control.Environment(physics, task, time_limit=time_limit,
53 | **environment_kwargs)
54 |
55 |
56 | class Physics(mujoco.Physics):
57 | """Physics simulation with additional features for the Cheetah domain."""
58 |
59 | def speed(self):
60 | """Returns the horizontal speed of the Cheetah."""
61 | return self.named.data.sensordata['torso_subtreelinvel'][0]
62 |
63 |
64 | class Cheetah(base.Task):
65 | """A `Task` to train a running Cheetah."""
66 |
67 | def initialize_episode(self, physics):
68 | """Sets the state of the environment at the start of each episode."""
69 | # The indexing below assumes that all joints have a single DOF.
70 | assert physics.model.nq == physics.model.njnt
71 | is_limited = physics.model.jnt_limited == 1
72 | lower, upper = physics.model.jnt_range[is_limited].T
73 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
74 |
75 | # Stabilize the model before the actual simulation.
76 | for _ in range(200):
77 | physics.step()
78 |
79 | physics.data.time = 0
80 | self._timeout_progress = 0
81 | super(Cheetah, self).initialize_episode(physics)
82 |
83 | def get_observation(self, physics):
84 | """Returns an observation of the state, ignoring horizontal position."""
85 | obs = collections.OrderedDict()
86 | # Ignores horizontal position to maintain translational invariance.
87 | obs['position'] = physics.data.qpos[1:].copy()
88 | obs['velocity'] = physics.velocity()
89 | return obs
90 |
91 | def get_reward(self, physics):
92 | """Returns a reward to the agent."""
93 | return rewards.tolerance(physics.speed(),
94 | bounds=(_RUN_SPEED, float('inf')),
95 | margin=_RUN_SPEED,
96 | value_at_margin=0,
97 | sigmoid='linear')
98 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/cheetah.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Functions to manage the common assets for domains."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | from dm_control.utils import io as resources
24 |
25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
26 | _FILENAMES = [
27 | "./common/materials.xml",
28 | "./common/materials_white_floor.xml",
29 | "./common/skybox.xml",
30 | "./common/visual.xml",
31 | ]
32 |
33 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
34 | for filename in _FILENAMES}
35 |
36 |
37 | def read_model(model_filename):
38 | """Reads a model XML file and returns its contents as a string."""
39 | return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))
40 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/materials.xml:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/materials_white_floor.xml:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/skybox.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/common/visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/demos/mocap_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Demonstration of amc parsing for CMU mocap database.
17 |
18 | To run the demo, supply a path to a `.amc` file:
19 |
20 | python mocap_demo --filename='path/to/mocap.amc'
21 |
22 | CMU motion capture clips are available at mocap.cs.cmu.edu
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import time
30 | # Internal dependencies.
31 |
32 | from absl import app
33 | from absl import flags
34 |
35 | from local_dm_control_suite import humanoid_CMU
36 | from dm_control.suite.utils import parse_amc
37 |
38 | import matplotlib.pyplot as plt
39 | import numpy as np
40 |
41 | FLAGS = flags.FLAGS
42 | flags.DEFINE_string('filename', None, 'amc file to be converted.')
43 | flags.DEFINE_integer('max_num_frames', 90,
44 | 'Maximum number of frames for plotting/playback')
45 |
46 |
47 | def main(unused_argv):
48 | env = humanoid_CMU.stand()
49 |
50 | # Parse and convert specified clip.
51 | converted = parse_amc.convert(FLAGS.filename,
52 | env.physics, env.control_timestep())
53 |
54 | max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1)
55 |
56 | width = 480
57 | height = 480
58 | video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)
59 |
60 | for i in range(max_frame):
61 | p_i = converted.qpos[:, i]
62 | with env.physics.reset_context():
63 | env.physics.data.qpos[:] = p_i
64 | video[i] = np.hstack([env.physics.render(height, width, camera_id=0),
65 | env.physics.render(height, width, camera_id=1)])
66 |
67 | tic = time.time()
68 | for i in range(max_frame):
69 | if i == 0:
70 | img = plt.imshow(video[i])
71 | else:
72 | img.set_data(video[i])
73 | toc = time.time()
74 | clock_dt = toc - tic
75 | tic = time.time()
76 | # Real-time playback not always possible as clock_dt > .03
77 | plt.pause(max(0.01, 0.03 - clock_dt)) # Need min display time > 0.0.
78 | plt.draw()
79 | plt.waitforbuttonpress()
80 |
81 |
82 | if __name__ == '__main__':
83 | flags.mark_flag_as_required('filename')
84 | app.run(main)
85 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/demos/zeros.amc:
--------------------------------------------------------------------------------
1 | #DUMMY AMC for testing
2 | :FULLY-SPECIFIED
3 | :DEGREES
4 | 1
5 | root 0 0 0 0 0 0
6 | lowerback 0 0 0
7 | upperback 0 0 0
8 | thorax 0 0 0
9 | lowerneck 0 0 0
10 | upperneck 0 0 0
11 | head 0 0 0
12 | rclavicle 0 0
13 | rhumerus 0 0 0
14 | rradius 0
15 | rwrist 0
16 | rhand 0 0
17 | rfingers 0
18 | rthumb 0 0
19 | lclavicle 0 0
20 | lhumerus 0 0 0
21 | lradius 0
22 | lwrist 0
23 | lhand 0 0
24 | lfingers 0
25 | lthumb 0 0
26 | rfemur 0 0 0
27 | rtibia 0
28 | rfoot 0 0
29 | rtoes 0
30 | lfemur 0 0 0
31 | ltibia 0
32 | lfoot 0 0
33 | ltoes 0
34 | 2
35 | root 0 0 0 0 0 0
36 | lowerback 0 0 0
37 | upperback 0 0 0
38 | thorax 0 0 0
39 | lowerneck 0 0 0
40 | upperneck 0 0 0
41 | head 0 0 0
42 | rclavicle 0 0
43 | rhumerus 0 0 0
44 | rradius 0
45 | rwrist 0
46 | rhand 0 0
47 | rfingers 0
48 | rthumb 0 0
49 | lclavicle 0 0
50 | lhumerus 0 0 0
51 | lradius 0
52 | lwrist 0
53 | lhand 0 0
54 | lfingers 0
55 | lthumb 0 0
56 | rfemur 0 0 0
57 | rtibia 0
58 | rfoot 0 0
59 | rtoes 0
60 | lfemur 0 0 0
61 | ltibia 0
62 | lfoot 0 0
63 | ltoes 0
64 | 3
65 | root 0 0 0 0 0 0
66 | lowerback 0 0 0
67 | upperback 0 0 0
68 | thorax 0 0 0
69 | lowerneck 0 0 0
70 | upperneck 0 0 0
71 | head 0 0 0
72 | rclavicle 0 0
73 | rhumerus 0 0 0
74 | rradius 0
75 | rwrist 0
76 | rhand 0 0
77 | rfingers 0
78 | rthumb 0 0
79 | lclavicle 0 0
80 | lhumerus 0 0 0
81 | lradius 0
82 | lwrist 0
83 | lhand 0 0
84 | lfingers 0
85 | lthumb 0 0
86 | rfemur 0 0 0
87 | rtibia 0
88 | rfoot 0 0
89 | rtoes 0
90 | lfemur 0 0 0
91 | ltibia 0
92 | lfoot 0 0
93 | ltoes 0
94 | 4
95 | root 0 0 0 0 0 0
96 | lowerback 0 0 0
97 | upperback 0 0 0
98 | thorax 0 0 0
99 | lowerneck 0 0 0
100 | upperneck 0 0 0
101 | head 0 0 0
102 | rclavicle 0 0
103 | rhumerus 0 0 0
104 | rradius 0
105 | rwrist 0
106 | rhand 0 0
107 | rfingers 0
108 | rthumb 0 0
109 | lclavicle 0 0
110 | lhumerus 0 0 0
111 | lradius 0
112 | lwrist 0
113 | lhand 0 0
114 | lfingers 0
115 | lthumb 0 0
116 | rfemur 0 0 0
117 | rtibia 0
118 | rfoot 0 0
119 | rtoes 0
120 | lfemur 0 0 0
121 | ltibia 0
122 | lfoot 0 0
123 | ltoes 0
124 | 5
125 | root 0 0 0 0 0 0
126 | lowerback 0 0 0
127 | upperback 0 0 0
128 | thorax 0 0 0
129 | lowerneck 0 0 0
130 | upperneck 0 0 0
131 | head 0 0 0
132 | rclavicle 0 0
133 | rhumerus 0 0 0
134 | rradius 0
135 | rwrist 0
136 | rhand 0 0
137 | rfingers 0
138 | rthumb 0 0
139 | lclavicle 0 0
140 | lhumerus 0 0 0
141 | lradius 0
142 | lwrist 0
143 | lhand 0 0
144 | lfingers 0
145 | lthumb 0 0
146 | rfemur 0 0 0
147 | rtibia 0
148 | rfoot 0 0
149 | rtoes 0
150 | lfemur 0 0 0
151 | ltibia 0
152 | lfoot 0 0
153 | ltoes 0
154 | 6
155 | root 0 0 0 0 0 0
156 | lowerback 0 0 0
157 | upperback 0 0 0
158 | thorax 0 0 0
159 | lowerneck 0 0 0
160 | upperneck 0 0 0
161 | head 0 0 0
162 | rclavicle 0 0
163 | rhumerus 0 0 0
164 | rradius 0
165 | rwrist 0
166 | rhand 0 0
167 | rfingers 0
168 | rthumb 0 0
169 | lclavicle 0 0
170 | lhumerus 0 0 0
171 | lradius 0
172 | lwrist 0
173 | lhand 0 0
174 | lfingers 0
175 | lthumb 0 0
176 | rfemur 0 0 0
177 | rtibia 0
178 | rfoot 0 0
179 | rtoes 0
180 | lfemur 0 0 0
181 | ltibia 0
182 | lfoot 0 0
183 | ltoes 0
184 | 7
185 | root 0 0 0 0 0 0
186 | lowerback 0 0 0
187 | upperback 0 0 0
188 | thorax 0 0 0
189 | lowerneck 0 0 0
190 | upperneck 0 0 0
191 | head 0 0 0
192 | rclavicle 0 0
193 | rhumerus 0 0 0
194 | rradius 0
195 | rwrist 0
196 | rhand 0 0
197 | rfingers 0
198 | rthumb 0 0
199 | lclavicle 0 0
200 | lhumerus 0 0 0
201 | lradius 0
202 | lwrist 0
203 | lhand 0 0
204 | lfingers 0
205 | lthumb 0 0
206 | rfemur 0 0 0
207 | rtibia 0
208 | rfoot 0 0
209 | rtoes 0
210 | lfemur 0 0 0
211 | ltibia 0
212 | lfoot 0 0
213 | ltoes 0
214 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/explore.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Control suite environments explorer."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from absl import app
22 | from absl import flags
23 | from dm_control import suite
24 | from dm_control.suite.wrappers import action_noise
25 | from six.moves import input
26 |
27 | from dm_control import viewer
28 |
29 |
30 | _ALL_NAMES = ['.'.join(domain_task) for domain_task in suite.ALL_TASKS]
31 |
32 | flags.DEFINE_enum('environment_name', None, _ALL_NAMES,
33 | 'Optional \'domain_name.task_name\' pair specifying the '
34 | 'environment to load. If unspecified a prompt will appear to '
35 | 'select one.')
36 | flags.DEFINE_bool('timeout', True, 'Whether episodes should have a time limit.')
37 | flags.DEFINE_bool('visualize_reward', True,
38 | 'Whether to vary the colors of geoms according to the '
39 | 'current reward value.')
40 | flags.DEFINE_float('action_noise', 0.,
41 | 'Standard deviation of Gaussian noise to apply to actions, '
42 | 'expressed as a fraction of the max-min range for each '
43 | 'action dimension. Defaults to 0, i.e. no noise.')
44 | FLAGS = flags.FLAGS
45 |
46 |
47 | def prompt_environment_name(prompt, values):
48 | environment_name = None
49 | while not environment_name:
50 | environment_name = input(prompt)
51 | if not environment_name or values.index(environment_name) < 0:
52 | print('"%s" is not a valid environment name.' % environment_name)
53 | environment_name = None
54 | return environment_name
55 |
56 |
57 | def main(argv):
58 | del argv
59 | environment_name = FLAGS.environment_name
60 | if environment_name is None:
61 | print('\n '.join(['Available environments:'] + _ALL_NAMES))
62 | environment_name = prompt_environment_name(
63 | 'Please select an environment name: ', _ALL_NAMES)
64 |
65 | index = _ALL_NAMES.index(environment_name)
66 | domain_name, task_name = suite.ALL_TASKS[index]
67 |
68 | task_kwargs = {}
69 | if not FLAGS.timeout:
70 | task_kwargs['time_limit'] = float('inf')
71 |
72 | def loader():
73 | env = suite.load(
74 | domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs)
75 | env.task.visualize_reward = FLAGS.visualize_reward
76 | if FLAGS.action_noise > 0:
77 | env = action_noise.Wrapper(env, scale=FLAGS.action_noise)
78 | return env
79 |
80 | viewer.launch(loader)
81 |
82 |
83 | if __name__ == '__main__':
84 | app.run(main)
85 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/finger.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/fish.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/lqr.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/pendulum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Pendulum domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 | import numpy as np
31 |
32 |
33 | _DEFAULT_TIME_LIMIT = 20
34 | _ANGLE_BOUND = 8
35 | _COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND))
36 | SUITE = containers.TaggedTasks()
37 |
38 |
39 | def get_model_and_assets():
40 | """Returns a tuple containing the model XML string and a dict of assets."""
41 | return common.read_model('pendulum.xml'), common.ASSETS
42 |
43 |
44 | @SUITE.add('benchmarking')
45 | def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None,
46 | environment_kwargs=None):
47 | """Returns pendulum swingup task ."""
48 | physics = Physics.from_xml_string(*get_model_and_assets())
49 | task = SwingUp(random=random)
50 | environment_kwargs = environment_kwargs or {}
51 | return control.Environment(
52 | physics, task, time_limit=time_limit, **environment_kwargs)
53 |
54 |
55 | class Physics(mujoco.Physics):
56 | """Physics simulation with additional features for the Pendulum domain."""
57 |
58 | def pole_vertical(self):
59 | """Returns vertical (z) component of pole frame."""
60 | return self.named.data.xmat['pole', 'zz']
61 |
62 | def angular_velocity(self):
63 | """Returns the angular velocity of the pole."""
64 | return self.named.data.qvel['hinge'].copy()
65 |
66 | def pole_orientation(self):
67 | """Returns both horizontal and vertical components of pole frame."""
68 | return self.named.data.xmat['pole', ['zz', 'xz']]
69 |
70 |
71 | class SwingUp(base.Task):
72 | """A Pendulum `Task` to swing up and balance the pole."""
73 |
74 | def __init__(self, random=None):
75 | """Initialize an instance of `Pendulum`.
76 |
77 | Args:
78 | random: Optional, either a `numpy.random.RandomState` instance, an
79 | integer seed for creating a new `RandomState`, or None to select a seed
80 | automatically (default).
81 | """
82 | super(SwingUp, self).__init__(random=random)
83 |
84 | def initialize_episode(self, physics):
85 | """Sets the state of the environment at the start of each episode.
86 |
87 | Pole is set to a random angle between [-pi, pi).
88 |
89 | Args:
90 | physics: An instance of `Physics`.
91 |
92 | """
93 | physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi)
94 | super(SwingUp, self).initialize_episode(physics)
95 |
96 | def get_observation(self, physics):
97 | """Returns an observation.
98 |
99 | Observations are states concatenating pole orientation and angular velocity
100 | and pixels from fixed camera.
101 |
102 | Args:
103 | physics: An instance of `physics`, Pendulum physics.
104 |
105 | Returns:
106 | A `dict` of observation.
107 | """
108 | obs = collections.OrderedDict()
109 | obs['orientation'] = physics.pole_orientation()
110 | obs['velocity'] = physics.angular_velocity()
111 | return obs
112 |
113 | def get_reward(self, physics):
114 | return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1))
115 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/point_mass.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/reacher.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Reacher domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from local_dm_control_suite import base
27 | from local_dm_control_suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | import numpy as np
32 |
33 | SUITE = containers.TaggedTasks()
34 | _DEFAULT_TIME_LIMIT = 20
35 | _BIG_TARGET = .05
36 | _SMALL_TARGET = .015
37 |
38 |
39 | def get_model_and_assets():
40 | """Returns a tuple containing the model XML string and a dict of assets."""
41 | return common.read_model('reacher.xml'), common.ASSETS
42 |
43 |
44 | @SUITE.add('benchmarking', 'easy')
45 | def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
46 | """Returns reacher with sparse reward with 5e-2 tol and randomized target."""
47 | physics = Physics.from_xml_string(*get_model_and_assets())
48 | task = Reacher(target_size=_BIG_TARGET, random=random)
49 | environment_kwargs = environment_kwargs or {}
50 | return control.Environment(
51 | physics, task, time_limit=time_limit, **environment_kwargs)
52 |
53 |
54 | @SUITE.add('benchmarking')
55 | def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
56 | """Returns reacher with sparse reward with 1e-2 tol and randomized target."""
57 | physics = Physics.from_xml_string(*get_model_and_assets())
58 | task = Reacher(target_size=_SMALL_TARGET, random=random)
59 | environment_kwargs = environment_kwargs or {}
60 | return control.Environment(
61 | physics, task, time_limit=time_limit, **environment_kwargs)
62 |
63 |
64 | class Physics(mujoco.Physics):
65 | """Physics simulation with additional features for the Reacher domain."""
66 |
67 | def finger_to_target(self):
68 | """Returns the vector from target to finger in global coordinates."""
69 | return (self.named.data.geom_xpos['target', :2] -
70 | self.named.data.geom_xpos['finger', :2])
71 |
72 | def finger_to_target_dist(self):
73 | """Returns the signed distance between the finger and target surface."""
74 | return np.linalg.norm(self.finger_to_target())
75 |
76 |
77 | class Reacher(base.Task):
78 | """A reacher `Task` to reach the target."""
79 |
80 | def __init__(self, target_size, random=None):
81 | """Initialize an instance of `Reacher`.
82 |
83 | Args:
84 | target_size: A `float`, tolerance to determine whether finger reached the
85 | target.
86 | random: Optional, either a `numpy.random.RandomState` instance, an
87 | integer seed for creating a new `RandomState`, or None to select a seed
88 | automatically (default).
89 | """
90 | self._target_size = target_size
91 | super(Reacher, self).__init__(random=random)
92 |
93 | def initialize_episode(self, physics):
94 | """Sets the state of the environment at the start of each episode."""
95 | physics.named.model.geom_size['target', 0] = self._target_size
96 | randomizers.randomize_limited_and_rotational_joints(physics, self.random)
97 |
98 | # Randomize target position
99 | angle = self.random.uniform(0, 2 * np.pi)
100 | radius = self.random.uniform(.05, .20)
101 | physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle)
102 | physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle)
103 |
104 | super(Reacher, self).initialize_episode(physics)
105 |
106 | def get_observation(self, physics):
107 | """Returns an observation of the state and the target position."""
108 | obs = collections.OrderedDict()
109 | obs['position'] = physics.position()
110 | obs['to_target'] = physics.finger_to_target()
111 | obs['velocity'] = physics.velocity()
112 | return obs
113 |
114 | def get_reward(self, physics):
115 | radii = physics.named.model.geom_size[['target', 'finger'], 0].sum()
116 | return rewards.tolerance(physics.finger_to_target_dist(), (0, radii))
117 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/reacher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/tests/loader_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Tests for the dm_control.suite loader."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | # Internal dependencies.
23 |
24 | from absl.testing import absltest
25 |
26 | from dm_control import suite
27 | from dm_control.rl import control
28 |
29 |
30 | class LoaderTest(absltest.TestCase):
31 |
32 | def test_load_without_kwargs(self):
33 | env = suite.load('cartpole', 'swingup')
34 | self.assertIsInstance(env, control.Environment)
35 |
36 | def test_load_with_kwargs(self):
37 | env = suite.load('cartpole', 'swingup',
38 | task_kwargs={'time_limit': 40, 'random': 99})
39 | self.assertIsInstance(env, control.Environment)
40 |
41 |
42 | class LoaderConstantsTest(absltest.TestCase):
43 |
44 | def testSuiteConstants(self):
45 | self.assertNotEmpty(suite.BENCHMARKING)
46 | self.assertNotEmpty(suite.EASY)
47 | self.assertNotEmpty(suite.HARD)
48 | self.assertNotEmpty(suite.EXTRA)
49 |
50 |
51 | if __name__ == '__main__':
52 | absltest.main()
53 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/tests/lqr_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Tests specific to the LQR domain."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 | import unittest
24 |
25 | # Internal dependencies.
26 | from absl import logging
27 |
28 | from absl.testing import absltest
29 | from absl.testing import parameterized
30 |
31 | from local_dm_control_suite import lqr
32 | from local_dm_control_suite import lqr_solver
33 |
34 | import numpy as np
35 | from six.moves import range
36 |
37 |
38 | class LqrTest(parameterized.TestCase):
39 |
40 | @parameterized.named_parameters(
41 | ('lqr_2_1', lqr.lqr_2_1),
42 | ('lqr_6_2', lqr.lqr_6_2))
43 | def test_lqr_optimal_policy(self, make_env):
44 | env = make_env()
45 | p, k, beta = lqr_solver.solve(env)
46 | self.assertPolicyisOptimal(env, p, k, beta)
47 |
48 | @parameterized.named_parameters(
49 | ('lqr_2_1', lqr.lqr_2_1),
50 | ('lqr_6_2', lqr.lqr_6_2))
51 | @unittest.skipUnless(
52 | condition=lqr_solver.sp,
53 | reason='scipy is not available, so non-scipy DARE solver is the default.')
54 | def test_lqr_optimal_policy_no_scipy(self, make_env):
55 | env = make_env()
56 | old_sp = lqr_solver.sp
57 | try:
58 | lqr_solver.sp = None # Force the solver to use the non-scipy code path.
59 | p, k, beta = lqr_solver.solve(env)
60 | finally:
61 | lqr_solver.sp = old_sp
62 | self.assertPolicyisOptimal(env, p, k, beta)
63 |
64 | def assertPolicyisOptimal(self, env, p, k, beta):
65 | tolerance = 1e-3
66 | n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta)))
67 | logging.info('%d timesteps for %g convergence.', n_steps, tolerance)
68 | total_loss = 0.0
69 |
70 | timestep = env.reset()
71 | initial_state = np.hstack((timestep.observation['position'],
72 | timestep.observation['velocity']))
73 | logging.info('Measuring total cost over %d steps.', n_steps)
74 | for _ in range(n_steps):
75 | x = np.hstack((timestep.observation['position'],
76 | timestep.observation['velocity']))
77 | # u = k*x is the optimal policy
78 | u = k.dot(x)
79 | total_loss += 1 - (timestep.reward or 0.0)
80 | timestep = env.step(u)
81 |
82 | logging.info('Analytical expected total cost is .5*x^T*p*x.')
83 | expected_loss = .5 * initial_state.T.dot(p).dot(initial_state)
84 | logging.info('Comparing measured and predicted costs.')
85 | np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance)
86 |
87 | if __name__ == '__main__':
88 | absltest.main()
89 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Utility functions used in the control suite."""
17 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/utils/parse_amc_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Tests for parse_amc utility."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 |
24 | # Internal dependencies.
25 |
26 | from absl.testing import absltest
27 | from local_dm_control_suite import humanoid_CMU
28 | from dm_control.suite.utils import parse_amc
29 |
30 | from dm_control.utils import io as resources
31 |
32 | _TEST_AMC_PATH = resources.GetResourceFilename(
33 | os.path.join(os.path.dirname(__file__), '../demos/zeros.amc'))
34 |
35 |
36 | class ParseAMCTest(absltest.TestCase):
37 |
38 | def test_sizes_of_parsed_data(self):
39 |
40 | # Instantiate the humanoid environment.
41 | env = humanoid_CMU.stand()
42 |
43 | # Parse and convert specified clip.
44 | converted = parse_amc.convert(
45 | _TEST_AMC_PATH, env.physics, env.control_timestep())
46 |
47 | self.assertEqual(converted.qpos.shape[0], 63)
48 | self.assertEqual(converted.qvel.shape[0], 62)
49 | self.assertEqual(converted.time.shape[0], converted.qpos.shape[1])
50 | self.assertEqual(converted.qpos.shape[1],
51 | converted.qvel.shape[1] + 1)
52 |
53 | # Parse and convert specified clip -- WITH SMALLER TIMESTEP
54 | converted2 = parse_amc.convert(
55 | _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep())
56 |
57 | self.assertEqual(converted2.qpos.shape[0], 63)
58 | self.assertEqual(converted2.qvel.shape[0], 62)
59 | self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1])
60 | self.assertEqual(converted.qpos.shape[1],
61 | converted.qvel.shape[1] + 1)
62 |
63 | # Compare sizes of parsed objects for different timesteps
64 | self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1])
65 |
66 |
67 | if __name__ == '__main__':
68 | absltest.main()
69 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/utils/randomizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Randomization functions."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from dm_control.mujoco.wrapper import mjbindings
23 | import numpy as np
24 | from six.moves import range
25 |
26 |
27 | def random_limited_quaternion(random, limit):
28 | """Generates a random quaternion limited to the specified rotations."""
29 | axis = random.randn(3)
30 | axis /= np.linalg.norm(axis)
31 | angle = random.rand() * limit
32 |
33 | quaternion = np.zeros(4)
34 | mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
35 |
36 | return quaternion
37 |
38 |
39 | def randomize_limited_and_rotational_joints(physics, random=None):
40 | """Randomizes the positions of joints defined in the physics body.
41 |
42 | The following randomization rules apply:
43 | - Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
44 | - Unbounded hinges are samples uniformly in [-pi, pi]
45 | - Quaternions for unlimited free joints and ball joints are sampled
46 | uniformly on the unit 3-sphere.
47 | - Quaternions for limited ball joints are sampled uniformly on a sector
48 | of the unit 3-sphere.
49 | - The linear degrees of freedom of free joints are not randomized.
50 |
51 | Args:
52 | physics: Instance of 'Physics' class that holds a loaded model.
53 | random: Optional instance of 'np.random.RandomState'. Defaults to the global
54 | NumPy random state.
55 | """
56 | random = random or np.random
57 |
58 | hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
59 | slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
60 | ball = mjbindings.enums.mjtJoint.mjJNT_BALL
61 | free = mjbindings.enums.mjtJoint.mjJNT_FREE
62 |
63 | qpos = physics.named.data.qpos
64 |
65 | for joint_id in range(physics.model.njnt):
66 | joint_name = physics.model.id2name(joint_id, 'joint')
67 | joint_type = physics.model.jnt_type[joint_id]
68 | is_limited = physics.model.jnt_limited[joint_id]
69 | range_min, range_max = physics.model.jnt_range[joint_id]
70 |
71 | if is_limited:
72 | if joint_type == hinge or joint_type == slide:
73 | qpos[joint_name] = random.uniform(range_min, range_max)
74 |
75 | elif joint_type == ball:
76 | qpos[joint_name] = random_limited_quaternion(random, range_max)
77 |
78 | else:
79 | if joint_type == hinge:
80 | qpos[joint_name] = random.uniform(-np.pi, np.pi)
81 |
82 | elif joint_type == ball:
83 | quat = random.randn(4)
84 | quat /= np.linalg.norm(quat)
85 | qpos[joint_name] = quat
86 |
87 | elif joint_type == free:
88 | quat = random.rand(4)
89 | quat /= np.linalg.norm(quat)
90 | qpos[joint_name][3:] = quat
91 |
92 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Environment wrappers used to extend or modify environment behaviour."""
17 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/wrappers/action_noise.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Wrapper control suite environments that adds Gaussian noise to actions."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import dm_env
23 | import numpy as np
24 |
25 |
26 | _BOUNDS_MUST_BE_FINITE = (
27 | 'All bounds in `env.action_spec()` must be finite, got: {action_spec}')
28 |
29 |
30 | class Wrapper(dm_env.Environment):
31 | """Wraps a control environment and adds Gaussian noise to actions."""
32 |
33 | def __init__(self, env, scale=0.01):
34 | """Initializes a new action noise Wrapper.
35 |
36 | Args:
37 | env: The control suite environment to wrap.
38 | scale: The standard deviation of the noise, expressed as a fraction
39 | of the max-min range for each action dimension.
40 |
41 | Raises:
42 | ValueError: If any of the action dimensions of the wrapped environment are
43 | unbounded.
44 | """
45 | action_spec = env.action_spec()
46 | if not (np.all(np.isfinite(action_spec.minimum)) and
47 | np.all(np.isfinite(action_spec.maximum))):
48 | raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec))
49 | self._minimum = action_spec.minimum
50 | self._maximum = action_spec.maximum
51 | self._noise_std = scale * (action_spec.maximum - action_spec.minimum)
52 | self._env = env
53 |
54 | def step(self, action):
55 | noisy_action = action + self._env.task.random.normal(scale=self._noise_std)
56 | # Clip the noisy actions in place so that they fall within the bounds
57 | # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of-
58 | # bounds control inputs, but we also clip here in case the actions do not
59 | # correspond directly to MuJoCo actuators, or if there are other wrapper
60 | # layers that expect the actions to be within bounds.
61 | np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action)
62 | return self._env.step(noisy_action)
63 |
64 | def reset(self):
65 | return self._env.reset()
66 |
67 | def observation_spec(self):
68 | return self._env.observation_spec()
69 |
70 | def action_spec(self):
71 | return self._env.action_spec()
72 |
73 | def __getattr__(self, name):
74 | return getattr(self._env, name)
75 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/wrappers/pixels.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Wrapper that adds pixel observations to a control environment."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | import dm_env
25 | from dm_env import specs
26 |
27 | STATE_KEY = 'state'
28 |
29 |
30 | class Wrapper(dm_env.Environment):
31 | """Wraps a control environment and adds a rendered pixel observation."""
32 |
33 | def __init__(self, env, pixels_only=True, render_kwargs=None,
34 | observation_key='pixels'):
35 | """Initializes a new pixel Wrapper.
36 |
37 | Args:
38 | env: The environment to wrap.
39 | pixels_only: If True (default), the original set of 'state' observations
40 | returned by the wrapped environment will be discarded, and the
41 | `OrderedDict` of observations will only contain pixels. If False, the
42 | `OrderedDict` will contain the original observations as well as the
43 | pixel observations.
44 | render_kwargs: Optional `dict` containing keyword arguments passed to the
45 | `mujoco.Physics.render` method.
46 | observation_key: Optional custom string specifying the pixel observation's
47 | key in the `OrderedDict` of observations. Defaults to 'pixels'.
48 |
49 | Raises:
50 | ValueError: If `env`'s observation spec is not compatible with the
51 | wrapper. Supported formats are a single array, or a dict of arrays.
52 | ValueError: If `env`'s observation already contains the specified
53 | `observation_key`.
54 | """
55 | if render_kwargs is None:
56 | render_kwargs = {}
57 |
58 | wrapped_observation_spec = env.observation_spec()
59 |
60 | if isinstance(wrapped_observation_spec, specs.Array):
61 | self._observation_is_dict = False
62 | invalid_keys = set([STATE_KEY])
63 | elif isinstance(wrapped_observation_spec, collections.MutableMapping):
64 | self._observation_is_dict = True
65 | invalid_keys = set(wrapped_observation_spec.keys())
66 | else:
67 | raise ValueError('Unsupported observation spec structure.')
68 |
69 | if not pixels_only and observation_key in invalid_keys:
70 | raise ValueError('Duplicate or reserved observation key {!r}.'
71 | .format(observation_key))
72 |
73 | if pixels_only:
74 | self._observation_spec = collections.OrderedDict()
75 | elif self._observation_is_dict:
76 | self._observation_spec = wrapped_observation_spec.copy()
77 | else:
78 | self._observation_spec = collections.OrderedDict()
79 | self._observation_spec[STATE_KEY] = wrapped_observation_spec
80 |
81 | # Extend observation spec.
82 | pixels = env.physics.render(**render_kwargs)
83 | pixels_spec = specs.Array(
84 | shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
85 | self._observation_spec[observation_key] = pixels_spec
86 |
87 | self._env = env
88 | self._pixels_only = pixels_only
89 | self._render_kwargs = render_kwargs
90 | self._observation_key = observation_key
91 |
92 | def reset(self):
93 | time_step = self._env.reset()
94 | return self._add_pixel_observation(time_step)
95 |
96 | def step(self, action):
97 | time_step = self._env.step(action)
98 | return self._add_pixel_observation(time_step)
99 |
100 | def observation_spec(self):
101 | return self._observation_spec
102 |
103 | def action_spec(self):
104 | return self._env.action_spec()
105 |
106 | def _add_pixel_observation(self, time_step):
107 | if self._pixels_only:
108 | observation = collections.OrderedDict()
109 | elif self._observation_is_dict:
110 | observation = type(time_step.observation)(time_step.observation)
111 | else:
112 | observation = collections.OrderedDict()
113 | observation[STATE_KEY] = time_step.observation
114 |
115 | pixels = self._env.physics.render(**self._render_kwargs)
116 | observation[self._observation_key] = pixels
117 | return time_step._replace(observation=observation)
118 |
119 | def __getattr__(self, name):
120 | return getattr(self._env, name)
121 |
--------------------------------------------------------------------------------
/local_dm_control_suite_off_center/wrappers/pixels_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Tests for the pixel wrapper."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 |
24 | # Internal dependencies.
25 | from absl.testing import absltest
26 | from absl.testing import parameterized
27 | from local_dm_control_suite import cartpole
28 | from dm_control.suite.wrappers import pixels
29 | import dm_env
30 | from dm_env import specs
31 |
32 | import numpy as np
33 |
34 |
35 | class FakePhysics(object):
36 |
37 | def render(self, *args, **kwargs):
38 | del args
39 | del kwargs
40 | return np.zeros((4, 5, 3), dtype=np.uint8)
41 |
42 |
43 | class FakeArrayObservationEnvironment(dm_env.Environment):
44 |
45 | def __init__(self):
46 | self.physics = FakePhysics()
47 |
48 | def reset(self):
49 | return dm_env.restart(np.zeros((2,)))
50 |
51 | def step(self, action):
52 | del action
53 | return dm_env.transition(0.0, np.zeros((2,)))
54 |
55 | def action_spec(self):
56 | pass
57 |
58 | def observation_spec(self):
59 | return specs.Array(shape=(2,), dtype=np.float)
60 |
61 |
62 | class PixelsTest(parameterized.TestCase):
63 |
64 | @parameterized.parameters(True, False)
65 | def test_dict_observation(self, pixels_only):
66 | pixel_key = 'rgb'
67 |
68 | env = cartpole.swingup()
69 |
70 | # Make sure we are testing the right environment for the test.
71 | observation_spec = env.observation_spec()
72 | self.assertIsInstance(observation_spec, collections.OrderedDict)
73 |
74 | width = 320
75 | height = 240
76 |
77 | # The wrapper should only add one observation.
78 | wrapped = pixels.Wrapper(env,
79 | observation_key=pixel_key,
80 | pixels_only=pixels_only,
81 | render_kwargs={'width': width, 'height': height})
82 |
83 | wrapped_observation_spec = wrapped.observation_spec()
84 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
85 |
86 | if pixels_only:
87 | self.assertLen(wrapped_observation_spec, 1)
88 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
89 | else:
90 | expected_length = len(observation_spec) + 1
91 | self.assertLen(wrapped_observation_spec, expected_length)
92 | expected_keys = list(observation_spec.keys()) + [pixel_key]
93 | self.assertEqual(expected_keys, list(wrapped_observation_spec.keys()))
94 |
95 | # Check that the added spec item is consistent with the added observation.
96 | time_step = wrapped.reset()
97 | rgb_observation = time_step.observation[pixel_key]
98 | wrapped_observation_spec[pixel_key].validate(rgb_observation)
99 |
100 | self.assertEqual(rgb_observation.shape, (height, width, 3))
101 | self.assertEqual(rgb_observation.dtype, np.uint8)
102 |
103 | @parameterized.parameters(True, False)
104 | def test_single_array_observation(self, pixels_only):
105 | pixel_key = 'depth'
106 |
107 | env = FakeArrayObservationEnvironment()
108 | observation_spec = env.observation_spec()
109 | self.assertIsInstance(observation_spec, specs.Array)
110 |
111 | wrapped = pixels.Wrapper(env, observation_key=pixel_key,
112 | pixels_only=pixels_only)
113 | wrapped_observation_spec = wrapped.observation_spec()
114 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
115 |
116 | if pixels_only:
117 | self.assertLen(wrapped_observation_spec, 1)
118 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
119 | else:
120 | self.assertLen(wrapped_observation_spec, 2)
121 | self.assertEqual([pixels.STATE_KEY, pixel_key],
122 | list(wrapped_observation_spec.keys()))
123 |
124 | time_step = wrapped.reset()
125 |
126 | depth_observation = time_step.observation[pixel_key]
127 | wrapped_observation_spec[pixel_key].validate(depth_observation)
128 |
129 | self.assertEqual(depth_observation.shape, (4, 5, 3))
130 | self.assertEqual(depth_observation.dtype, np.uint8)
131 |
132 | if __name__ == '__main__':
133 | absltest.main()
134 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.12.0
2 | cachetools==4.2.2
3 | certifi==2021.5.30
4 | cffi==1.15.0
5 | chardet==4.0.0
6 | cloudpickle==1.6.0
7 | cycler==0.10.0
8 | Cython==0.29.24
9 | dataclasses==0.8
10 | decorator==4.4.2
11 | fasteners==0.16.3
12 | flatbuffers==1.12
13 | future==0.18.2
14 | gast==0.2.2
15 | gin-config==0.4.0
16 | glfw==2.1.0
17 | google-auth==1.28.1
18 | google-auth-oauthlib==0.4.4
19 | google-pasta==0.2.0
20 | grpcio==1.32.0
21 | gym==0.18.0
22 | h5py==2.10.0
23 | idna==2.10
24 | imageio==2.9.0
25 | imageio-ffmpeg==0.4.2
26 | importlib-metadata==3.10.0
27 | kiwisolver==1.3.1
28 | labmaze==1.0.5
29 | lockfile==0.12.2
30 | Markdown==3.3.4
31 | matplotlib==3.3.4
32 | mujoco-py==2.0.2.13
33 | networkx==2.5.1
34 | numpy==1.19.1
35 | oauthlib==3.1.0
36 | opencv-python==4.5.1.48
37 | opt-einsum==3.3.0
38 | pandas==1.1.5
39 | Pillow>=8.3.2
40 | Pillow-SIMD==7.0.0.post3
41 | protobuf==3.17.1
42 | pyasn1==0.4.8
43 | pyasn1-modules==0.2.8
44 | pybullet==3.2.0
45 | pycparser==2.20
46 | pyglet==1.5.0
47 | PyOpenGL==3.1.5
48 | pyparsing==2.4.7
49 | python-dateutil==2.8.1
50 | pytz==2021.1
51 | PyWavelets==1.1.1
52 | PyYAML==6.0
53 | requests==2.25.1
54 | requests-oauthlib==1.3.0
55 | rsa==4.7.2
56 | scikit-image==0.17.2
57 | scipy==1.5.4
58 | seaborn==0.11.1
59 | six==1.16.0
60 | sortedcontainers==2.4.0
61 | tensorboard==2.0.2
62 | tensorboard-data-server==0.6.0
63 | tensorboard-plugin-wit==1.8.0
64 | termcolor==1.1.0
65 | tifffile==2020.9.3
66 | torch==1.8.1
67 | tqdm==4.60.0
68 | typing-extensions==3.7.4.3
69 | urllib3>=1.26.5
70 | Werkzeug==1.0.1
71 | wrapt==1.12.1
72 | zipp==3.4.1
73 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train.py \
2 | --domain_name $3 \
3 | --task_name $4 --case $1 \
4 | --encoder_type pixel --work_dir ./tmp/test \
5 | --action_repeat 8 --num_eval_episodes 10 \
6 | --pre_transform_image_size 100 --image_size 84 --replay_buffer_capacity 100 \
7 | --frame_stack 3 --data_augs no_aug \
8 | --seed 23 --critic_lr 1e-3 --actor_lr 1e-3 --eval_freq 10000 \
9 | --batch_size 16 --num_train_steps 200000 --metric_loss \
10 | --init_steps 1000 \
11 | --resource_files './distractors/driving/*.mp4' --img_source 'video' --total_frames 50 \
12 | --horizon $2 --save_model
13 |
--------------------------------------------------------------------------------
/video.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import os
3 | import numpy as np
4 |
5 |
6 | class VideoRecorder(object):
7 | def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30):
8 | self.dir_name = dir_name
9 | self.height = height
10 | self.width = width
11 | self.camera_id = camera_id
12 | self.fps = fps
13 | self.frames = []
14 |
15 | def init(self, enabled=True):
16 | self.frames = []
17 | self.enabled = self.dir_name is not None and enabled
18 |
19 | def record(self, env):
20 | if self.enabled:
21 | try:
22 | frame = env.render(
23 | mode='rgb_array',
24 | height=self.height,
25 | width=self.width,
26 | camera_id=self.camera_id
27 | )
28 | except:
29 | frame = env.render(
30 | mode='rgb_array',
31 | )
32 |
33 | self.frames.append(frame)
34 |
35 | def save(self, file_name):
36 | if self.enabled:
37 | path = os.path.join(self.dir_name, file_name)
38 | imageio.mimsave(path, self.frames, fps=self.fps)
39 |
--------------------------------------------------------------------------------