├── LICENSE ├── README.md ├── dmc2gym ├── __init__.py └── wrappers.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Denis Yarats 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenAI Gym wrapper for the DeepMind Control Suite. 2 | A lightweight wrapper around the DeepMind Control Suite that provides the standard OpenAI Gym interface. The wrapper allows to specify the following: 3 | * Reliable random seed initialization that will ensure deterministic behaviour. 4 | * Setting ```from_pixels=True``` converts proprioceptive observations into image-based. In additional, you can choose the image dimensions, by setting ```height``` and ```width```. 5 | * Action space normalization bound each action's coordinate into the ```[-1, 1]``` range. 6 | * Setting ```frame_skip``` argument lets to perform action repeat. 7 | 8 | 9 | ### Instalation 10 | ``` 11 | pip install git+git://github.com/denisyarats/dmc2gym.git 12 | ``` 13 | 14 | ### Usage 15 | ```python 16 | import dmc2gym 17 | 18 | env = dmc2gym.make(domain_name='point_mass', task_name='easy', seed=1) 19 | 20 | done = False 21 | obs = env.reset() 22 | while not done: 23 | action = env.action_space.sample() 24 | obs, reward, done, info = env.step(action) 25 | ``` 26 | -------------------------------------------------------------------------------- /dmc2gym/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.envs.registration import register 3 | 4 | 5 | def make( 6 | domain_name, 7 | task_name, 8 | seed=1, 9 | visualize_reward=True, 10 | from_pixels=False, 11 | height=84, 12 | width=84, 13 | camera_id=0, 14 | frame_skip=1, 15 | episode_length=1000, 16 | environment_kwargs=None, 17 | time_limit=None, 18 | channels_first=True 19 | ): 20 | env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed) 21 | 22 | if from_pixels: 23 | assert not visualize_reward, 'cannot use visualize reward when learning from pixels' 24 | 25 | # shorten episode length 26 | max_episode_steps = (episode_length + frame_skip - 1) // frame_skip 27 | 28 | if not env_id in gym.envs.registry.env_specs: 29 | task_kwargs = {} 30 | if seed is not None: 31 | task_kwargs['random'] = seed 32 | if time_limit is not None: 33 | task_kwargs['time_limit'] = time_limit 34 | register( 35 | id=env_id, 36 | entry_point='dmc2gym.wrappers:DMCWrapper', 37 | kwargs=dict( 38 | domain_name=domain_name, 39 | task_name=task_name, 40 | task_kwargs=task_kwargs, 41 | environment_kwargs=environment_kwargs, 42 | visualize_reward=visualize_reward, 43 | from_pixels=from_pixels, 44 | height=height, 45 | width=width, 46 | camera_id=camera_id, 47 | frame_skip=frame_skip, 48 | channels_first=channels_first, 49 | ), 50 | max_episode_steps=max_episode_steps, 51 | ) 52 | return gym.make(env_id) 53 | -------------------------------------------------------------------------------- /dmc2gym/wrappers.py: -------------------------------------------------------------------------------- 1 | from gym import core, spaces 2 | from dm_control import suite 3 | from dm_env import specs 4 | import numpy as np 5 | 6 | 7 | def _spec_to_box(spec, dtype): 8 | def extract_min_max(s): 9 | assert s.dtype == np.float64 or s.dtype == np.float32 10 | dim = np.int(np.prod(s.shape)) 11 | if type(s) == specs.Array: 12 | bound = np.inf * np.ones(dim, dtype=np.float32) 13 | return -bound, bound 14 | elif type(s) == specs.BoundedArray: 15 | zeros = np.zeros(dim, dtype=np.float32) 16 | return s.minimum + zeros, s.maximum + zeros 17 | 18 | mins, maxs = [], [] 19 | for s in spec: 20 | mn, mx = extract_min_max(s) 21 | mins.append(mn) 22 | maxs.append(mx) 23 | low = np.concatenate(mins, axis=0).astype(dtype) 24 | high = np.concatenate(maxs, axis=0).astype(dtype) 25 | assert low.shape == high.shape 26 | return spaces.Box(low, high, dtype=dtype) 27 | 28 | 29 | def _flatten_obs(obs): 30 | obs_pieces = [] 31 | for v in obs.values(): 32 | flat = np.array([v]) if np.isscalar(v) else v.ravel() 33 | obs_pieces.append(flat) 34 | return np.concatenate(obs_pieces, axis=0) 35 | 36 | 37 | class DMCWrapper(core.Env): 38 | def __init__( 39 | self, 40 | domain_name, 41 | task_name, 42 | task_kwargs=None, 43 | visualize_reward={}, 44 | from_pixels=False, 45 | height=84, 46 | width=84, 47 | camera_id=0, 48 | frame_skip=1, 49 | environment_kwargs=None, 50 | channels_first=True 51 | ): 52 | assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' 53 | self._from_pixels = from_pixels 54 | self._height = height 55 | self._width = width 56 | self._camera_id = camera_id 57 | self._frame_skip = frame_skip 58 | self._channels_first = channels_first 59 | 60 | # create task 61 | self._env = suite.load( 62 | domain_name=domain_name, 63 | task_name=task_name, 64 | task_kwargs=task_kwargs, 65 | visualize_reward=visualize_reward, 66 | environment_kwargs=environment_kwargs 67 | ) 68 | 69 | # true and normalized action spaces 70 | self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) 71 | self._norm_action_space = spaces.Box( 72 | low=-1.0, 73 | high=1.0, 74 | shape=self._true_action_space.shape, 75 | dtype=np.float32 76 | ) 77 | 78 | # create observation space 79 | if from_pixels: 80 | shape = [3, height, width] if channels_first else [height, width, 3] 81 | self._observation_space = spaces.Box( 82 | low=0, high=255, shape=shape, dtype=np.uint8 83 | ) 84 | else: 85 | self._observation_space = _spec_to_box( 86 | self._env.observation_spec().values(), 87 | np.float64 88 | ) 89 | 90 | self._state_space = _spec_to_box( 91 | self._env.observation_spec().values(), 92 | np.float64 93 | ) 94 | 95 | self.current_state = None 96 | 97 | # set seed 98 | self.seed(seed=task_kwargs.get('random', 1)) 99 | 100 | def __getattr__(self, name): 101 | return getattr(self._env, name) 102 | 103 | def _get_obs(self, time_step): 104 | if self._from_pixels: 105 | obs = self.render( 106 | height=self._height, 107 | width=self._width, 108 | camera_id=self._camera_id 109 | ) 110 | if self._channels_first: 111 | obs = obs.transpose(2, 0, 1).copy() 112 | else: 113 | obs = _flatten_obs(time_step.observation) 114 | return obs 115 | 116 | def _convert_action(self, action): 117 | action = action.astype(np.float64) 118 | true_delta = self._true_action_space.high - self._true_action_space.low 119 | norm_delta = self._norm_action_space.high - self._norm_action_space.low 120 | action = (action - self._norm_action_space.low) / norm_delta 121 | action = action * true_delta + self._true_action_space.low 122 | action = action.astype(np.float32) 123 | return action 124 | 125 | @property 126 | def observation_space(self): 127 | return self._observation_space 128 | 129 | @property 130 | def state_space(self): 131 | return self._state_space 132 | 133 | @property 134 | def action_space(self): 135 | return self._norm_action_space 136 | 137 | @property 138 | def reward_range(self): 139 | return 0, self._frame_skip 140 | 141 | def seed(self, seed): 142 | self._true_action_space.seed(seed) 143 | self._norm_action_space.seed(seed) 144 | self._observation_space.seed(seed) 145 | 146 | def step(self, action): 147 | assert self._norm_action_space.contains(action) 148 | action = self._convert_action(action) 149 | assert self._true_action_space.contains(action) 150 | reward = 0 151 | extra = {'internal_state': self._env.physics.get_state().copy()} 152 | 153 | for _ in range(self._frame_skip): 154 | time_step = self._env.step(action) 155 | reward += time_step.reward or 0 156 | done = time_step.last() 157 | if done: 158 | break 159 | obs = self._get_obs(time_step) 160 | self.current_state = _flatten_obs(time_step.observation) 161 | extra['discount'] = time_step.discount 162 | return obs, reward, done, extra 163 | 164 | def reset(self): 165 | time_step = self._env.reset() 166 | self.current_state = _flatten_obs(time_step.observation) 167 | obs = self._get_obs(time_step) 168 | return obs 169 | 170 | def render(self, mode='rgb_array', height=None, width=None, camera_id=0): 171 | assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode 172 | height = height or self._height 173 | width = width or self._width 174 | camera_id = camera_id or self._camera_id 175 | return self._env.physics.render( 176 | height=height, width=width, camera_id=camera_id 177 | ) 178 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from setuptools import find_packages 4 | 5 | setup( 6 | name='dmc2gym', 7 | version='1.0.0', 8 | author='Denis Yarats', 9 | description=('a gym like wrapper for dm_control'), 10 | license='', 11 | keywords='gym dm_control openai deepmind', 12 | packages=find_packages(), 13 | install_requires=[ 14 | 'gym', 15 | 'dm_control', 16 | ], 17 | ) 18 | --------------------------------------------------------------------------------