├── .github └── ISSUE_TEMPLATE │ └── issue.md ├── .gitignore ├── LICENSE ├── README.md ├── dreamer.py ├── models.py ├── plotting.py ├── scores ├── baselines.json └── dreamer.json ├── tools.py └── wrappers.py /.github/ISSUE_TEMPLATE/issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Issue 3 | about: Ask a question, report a bug, or report any other issue. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.egg-info 4 | ./dist 5 | ./logdir 6 | MUJOCO_LOG.TXT 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Danijar Hafner 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dream to Control 2 | 3 | **NOTE:** Check out the code for [DreamerV2](https://github.com/danijar/dreamerv2), which supports both Atari and DMControl environments. 4 | 5 | Fast and simple implementation of the Dreamer agent in TensorFlow 2. 6 | 7 | 8 | 9 | If you find this code useful, please reference in your paper: 10 | 11 | ``` 12 | @article{hafner2019dreamer, 13 | title={Dream to Control: Learning Behaviors by Latent Imagination}, 14 | author={Hafner, Danijar and Lillicrap, Timothy and Ba, Jimmy and Norouzi, Mohammad}, 15 | journal={arXiv preprint arXiv:1912.01603}, 16 | year={2019} 17 | } 18 | ``` 19 | 20 | ## Method 21 | 22 | ![Dreamer](https://imgur.com/JrXC4rh.png) 23 | 24 | Dreamer learns a world model that predicts ahead in a compact feature space. 25 | From imagined feature sequences, it learns a policy and state-value function. 26 | The value gradients are backpropagated through the multi-step predictions to 27 | efficiently learn a long-horizon policy. 28 | 29 | - [Project website][website] 30 | - [Research paper][paper] 31 | - [Official implementation][code] (TensorFlow 1) 32 | 33 | [website]: https://danijar.com/dreamer 34 | [paper]: https://arxiv.org/pdf/1912.01603.pdf 35 | [code]: https://github.com/google-research/dreamer 36 | 37 | ## Instructions 38 | 39 | Get dependencies: 40 | 41 | ``` 42 | pip3 install --user tensorflow-gpu==2.2.0 43 | pip3 install --user tensorflow_probability 44 | pip3 install --user git+git://github.com/deepmind/dm_control.git 45 | pip3 install --user pandas 46 | pip3 install --user matplotlib 47 | ``` 48 | 49 | Train the agent: 50 | 51 | ``` 52 | python3 dreamer.py --logdir ./logdir/dmc_walker_walk/dreamer/1 --task dmc_walker_walk 53 | ``` 54 | 55 | Generate plots: 56 | 57 | ``` 58 | python3 plotting.py --indir ./logdir --outdir ./plots --xaxis step --yaxis test/return --bins 3e4 59 | ``` 60 | 61 | Graphs and GIFs: 62 | 63 | ``` 64 | tensorboard --logdir ./logdir 65 | ``` 66 | -------------------------------------------------------------------------------- /dreamer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import functools 4 | import json 5 | import os 6 | import pathlib 7 | import sys 8 | import time 9 | 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 11 | os.environ['MUJOCO_GL'] = 'egl' 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | from tensorflow.keras.mixed_precision import experimental as prec 16 | 17 | tf.get_logger().setLevel('ERROR') 18 | 19 | from tensorflow_probability import distributions as tfd 20 | 21 | sys.path.append(str(pathlib.Path(__file__).parent)) 22 | 23 | import models 24 | import tools 25 | import wrappers 26 | 27 | 28 | def define_config(): 29 | config = tools.AttrDict() 30 | # General. 31 | config.logdir = pathlib.Path('.') 32 | config.seed = 0 33 | config.steps = 5e6 34 | config.eval_every = 1e4 35 | config.log_every = 1e3 36 | config.log_scalars = True 37 | config.log_images = True 38 | config.gpu_growth = True 39 | config.precision = 16 40 | # Environment. 41 | config.task = 'dmc_walker_walk' 42 | config.envs = 1 43 | config.parallel = 'none' 44 | config.action_repeat = 2 45 | config.time_limit = 1000 46 | config.prefill = 5000 47 | config.eval_noise = 0.0 48 | config.clip_rewards = 'none' 49 | # Model. 50 | config.deter_size = 200 51 | config.stoch_size = 30 52 | config.num_units = 400 53 | config.dense_act = 'elu' 54 | config.cnn_act = 'relu' 55 | config.cnn_depth = 32 56 | config.pcont = False 57 | config.free_nats = 3.0 58 | config.kl_scale = 1.0 59 | config.pcont_scale = 10.0 60 | config.weight_decay = 0.0 61 | config.weight_decay_pattern = r'.*' 62 | # Training. 63 | config.batch_size = 50 64 | config.batch_length = 50 65 | config.train_every = 1000 66 | config.train_steps = 100 67 | config.pretrain = 100 68 | config.model_lr = 6e-4 69 | config.value_lr = 8e-5 70 | config.actor_lr = 8e-5 71 | config.grad_clip = 100.0 72 | config.dataset_balance = False 73 | # Behavior. 74 | config.discount = 0.99 75 | config.disclam = 0.95 76 | config.horizon = 15 77 | config.action_dist = 'tanh_normal' 78 | config.action_init_std = 5.0 79 | config.expl = 'additive_gaussian' 80 | config.expl_amount = 0.3 81 | config.expl_decay = 0.0 82 | config.expl_min = 0.0 83 | return config 84 | 85 | 86 | class Dreamer(tools.Module): 87 | 88 | def __init__(self, config, datadir, actspace, writer): 89 | self._c = config 90 | self._actspace = actspace 91 | self._actdim = actspace.n if hasattr(actspace, 'n') else actspace.shape[0] 92 | self._writer = writer 93 | self._random = np.random.RandomState(config.seed) 94 | with tf.device('cpu:0'): 95 | self._step = tf.Variable(count_steps(datadir, config), dtype=tf.int64) 96 | self._should_pretrain = tools.Once() 97 | self._should_train = tools.Every(config.train_every) 98 | self._should_log = tools.Every(config.log_every) 99 | self._last_log = None 100 | self._last_time = time.time() 101 | self._metrics = collections.defaultdict(tf.metrics.Mean) 102 | self._metrics['expl_amount'] # Create variable for checkpoint. 103 | self._float = prec.global_policy().compute_dtype 104 | self._strategy = tf.distribute.MirroredStrategy() 105 | with self._strategy.scope(): 106 | self._dataset = iter(self._strategy.experimental_distribute_dataset( 107 | load_dataset(datadir, self._c))) 108 | self._build_model() 109 | 110 | def __call__(self, obs, reset, state=None, training=True): 111 | step = self._step.numpy().item() 112 | tf.summary.experimental.set_step(step) 113 | if state is not None and reset.any(): 114 | mask = tf.cast(1 - reset, self._float)[:, None] 115 | state = tf.nest.map_structure(lambda x: x * mask, state) 116 | if self._should_train(step): 117 | log = self._should_log(step) 118 | n = self._c.pretrain if self._should_pretrain() else self._c.train_steps 119 | print(f'Training for {n} steps.') 120 | with self._strategy.scope(): 121 | for train_step in range(n): 122 | log_images = self._c.log_images and log and train_step == 0 123 | self.train(next(self._dataset), log_images) 124 | if log: 125 | self._write_summaries() 126 | action, state = self.policy(obs, state, training) 127 | if training: 128 | self._step.assign_add(len(reset) * self._c.action_repeat) 129 | return action, state 130 | 131 | @tf.function 132 | def policy(self, obs, state, training): 133 | if state is None: 134 | latent = self._dynamics.initial(len(obs['image'])) 135 | action = tf.zeros((len(obs['image']), self._actdim), self._float) 136 | else: 137 | latent, action = state 138 | embed = self._encode(preprocess(obs, self._c)) 139 | latent, _ = self._dynamics.obs_step(latent, action, embed) 140 | feat = self._dynamics.get_feat(latent) 141 | if training: 142 | action = self._actor(feat).sample() 143 | else: 144 | action = self._actor(feat).mode() 145 | action = self._exploration(action, training) 146 | state = (latent, action) 147 | return action, state 148 | 149 | def load(self, filename): 150 | super().load(filename) 151 | self._should_pretrain() 152 | 153 | @tf.function() 154 | def train(self, data, log_images=False): 155 | self._strategy.experimental_run_v2(self._train, args=(data, log_images)) 156 | 157 | def _train(self, data, log_images): 158 | with tf.GradientTape() as model_tape: 159 | embed = self._encode(data) 160 | post, prior = self._dynamics.observe(embed, data['action']) 161 | feat = self._dynamics.get_feat(post) 162 | image_pred = self._decode(feat) 163 | reward_pred = self._reward(feat) 164 | likes = tools.AttrDict() 165 | likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) 166 | likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) 167 | if self._c.pcont: 168 | pcont_pred = self._pcont(feat) 169 | pcont_target = self._c.discount * data['discount'] 170 | likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) 171 | likes.pcont *= self._c.pcont_scale 172 | prior_dist = self._dynamics.get_dist(prior) 173 | post_dist = self._dynamics.get_dist(post) 174 | div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) 175 | div = tf.maximum(div, self._c.free_nats) 176 | model_loss = self._c.kl_scale * div - sum(likes.values()) 177 | model_loss /= float(self._strategy.num_replicas_in_sync) 178 | 179 | with tf.GradientTape() as actor_tape: 180 | imag_feat = self._imagine_ahead(post) 181 | reward = self._reward(imag_feat).mode() 182 | if self._c.pcont: 183 | pcont = self._pcont(imag_feat).mean() 184 | else: 185 | pcont = self._c.discount * tf.ones_like(reward) 186 | value = self._value(imag_feat).mode() 187 | returns = tools.lambda_return( 188 | reward[:-1], value[:-1], pcont[:-1], 189 | bootstrap=value[-1], lambda_=self._c.disclam, axis=0) 190 | discount = tf.stop_gradient(tf.math.cumprod(tf.concat( 191 | [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) 192 | actor_loss = -tf.reduce_mean(discount * returns) 193 | actor_loss /= float(self._strategy.num_replicas_in_sync) 194 | 195 | with tf.GradientTape() as value_tape: 196 | value_pred = self._value(imag_feat)[:-1] 197 | target = tf.stop_gradient(returns) 198 | value_loss = -tf.reduce_mean(discount * value_pred.log_prob(target)) 199 | value_loss /= float(self._strategy.num_replicas_in_sync) 200 | 201 | model_norm = self._model_opt(model_tape, model_loss) 202 | actor_norm = self._actor_opt(actor_tape, actor_loss) 203 | value_norm = self._value_opt(value_tape, value_loss) 204 | 205 | if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: 206 | if self._c.log_scalars: 207 | self._scalar_summaries( 208 | data, feat, prior_dist, post_dist, likes, div, 209 | model_loss, value_loss, actor_loss, model_norm, value_norm, 210 | actor_norm) 211 | if tf.equal(log_images, True): 212 | self._image_summaries(data, embed, image_pred) 213 | 214 | def _build_model(self): 215 | acts = dict( 216 | elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish, 217 | leaky_relu=tf.nn.leaky_relu) 218 | cnn_act = acts[self._c.cnn_act] 219 | act = acts[self._c.dense_act] 220 | self._encode = models.ConvEncoder(self._c.cnn_depth, cnn_act) 221 | self._dynamics = models.RSSM( 222 | self._c.stoch_size, self._c.deter_size, self._c.deter_size) 223 | self._decode = models.ConvDecoder(self._c.cnn_depth, cnn_act) 224 | self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) 225 | if self._c.pcont: 226 | self._pcont = models.DenseDecoder( 227 | (), 3, self._c.num_units, 'binary', act=act) 228 | self._value = models.DenseDecoder((), 3, self._c.num_units, act=act) 229 | self._actor = models.ActionDecoder( 230 | self._actdim, 4, self._c.num_units, self._c.action_dist, 231 | init_std=self._c.action_init_std, act=act) 232 | model_modules = [self._encode, self._dynamics, self._decode, self._reward] 233 | if self._c.pcont: 234 | model_modules.append(self._pcont) 235 | Optimizer = functools.partial( 236 | tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip, 237 | wdpattern=self._c.weight_decay_pattern) 238 | self._model_opt = Optimizer('model', model_modules, self._c.model_lr) 239 | self._value_opt = Optimizer('value', [self._value], self._c.value_lr) 240 | self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) 241 | # Do a train step to initialize all variables, including optimizer 242 | # statistics. Ideally, we would use batch size zero, but that doesn't work 243 | # in multi-GPU mode. 244 | self.train(next(self._dataset)) 245 | 246 | def _exploration(self, action, training): 247 | if training: 248 | amount = self._c.expl_amount 249 | if self._c.expl_decay: 250 | amount *= 0.5 ** (tf.cast(self._step, tf.float32) / self._c.expl_decay) 251 | if self._c.expl_min: 252 | amount = tf.maximum(self._c.expl_min, amount) 253 | self._metrics['expl_amount'].update_state(amount) 254 | elif self._c.eval_noise: 255 | amount = self._c.eval_noise 256 | else: 257 | return action 258 | if self._c.expl == 'additive_gaussian': 259 | return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) 260 | if self._c.expl == 'completely_random': 261 | return tf.random.uniform(action.shape, -1, 1) 262 | if self._c.expl == 'epsilon_greedy': 263 | indices = tfd.Categorical(0 * action).sample() 264 | return tf.where( 265 | tf.random.uniform(action.shape[:1], 0, 1) < amount, 266 | tf.one_hot(indices, action.shape[-1], dtype=self._float), 267 | action) 268 | raise NotImplementedError(self._c.expl) 269 | 270 | def _imagine_ahead(self, post): 271 | if self._c.pcont: # Last step could be terminal. 272 | post = {k: v[:, :-1] for k, v in post.items()} 273 | flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:])) 274 | start = {k: flatten(v) for k, v in post.items()} 275 | policy = lambda state: self._actor( 276 | tf.stop_gradient(self._dynamics.get_feat(state))).sample() 277 | states = tools.static_scan( 278 | lambda prev, _: self._dynamics.img_step(prev, policy(prev)), 279 | tf.range(self._c.horizon), start) 280 | imag_feat = self._dynamics.get_feat(states) 281 | return imag_feat 282 | 283 | def _scalar_summaries( 284 | self, data, feat, prior_dist, post_dist, likes, div, 285 | model_loss, value_loss, actor_loss, model_norm, value_norm, 286 | actor_norm): 287 | self._metrics['model_grad_norm'].update_state(model_norm) 288 | self._metrics['value_grad_norm'].update_state(value_norm) 289 | self._metrics['actor_grad_norm'].update_state(actor_norm) 290 | self._metrics['prior_ent'].update_state(prior_dist.entropy()) 291 | self._metrics['post_ent'].update_state(post_dist.entropy()) 292 | for name, logprob in likes.items(): 293 | self._metrics[name + '_loss'].update_state(-logprob) 294 | self._metrics['div'].update_state(div) 295 | self._metrics['model_loss'].update_state(model_loss) 296 | self._metrics['value_loss'].update_state(value_loss) 297 | self._metrics['actor_loss'].update_state(actor_loss) 298 | self._metrics['action_ent'].update_state(self._actor(feat).entropy()) 299 | 300 | def _image_summaries(self, data, embed, image_pred): 301 | truth = data['image'][:6] + 0.5 302 | recon = image_pred.mode()[:6] 303 | init, _ = self._dynamics.observe(embed[:6, :5], data['action'][:6, :5]) 304 | init = {k: v[:, -1] for k, v in init.items()} 305 | prior = self._dynamics.imagine(data['action'][:6, 5:], init) 306 | openl = self._decode(self._dynamics.get_feat(prior)).mode() 307 | model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) 308 | error = (model - truth + 1) / 2 309 | openl = tf.concat([truth, model, error], 2) 310 | tools.graph_summary( 311 | self._writer, tools.video_summary, 'agent/openl', openl) 312 | 313 | def _write_summaries(self): 314 | step = int(self._step.numpy()) 315 | metrics = [(k, float(v.result())) for k, v in self._metrics.items()] 316 | if self._last_log is not None: 317 | duration = time.time() - self._last_time 318 | self._last_time += duration 319 | metrics.append(('fps', (step - self._last_log) / duration)) 320 | self._last_log = step 321 | [m.reset_states() for m in self._metrics.values()] 322 | with (self._c.logdir / 'metrics.jsonl').open('a') as f: 323 | f.write(json.dumps({'step': step, **dict(metrics)}) + '\n') 324 | [tf.summary.scalar('agent/' + k, m) for k, m in metrics] 325 | print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics)) 326 | self._writer.flush() 327 | 328 | 329 | def preprocess(obs, config): 330 | dtype = prec.global_policy().compute_dtype 331 | obs = obs.copy() 332 | with tf.device('cpu:0'): 333 | obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 334 | clip_rewards = dict(none=lambda x: x, tanh=tf.tanh)[config.clip_rewards] 335 | obs['reward'] = clip_rewards(obs['reward']) 336 | return obs 337 | 338 | 339 | def count_steps(datadir, config): 340 | return tools.count_episodes(datadir)[1] * config.action_repeat 341 | 342 | 343 | def load_dataset(directory, config): 344 | episode = next(tools.load_episodes(directory, 1)) 345 | types = {k: v.dtype for k, v in episode.items()} 346 | shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()} 347 | generator = lambda: tools.load_episodes( 348 | directory, config.train_steps, config.batch_length, 349 | config.dataset_balance) 350 | dataset = tf.data.Dataset.from_generator(generator, types, shapes) 351 | dataset = dataset.batch(config.batch_size, drop_remainder=True) 352 | dataset = dataset.map(functools.partial(preprocess, config=config)) 353 | dataset = dataset.prefetch(10) 354 | return dataset 355 | 356 | 357 | def summarize_episode(episode, config, datadir, writer, prefix): 358 | episodes, steps = tools.count_episodes(datadir) 359 | length = (len(episode['reward']) - 1) * config.action_repeat 360 | ret = episode['reward'].sum() 361 | print(f'{prefix.title()} episode of length {length} with return {ret:.1f}.') 362 | metrics = [ 363 | (f'{prefix}/return', float(episode['reward'].sum())), 364 | (f'{prefix}/length', len(episode['reward']) - 1), 365 | (f'episodes', episodes)] 366 | step = count_steps(datadir, config) 367 | with (config.logdir / 'metrics.jsonl').open('a') as f: 368 | f.write(json.dumps(dict([('step', step)] + metrics)) + '\n') 369 | with writer.as_default(): # Env might run in a different thread. 370 | tf.summary.experimental.set_step(step) 371 | [tf.summary.scalar('sim/' + k, v) for k, v in metrics] 372 | if prefix == 'test': 373 | tools.video_summary(f'sim/{prefix}/video', episode['image'][None]) 374 | 375 | 376 | def make_env(config, writer, prefix, datadir, store): 377 | suite, task = config.task.split('_', 1) 378 | if suite == 'dmc': 379 | env = wrappers.DeepMindControl(task) 380 | env = wrappers.ActionRepeat(env, config.action_repeat) 381 | env = wrappers.NormalizeActions(env) 382 | elif suite == 'atari': 383 | env = wrappers.Atari( 384 | task, config.action_repeat, (64, 64), grayscale=False, 385 | life_done=True, sticky_actions=True) 386 | env = wrappers.OneHotAction(env) 387 | else: 388 | raise NotImplementedError(suite) 389 | env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) 390 | callbacks = [] 391 | if store: 392 | callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) 393 | callbacks.append( 394 | lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) 395 | env = wrappers.Collect(env, callbacks, config.precision) 396 | env = wrappers.RewardObs(env) 397 | return env 398 | 399 | 400 | def main(config): 401 | if config.gpu_growth: 402 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 403 | tf.config.experimental.set_memory_growth(gpu, True) 404 | assert config.precision in (16, 32), config.precision 405 | if config.precision == 16: 406 | prec.set_policy(prec.Policy('mixed_float16')) 407 | config.steps = int(config.steps) 408 | config.logdir.mkdir(parents=True, exist_ok=True) 409 | print('Logdir', config.logdir) 410 | 411 | # Create environments. 412 | datadir = config.logdir / 'episodes' 413 | writer = tf.summary.create_file_writer( 414 | str(config.logdir), max_queue=1000, flush_millis=20000) 415 | writer.set_as_default() 416 | train_envs = [wrappers.Async(lambda: make_env( 417 | config, writer, 'train', datadir, store=True), config.parallel) 418 | for _ in range(config.envs)] 419 | test_envs = [wrappers.Async(lambda: make_env( 420 | config, writer, 'test', datadir, store=False), config.parallel) 421 | for _ in range(config.envs)] 422 | actspace = train_envs[0].action_space 423 | 424 | # Prefill dataset with random episodes. 425 | step = count_steps(datadir, config) 426 | prefill = max(0, config.prefill - step) 427 | print(f'Prefill dataset with {prefill} steps.') 428 | random_agent = lambda o, d, _: ([actspace.sample() for _ in d], None) 429 | tools.simulate(random_agent, train_envs, prefill / config.action_repeat) 430 | writer.flush() 431 | 432 | # Train and regularly evaluate the agent. 433 | step = count_steps(datadir, config) 434 | print(f'Simulating agent for {config.steps-step} steps.') 435 | agent = Dreamer(config, datadir, actspace, writer) 436 | if (config.logdir / 'variables.pkl').exists(): 437 | print('Load checkpoint.') 438 | agent.load(config.logdir / 'variables.pkl') 439 | state = None 440 | while step < config.steps: 441 | print('Start evaluation.') 442 | tools.simulate( 443 | functools.partial(agent, training=False), test_envs, episodes=1) 444 | writer.flush() 445 | print('Start collection.') 446 | steps = config.eval_every // config.action_repeat 447 | state = tools.simulate(agent, train_envs, steps, state=state) 448 | step = count_steps(datadir, config) 449 | agent.save(config.logdir / 'variables.pkl') 450 | for env in train_envs + test_envs: 451 | env.close() 452 | 453 | 454 | if __name__ == '__main__': 455 | try: 456 | import colored_traceback 457 | colored_traceback.add_hook() 458 | except ImportError: 459 | pass 460 | parser = argparse.ArgumentParser() 461 | for key, value in define_config().items(): 462 | parser.add_argument(f'--{key}', type=tools.args_type(value), default=value) 463 | main(parser.parse_args()) 464 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras import layers as tfkl 4 | from tensorflow_probability import distributions as tfd 5 | from tensorflow.keras.mixed_precision import experimental as prec 6 | 7 | import tools 8 | 9 | 10 | class RSSM(tools.Module): 11 | 12 | def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu): 13 | super().__init__() 14 | self._activation = act 15 | self._stoch_size = stoch 16 | self._deter_size = deter 17 | self._hidden_size = hidden 18 | self._cell = tfkl.GRUCell(self._deter_size) 19 | 20 | def initial(self, batch_size): 21 | dtype = prec.global_policy().compute_dtype 22 | return dict( 23 | mean=tf.zeros([batch_size, self._stoch_size], dtype), 24 | std=tf.zeros([batch_size, self._stoch_size], dtype), 25 | stoch=tf.zeros([batch_size, self._stoch_size], dtype), 26 | deter=self._cell.get_initial_state(None, batch_size, dtype)) 27 | 28 | @tf.function 29 | def observe(self, embed, action, state=None): 30 | if state is None: 31 | state = self.initial(tf.shape(action)[0]) 32 | embed = tf.transpose(embed, [1, 0, 2]) 33 | action = tf.transpose(action, [1, 0, 2]) 34 | post, prior = tools.static_scan( 35 | lambda prev, inputs: self.obs_step(prev[0], *inputs), 36 | (action, embed), (state, state)) 37 | post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()} 38 | prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} 39 | return post, prior 40 | 41 | @tf.function 42 | def imagine(self, action, state=None): 43 | if state is None: 44 | state = self.initial(tf.shape(action)[0]) 45 | assert isinstance(state, dict), state 46 | action = tf.transpose(action, [1, 0, 2]) 47 | prior = tools.static_scan(self.img_step, action, state) 48 | prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} 49 | return prior 50 | 51 | def get_feat(self, state): 52 | return tf.concat([state['stoch'], state['deter']], -1) 53 | 54 | def get_dist(self, state): 55 | return tfd.MultivariateNormalDiag(state['mean'], state['std']) 56 | 57 | @tf.function 58 | def obs_step(self, prev_state, prev_action, embed): 59 | prior = self.img_step(prev_state, prev_action) 60 | x = tf.concat([prior['deter'], embed], -1) 61 | x = self.get('obs1', tfkl.Dense, self._hidden_size, self._activation)(x) 62 | x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x) 63 | mean, std = tf.split(x, 2, -1) 64 | std = tf.nn.softplus(std) + 0.1 65 | stoch = self.get_dist({'mean': mean, 'std': std}).sample() 66 | post = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': prior['deter']} 67 | return post, prior 68 | 69 | @tf.function 70 | def img_step(self, prev_state, prev_action): 71 | x = tf.concat([prev_state['stoch'], prev_action], -1) 72 | x = self.get('img1', tfkl.Dense, self._hidden_size, self._activation)(x) 73 | x, deter = self._cell(x, [prev_state['deter']]) 74 | deter = deter[0] # Keras wraps the state in a list. 75 | x = self.get('img2', tfkl.Dense, self._hidden_size, self._activation)(x) 76 | x = self.get('img3', tfkl.Dense, 2 * self._stoch_size, None)(x) 77 | mean, std = tf.split(x, 2, -1) 78 | std = tf.nn.softplus(std) + 0.1 79 | stoch = self.get_dist({'mean': mean, 'std': std}).sample() 80 | prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter} 81 | return prior 82 | 83 | 84 | class ConvEncoder(tools.Module): 85 | 86 | def __init__(self, depth=32, act=tf.nn.relu): 87 | self._act = act 88 | self._depth = depth 89 | 90 | def __call__(self, obs): 91 | kwargs = dict(strides=2, activation=self._act) 92 | x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:])) 93 | x = self.get('h1', tfkl.Conv2D, 1 * self._depth, 4, **kwargs)(x) 94 | x = self.get('h2', tfkl.Conv2D, 2 * self._depth, 4, **kwargs)(x) 95 | x = self.get('h3', tfkl.Conv2D, 4 * self._depth, 4, **kwargs)(x) 96 | x = self.get('h4', tfkl.Conv2D, 8 * self._depth, 4, **kwargs)(x) 97 | shape = tf.concat([tf.shape(obs['image'])[:-3], [32 * self._depth]], 0) 98 | return tf.reshape(x, shape) 99 | 100 | 101 | class ConvDecoder(tools.Module): 102 | 103 | def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)): 104 | self._act = act 105 | self._depth = depth 106 | self._shape = shape 107 | 108 | def __call__(self, features): 109 | kwargs = dict(strides=2, activation=self._act) 110 | x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features) 111 | x = tf.reshape(x, [-1, 1, 1, 32 * self._depth]) 112 | x = self.get('h2', tfkl.Conv2DTranspose, 4 * self._depth, 5, **kwargs)(x) 113 | x = self.get('h3', tfkl.Conv2DTranspose, 2 * self._depth, 5, **kwargs)(x) 114 | x = self.get('h4', tfkl.Conv2DTranspose, 1 * self._depth, 6, **kwargs)(x) 115 | x = self.get('h5', tfkl.Conv2DTranspose, self._shape[-1], 6, strides=2)(x) 116 | mean = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) 117 | return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)) 118 | 119 | 120 | class DenseDecoder(tools.Module): 121 | 122 | def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu): 123 | self._shape = shape 124 | self._layers = layers 125 | self._units = units 126 | self._dist = dist 127 | self._act = act 128 | 129 | def __call__(self, features): 130 | x = features 131 | for index in range(self._layers): 132 | x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) 133 | x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x) 134 | x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) 135 | if self._dist == 'normal': 136 | return tfd.Independent(tfd.Normal(x, 1), len(self._shape)) 137 | if self._dist == 'binary': 138 | return tfd.Independent(tfd.Bernoulli(x), len(self._shape)) 139 | raise NotImplementedError(self._dist) 140 | 141 | 142 | class ActionDecoder(tools.Module): 143 | 144 | def __init__( 145 | self, size, layers, units, dist='tanh_normal', act=tf.nn.elu, 146 | min_std=1e-4, init_std=5, mean_scale=5): 147 | self._size = size 148 | self._layers = layers 149 | self._units = units 150 | self._dist = dist 151 | self._act = act 152 | self._min_std = min_std 153 | self._init_std = init_std 154 | self._mean_scale = mean_scale 155 | 156 | def __call__(self, features): 157 | raw_init_std = np.log(np.exp(self._init_std) - 1) 158 | x = features 159 | for index in range(self._layers): 160 | x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) 161 | if self._dist == 'tanh_normal': 162 | # https://www.desmos.com/calculator/rcmcf5jwe7 163 | x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) 164 | mean, std = tf.split(x, 2, -1) 165 | mean = self._mean_scale * tf.tanh(mean / self._mean_scale) 166 | std = tf.nn.softplus(std + raw_init_std) + self._min_std 167 | dist = tfd.Normal(mean, std) 168 | dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) 169 | dist = tfd.Independent(dist, 1) 170 | dist = tools.SampleDist(dist) 171 | elif self._dist == 'onehot': 172 | x = self.get(f'hout', tfkl.Dense, self._size)(x) 173 | dist = tools.OneHotDist(x) 174 | else: 175 | raise NotImplementedError(dist) 176 | return dist 177 | -------------------------------------------------------------------------------- /plotting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import functools 4 | import json 5 | import multiprocessing as mp 6 | import pathlib 7 | import re 8 | import subprocess 9 | 10 | import matplotlib.pyplot as plt 11 | import matplotlib.ticker as ticker 12 | import numpy as np 13 | import pandas as pd 14 | 15 | # import matplotlib 16 | # matplotlib.rcParams['mathtext.fontset'] = 'stix' 17 | # matplotlib.rcParams['font.family'] = 'STIXGeneral' 18 | 19 | Run = collections.namedtuple('Run', 'task method seed xs ys color') 20 | 21 | PALETTE = 10 * ( 22 | '#377eb8', '#4daf4a', '#984ea3', '#e41a1c', '#ff7f00', '#a65628', 23 | '#f781bf', '#888888', '#a6cee3', '#b2df8a', '#cab2d6', '#fb9a99', 24 | '#fdbf6f') 25 | 26 | LEGEND = dict( 27 | fontsize='medium', numpoints=1, labelspacing=0, columnspacing=1.2, 28 | handlelength=1.5, handletextpad=0.5, ncol=4, loc='lower center') 29 | 30 | 31 | def find_keys(args): 32 | filename = next(args.indir[0].glob('**/*.jsonl')) 33 | keys = set() 34 | for line in filename.read_text().split('\n'): 35 | if line: 36 | keys |= json.loads(line).keys() 37 | print(f'Keys ({len(keys)}):', ', '.join(keys), flush=True) 38 | 39 | 40 | def load_runs(args): 41 | toload = [] 42 | for indir in args.indir: 43 | filenames = list(indir.glob('**/*.jsonl')) 44 | for filename in filenames: 45 | task, method, seed = filename.relative_to(indir).parts[:-1] 46 | if not any(p.search(task) for p in args.tasks): 47 | continue 48 | if not any(p.search(method) for p in args.methods): 49 | continue 50 | if method not in args.colors: 51 | args.colors[method] = args.palette[len(args.colors)] 52 | toload.append((filename, indir)) 53 | print(f'Loading {len(toload)} of {len(filenames)} runs...') 54 | jobs = [functools.partial(load_run, f, i, args) for f, i in toload] 55 | with mp.Pool(10) as pool: 56 | promises = [pool.apply_async(j) for j in jobs] 57 | runs = [p.get() for p in promises] 58 | runs = [r for r in runs if r is not None] 59 | return runs 60 | 61 | 62 | def load_run(filename, indir, args): 63 | task, method, seed = filename.relative_to(indir).parts[:-1] 64 | try: 65 | # Future pandas releases will support JSON files with NaN values. 66 | # df = pd.read_json(filename, lines=True) 67 | with filename.open() as f: 68 | df = pd.DataFrame([json.loads(l) for l in f.readlines()]) 69 | except ValueError as e: 70 | print('Invalid', filename.relative_to(indir), e) 71 | return 72 | try: 73 | df = df[[args.xaxis, args.yaxis]].dropna() 74 | except KeyError: 75 | return 76 | xs = df[args.xaxis].to_numpy() 77 | ys = df[args.yaxis].to_numpy() 78 | color = args.colors[method] 79 | return Run(task, method, seed, xs, ys, color) 80 | 81 | 82 | def load_baselines(args): 83 | runs = [] 84 | directory = pathlib.Path(__file__).parent / 'baselines' 85 | for filename in directory.glob('**/*.json'): 86 | for task, methods in json.loads(filename.read_text()).items(): 87 | for method, score in methods.items(): 88 | if not any(p.search(method) for p in args.baselines): 89 | continue 90 | if method not in args.colors: 91 | args.colors[method] = args.palette[len(args.colors)] 92 | color = args.colors[method] 93 | runs.append(Run(task, method, None, None, score, color)) 94 | return runs 95 | 96 | 97 | def stats(runs): 98 | baselines = sorted(set(r.method for r in runs if r.xs is None)) 99 | runs = [r for r in runs if r.xs is not None] 100 | tasks = sorted(set(r.task for r in runs)) 101 | methods = sorted(set(r.method for r in runs)) 102 | seeds = sorted(set(r.seed for r in runs)) 103 | print('Loaded', len(runs), 'runs.') 104 | print(f'Tasks ({len(tasks)}):', ', '.join(tasks)) 105 | print(f'Methods ({len(methods)}):', ', '.join(methods)) 106 | print(f'Seeds ({len(seeds)}):', ', '.join(seeds)) 107 | print(f'Baselines ({len(baselines)}):', ', '.join(baselines)) 108 | 109 | 110 | def figure(runs, args): 111 | tasks = sorted(set(r.task for r in runs if r.xs is not None)) 112 | rows = int(np.ceil(len(tasks) / args.cols)) 113 | figsize = args.size[0] * args.cols, args.size[1] * rows 114 | fig, axes = plt.subplots(rows, args.cols, figsize=figsize) 115 | for task, ax in zip(tasks, axes.flatten()): 116 | relevant = [r for r in runs if r.task == task] 117 | plot(task, ax, relevant, args) 118 | if args.xlim: 119 | for ax in axes[:-1].flatten(): 120 | ax.xaxis.get_offset_text().set_visible(False) 121 | if args.xlabel: 122 | for ax in axes[-1]: 123 | ax.set_xlabel(args.xlabel) 124 | if args.ylabel: 125 | for ax in axes[:, 0]: 126 | ax.set_ylabel(args.ylabel) 127 | for ax in axes[len(tasks):]: 128 | ax.axis('off') 129 | legend(fig, args.labels, **LEGEND) 130 | return fig 131 | 132 | 133 | def plot(task, ax, runs, args): 134 | try: 135 | title = task.split('_', 1)[1].replace('_', ' ').title() 136 | except IndexError: 137 | title = task.title() 138 | ax.set_title(title) 139 | methods = [] 140 | methods += sorted(set(r.method for r in runs if r.xs is not None)) 141 | methods += sorted(set(r.method for r in runs if r.xs is None)) 142 | xlim = [+np.inf, -np.inf] 143 | for index, method in enumerate(methods): 144 | relevant = [r for r in runs if r.method == method] 145 | if not relevant: 146 | continue 147 | if any(r.xs is None for r in relevant): 148 | baseline(index, method, ax, relevant, args) 149 | else: 150 | if args.aggregate == 'std': 151 | xs, ys = curve_std(index, method, ax, relevant, args) 152 | elif args.aggregate == 'none': 153 | xs, ys = curve_individual(index, method, ax, relevant, args) 154 | else: 155 | raise NotImplementedError(args.aggregate) 156 | xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())] 157 | ax.ticklabel_format(axis='x', style='sci', scilimits=(0, 0)) 158 | steps = [1, 2, 2.5, 5, 10] 159 | ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps)) 160 | ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps)) 161 | ax.set_xlim(args.xlim or xlim) 162 | if args.xlim: 163 | ticks = sorted({*ax.get_xticks(), *args.xlim}) 164 | ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]] 165 | ax.set_xticks(ticks) 166 | if args.ylim: 167 | ax.set_ylim(args.ylim) 168 | ticks = sorted({*ax.get_yticks(), *args.ylim}) 169 | ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]] 170 | ax.set_yticks(ticks) 171 | 172 | 173 | def curve_individual(index, method, ax, runs, args): 174 | if args.bins: 175 | for index, run in enumerate(runs): 176 | xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean) 177 | runs[index] = run._replace(xs=xs, ys=ys) 178 | zorder = 10000 - 10 * index - 1 179 | for run in runs: 180 | ax.plot(run.xs, run.ys, label=method, color=run.color, zorder=zorder) 181 | return runs[0].xs, runs[0].ys 182 | 183 | 184 | def curve_std(index, method, ax, runs, args): 185 | if args.bins: 186 | for index, run in enumerate(runs): 187 | xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean) 188 | runs[index] = run._replace(xs=xs, ys=ys) 189 | xs = np.concatenate([r.xs for r in runs]) 190 | ys = np.concatenate([r.ys for r in runs]) 191 | order = np.argsort(xs) 192 | xs, ys = xs[order], ys[order] 193 | color = runs[0].color 194 | if args.bins: 195 | reducer = lambda y: (np.nanmean(np.array(y)), np.nanstd(np.array(y))) 196 | xs, ys = binning(xs, ys, args.bins, reducer) 197 | ys, std = ys.T 198 | kw = dict(color=color, zorder=10000 - 10 * index, alpha=0.1, linewidths=0) 199 | ax.fill_between(xs, ys - std, ys + std, **kw) 200 | ax.plot(xs, ys, label=method, color=color, zorder=10000 - 10 * index - 1) 201 | return xs, ys 202 | 203 | 204 | def baseline(index, method, ax, runs, args): 205 | assert len(runs) == 1 and runs[0].xs is None 206 | y = np.mean(runs[0].ys) 207 | kw = dict(ls='--', color=runs[0].color, zorder=5000 - 10 * index - 1) 208 | ax.axhline(y, label=method, **kw) 209 | 210 | 211 | def binning(xs, ys, bins, reducer): 212 | binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins) 213 | binned_ys = [] 214 | for start, stop in zip([-np.inf] + list(binned_xs), binned_xs): 215 | left = (xs <= start).sum() 216 | right = (xs <= stop).sum() 217 | binned_ys.append(reducer(ys[left:right])) 218 | binned_ys = np.array(binned_ys) 219 | return binned_xs, binned_ys 220 | 221 | 222 | def legend(fig, mapping=None, **kwargs): 223 | entries = {} 224 | for ax in fig.axes: 225 | for handle, label in zip(*ax.get_legend_handles_labels()): 226 | if mapping and label in mapping: 227 | label = mapping[label] 228 | entries[label] = handle 229 | leg = fig.legend(entries.values(), entries.keys(), **kwargs) 230 | leg.get_frame().set_edgecolor('white') 231 | extent = leg.get_window_extent(fig.canvas.get_renderer()) 232 | extent = extent.transformed(fig.transFigure.inverted()) 233 | yloc, xloc = kwargs['loc'].split() 234 | y0 = dict(lower=extent.y1, center=0, upper=0)[yloc] 235 | y1 = dict(lower=1, center=1, upper=extent.y0)[yloc] 236 | x0 = dict(left=extent.x1, center=0, right=0)[xloc] 237 | x1 = dict(left=1, center=1, right=extent.x0)[xloc] 238 | fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5) 239 | 240 | 241 | def save(fig, args): 242 | args.outdir.mkdir(parents=True, exist_ok=True) 243 | filename = args.outdir / 'curves.png' 244 | fig.savefig(filename, dpi=130) 245 | print('Saved to', filename) 246 | filename = args.outdir / 'curves.pdf' 247 | fig.savefig(filename) 248 | try: 249 | subprocess.call(['pdfcrop', str(filename), str(filename)]) 250 | except FileNotFoundError: 251 | pass # Install texlive-extra-utils. 252 | 253 | 254 | def main(args): 255 | find_keys(args) 256 | runs = load_runs(args) + load_baselines(args) 257 | stats(runs) 258 | if not runs: 259 | print('Noting to plot.') 260 | return 261 | print('Plotting...') 262 | fig = figure(runs, args) 263 | save(fig, args) 264 | 265 | 266 | def parse_args(): 267 | boolean = lambda x: bool(['False', 'True'].index(x)) 268 | parser = argparse.ArgumentParser() 269 | parser.add_argument('--indir', nargs='+', type=pathlib.Path, required=True) 270 | parser.add_argument('--outdir', type=pathlib.Path, required=True) 271 | parser.add_argument('--subdir', type=boolean, default=True) 272 | parser.add_argument('--xaxis', type=str, required=True) 273 | parser.add_argument('--yaxis', type=str, required=True) 274 | parser.add_argument('--tasks', nargs='+', default=[r'.*']) 275 | parser.add_argument('--methods', nargs='+', default=[r'.*']) 276 | parser.add_argument('--baselines', nargs='+', default=[]) 277 | parser.add_argument('--bins', type=float, default=0) 278 | parser.add_argument('--aggregate', type=str, default='std') 279 | parser.add_argument('--size', nargs=2, type=float, default=[2.5, 2.3]) 280 | parser.add_argument('--cols', type=int, default=4) 281 | parser.add_argument('--xlim', nargs=2, type=float, default=None) 282 | parser.add_argument('--ylim', nargs=2, type=float, default=None) 283 | parser.add_argument('--xlabel', type=str, default=None) 284 | parser.add_argument('--ylabel', type=str, default=None) 285 | parser.add_argument('--xticks', type=int, default=6) 286 | parser.add_argument('--yticks', type=int, default=5) 287 | parser.add_argument('--labels', nargs='+', default=None) 288 | parser.add_argument('--palette', nargs='+', default=PALETTE) 289 | parser.add_argument('--colors', nargs='+', default={}) 290 | args = parser.parse_args() 291 | if args.subdir: 292 | args.outdir /= args.indir[0].stem 293 | args.indir = [d.expanduser() for d in args.indir] 294 | args.outdir = args.outdir.expanduser() 295 | if args.labels: 296 | assert len(args.labels) % 2 == 0 297 | args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])} 298 | if args.colors: 299 | assert len(args.colors) % 2 == 0 300 | args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])} 301 | args.tasks = [re.compile(p) for p in args.tasks] 302 | args.methods = [re.compile(p) for p in args.methods] 303 | args.baselines = [re.compile(p) for p in args.baselines] 304 | args.palette = 10 * args.palette 305 | return args 306 | 307 | 308 | if __name__ == '__main__': 309 | main(parse_args()) 310 | -------------------------------------------------------------------------------- /scores/baselines.json: -------------------------------------------------------------------------------- 1 | {"dmc_acrobot_swingup": {"d4pg_100m": 91.7, "a3c_100m_proprio": 41.9}, "dmc_cartpole_balance": {"d4pg_100m": 992.8, "a3c_100m_proprio": 951.6}, "dmc_cartpole_swingup": {"d4pg_100m": 862.0, "planet_1e6": 821, "a3c_100m_proprio": 558.4}, "dmc_cartpole_balance_sparse": {"d4pg_100m": 1000.0, "a3c_100m_proprio": 857.4}, "dmc_cartpole_swingup_sparse": {"d4pg_100m": 482.0, "a3c_100m_proprio": 179.8}, "dmc_cheetah_run": {"slac_3e6": 880, "d4pg_100m": 523.8, "planet_1e6": 662, "a3c_100m_proprio": 213.9}, "dmc_cup_catch": {"slac_3e6": 970, "d4pg_100m": 980.5, "planet_1e6": 930, "a3c_100m_proprio": 104.7}, "dmc_finger_spin": {"slac_3e6": 950, "d4pg_100m": 985.7, "planet_1e6": 700, "a3c_100m_proprio": 129.4}, "dmc_finger_turn_easy": {"d4pg_100m": 971.4, "a3c_100m_proprio": 167.3}, "dmc_finger_turn_hard": {"d4pg_100m": 966.0, "a3c_100m_proprio": 88.7}, "dmc_hopper_hop": {"d4pg_100m": 242.0, "a3c_100m_proprio": 0.5}, "dmc_hopper_stand": {"d4pg_100m": 929.9, "a3c_100m_proprio": 27.9}, "dmc_reacher_easy": {"d4pg_100m": 967.4, "planet_1e6": 832, "a3c_100m_proprio": 95.6}, "dmc_reacher_hard": {"d4pg_100m": 957.1, "a3c_100m_proprio": 39.7}, "dmc_walker_stand": {"d4pg_100m": 985.2, "a3c_100m_proprio": 378.4}, "dmc_walker_walk": {"slac_3e6": 840, "d4pg_100m": 968.3, "planet_1e6": 951, "a3c_100m_proprio": 311.0}, "dmc_walker_run": {"d4pg_100m": 567.2, "a3c_100m_proprio": 191.8}, "dmc_pendulum_swingup": {"d4pg_100m": 680.9, "a3c_100m_proprio": 48.6}} -------------------------------------------------------------------------------- /scores/dreamer.json: -------------------------------------------------------------------------------- 1 | [{"task": "dmc_cartpole_swingup", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [75.9375, 145.378125, 482.8, 712.7, 846.1, 855.85, 866.15, 850.3, 868.0, 859.25, 856.55, 817.95, 841.4, 833.2, 858.75, 841.6, 871.5, 858.85, 869.55, 854.0, 865.05, 874.05, 866.7, 878.05, 859.55, 862.45, 875.55, 876.4, 877.85, 871.65, 868.55, 840.4, 831.7, 857.25, 847.1, 866.65, 863.35, 865.65, 859.8, 872.2, 875.35, 867.1, 869.15, 875.0, 843.95, 875.9, 866.45, 872.25, 858.7, 865.6]}, {"task": "dmc_cartpole_swingup", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [73.1875, 276.153125, 654.825, 849.85, 871.7, 863.85, 866.55, 858.25, 863.75, 862.15, 870.45, 770.35, 860.3, 860.7, 874.25, 872.55, 873.3, 856.05, 839.2, 871.6, 823.125, 871.85, 868.25, 867.15, 839.15, 839.75, 834.4, 848.5, 858.55, 843.15, 872.45, 871.1, 837.1, 872.0, 861.4, 863.8, 877.25, 875.5, 873.7, 807.525, 874.8, 873.4, 878.05, 879.15, 854.2, 873.7, 841.85, 858.9, 858.5, 863.75]}, {"task": "dmc_cartpole_swingup", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [76.0625, 125.9390625, 412.35, 685.15, 828.8, 851.5, 852.5, 844.3, 853.15, 852.9, 848.15, 832.65, 830.25, 846.65, 808.4, 849.95, 844.55, 859.15, 872.0, 866.2, 871.9, 811.6, 867.7, 870.45, 864.7, 857.8, 872.05, 877.1, 877.4, 873.9, 878.25, 876.8181818181819, 875.55, 877.75, 873.5, 862.15, 876.0, 880.15, 879.35, 880.1, 877.55, 872.4, 808.95, 877.5, 876.25, 875.75, 874.5, 872.0, 875.0, 871.8]}, {"task": "dmc_cartpole_swingup", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [74.625, 75.00625, 87.4640625, 185.225, 394.55, 573.7, 718.3, 790.9, 814.75, 838.1, 840.05, 803.2, 846.05, 826.85, 844.5, 859.1, 864.5, 850.15, 865.35, 858.95, 854.85, 845.5, 868.85, 830.25, 815.6, 792.375, 848.0, 859.25, 788.025, 790.9, 805.0, 797.05, 806.65, 784.45, 849.15, 868.65, 808.8, 869.95, 850.55, 867.75, 857.3, 868.7, 837.85, 873.5, 856.9, 868.75, 850.85, 835.2, 794.475, 869.9]}, {"task": "dmc_cartpole_swingup", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [73.4375, 74.3625, 74.575, 98.84375, 142.25, 247.2625, 529.7, 681.05, 767.3, 844.25, 861.5, 859.9, 862.35, 871.15, 866.65, 855.85, 873.15, 864.15, 864.1, 872.5, 871.4, 876.55, 861.3, 877.85, 851.15, 744.2, 836.7, 871.2, 870.8, 859.05, 843.7, 867.15, 852.05, 826.3, 760.925, 873.75, 862.85, 763.55, 855.65, 822.4, 864.6, 848.05, 866.8, 865.7, 869.0, 870.35, 851.2, 877.1, 867.5, 872.55]}, {"task": "dmc_quadruped_walk", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [29.65625, 61.43125, 75.18125, 108.6015625, 154.3375, 207.721875, 245.66875, 291.49375, 328.15, 374.05, 589.0, 597.8, 570.45, 608.0625, 607.0, 724.1, 664.65, 785.1, 793.2, 827.4, 804.7, 843.2, 777.65, 850.35, 844.625, 820.1, 901.15, 891.0, 914.95, 916.35, 898.4, 915.6, 912.55, 915.7, 910.55, 831.05, 936.55, 914.75, 916.05, 948.6, 953.95, 927.8, 890.05, 920.45, 919.05, 890.6, 935.15, 892.8, 920.55, 941.7]}, {"task": "dmc_quadruped_walk", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [15.8203125, 49.215625, 92.1296875, 66.915625, 107.5390625, 152.64375, 196.8875, 229.775, 258.4875, 320.55, 348.9875, 470.475, 400.85, 406.8375, 502.1125, 506.6625, 527.4, 552.6, 631.15, 697.625, 712.45, 679.25, 787.55, 836.15, 893.3, 869.45, 881.95, 876.65, 838.9, 897.25, 901.9, 864.4, 890.45, 870.85, 931.2, 903.5, 929.75, 897.45, 904.5, 840.2, 907.5, 936.65, 913.35, 925.7, 904.65, 944.5, 929.05, 918.05, 920.05, 921.9]}, {"task": "dmc_quadruped_walk", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [1.6181640625, 40.0, 89.7296875, 117.621875, 246.06875, 332.8625, 273.4375, 326.5875, 382.7, 346.8625, 434.525, 493.25, 468.775, 533.4, 640.95, 638.35, 737.75, 740.95, 856.6, 880.55, 904.75, 851.2, 861.7, 901.9, 893.3, 931.7, 928.85, 933.0, 953.1, 883.7, 908.8, 895.7, 929.95, 937.8, 932.45, 953.25, 912.9, 944.15, 952.4, 948.55, 950.5, 928.1, 935.8, 956.05, 940.8, 955.9, 965.9, 961.55, 941.6, 915.9]}, {"task": "dmc_quadruped_walk", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [8.671875, 57.34609375, 45.7125, 118.9453125, 117.5, 147.08125, 301.2125, 234.6375, 327.0625, 347.225, 441.525, 458.1, 436.85, 418.4125, 468.425, 595.45, 691.35, 781.85, 826.95, 841.0, 820.5, 853.75, 821.65, 876.05, 877.8, 884.65, 867.95, 885.3, 908.7, 901.35, 924.25, 947.2, 950.6, 936.0, 929.2, 941.35, 933.6, 957.55, 949.5, 907.5, 885.45, 920.05, 930.3, 864.6, 904.45, 928.65, 950.1, 829.2125, 884.05, 910.35]}, {"task": "dmc_quadruped_walk", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [33.6875, 87.2765625, 73.2640625, 140.58125, 200.525, 273.525, 332.925, 372.575, 341.75, 392.35, 415.2875, 483.15, 524.2, 557.95, 535.6, 761.775, 902.55, 732.325, 803.0, 810.45, 860.65, 816.6, 875.35, 830.35, 849.3, 806.65, 880.1, 926.85, 919.4, 911.95, 899.45, 925.1, 906.9, 904.15, 916.4, 950.55, 928.45, 926.1, 926.7, 916.3, 949.1, 934.25, 923.85, 949.5, 930.7, 894.95, 950.3, 935.85, 928.1, 942.0]}, {"task": "dmc_hopper_stand", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [4.0, 5.369140625, 310.20625, 663.425, 755.8, 822.375, 868.225, 912.8, 910.1, 906.65, 909.65, 900.05, 867.5, 893.45, 926.45, 922.6, 934.85, 899.65, 900.35, 903.15, 888.6, 910.7, 930.0, 921.6, 855.55, 895.15, 904.9, 932.05, 954.5, 923.7, 936.25, 942.55, 934.5, 941.5, 935.85, 932.15, 939.2, 846.6, 929.75, 855.45, 940.1, 939.5, 849.8, 939.15, 934.05, 937.1, 940.95, 945.3, 945.15, 930.9]}, {"task": "dmc_hopper_stand", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.8048828125, 3.219091796875, 141.0046875, 317.725, 336.075, 472.675, 800.8, 776.0, 868.65, 914.9, 907.2, 824.65, 923.2, 927.2, 842.95, 944.8, 948.75, 933.8, 923.4, 936.1, 930.35, 904.3, 929.85, 926.65, 933.25, 930.9, 937.15, 932.9, 938.3, 944.25, 934.35, 837.35, 935.65, 915.75, 851.1, 926.45, 930.55, 932.25, 941.4, 930.5, 941.6, 944.55, 946.15, 941.4, 924.85, 934.85, 934.8, 936.75, 954.85]}, {"task": "dmc_hopper_stand", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 5.540625, 121.058984375, 544.375, 715.225, 725.975, 801.8, 917.7, 813.5, 874.35, 909.15, 923.0, 911.15, 932.6, 916.6, 932.85, 926.5, 921.95, 939.95, 934.6, 937.75, 934.0, 947.9, 936.35, 935.35, 947.05, 854.05, 951.05, 925.5, 919.8, 911.5, 933.35, 943.45, 938.25, 934.85, 940.55, 935.1, 917.5, 841.35, 935.25, 932.6, 924.85, 943.0, 936.55, 941.1, 947.45, 932.1, 952.1, 930.7, 937.25]}, {"task": "dmc_hopper_stand", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 1.03984375, 4.404296875, 170.3171875, 361.925, 511.125, 720.05, 796.3, 875.6, 862.2, 921.85, 927.2, 935.8, 930.65, 915.6, 924.2, 930.7, 928.85, 940.2, 926.0, 942.25, 847.85, 926.45, 948.95, 920.2, 945.95, 937.4, 942.8, 931.65, 944.65, 939.35, 940.15, 936.65, 943.2, 942.15, 950.65, 940.5, 936.25, 839.5, 940.75, 937.5, 939.25, 947.05, 905.9, 947.2, 935.35, 940.25, 848.8, 950.1, 928.55]}, {"task": "dmc_hopper_stand", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.63984375, 4.41796875, 5.7828125, 201.85, 652.725, 720.95, 814.4, 842.85, 915.9, 878.05, 895.5, 926.3, 932.6, 917.75, 932.5, 922.45, 896.3, 913.5, 934.85, 944.1, 925.3, 913.45, 938.45, 935.0, 936.45, 939.8, 944.65, 928.4, 948.25, 939.35, 949.5, 918.0, 850.95, 949.1, 921.9, 946.35, 916.4, 942.85, 947.8, 935.85, 949.05, 940.45, 942.85, 858.1, 955.2, 939.35, 946.5, 844.85, 949.6]}, {"task": "dmc_acrobot_swingup", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.05804443359375, 57.07297210693359, 84.29622802734374, 204.41171875, 292.3875, 425.675, 350.965625, 375.6, 402.9, 453.875, 442.2375, 500.1125, 422.3, 499.325, 371.425, 388.15, 436.6125, 489.5, 427.25, 461.7, 471.05, 432.0, 494.95, 434.025, 481.4, 453.675, 484.95, 470.275, 406.225, 420.9, 488.5, 419.275, 479.025, 460.8, 399.55, 383.1, 383.725, 389.275, 317.1, 396.9125, 456.225, 401.225, 450.425, 415.1875, 484.15, 430.4125, 368.25, 468.7, 423.325, 407.15]}, {"task": "dmc_acrobot_swingup", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0], "ys": [0.8515625, 42.020166015625, 79.67265625, 185.80859375, 277.95625, 359.375, 346.125, 432.0375, 380.3625, 399.95, 372.0125, 408.75, 485.15, 443.05, 468.575, 396.8, 453.9, 423.30625, 443.2, 556.05, 445.475, 437.1625, 510.15, 517.875, 484.425, 489.05, 490.775, 500.55, 427.675, 468.3625, 493.25, 536.2, 507.975, 525.475, 517.525, 478.275, 493.025, 410.3125, 411.675, 494.4875, 483.45]}, {"task": "dmc_acrobot_swingup", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [149.75, 22.683526611328126, 51.10404663085937, 145.3125, 145.73627319335938, 260.3875, 373.1875, 393.1875, 384.2375, 341.5625, 391.1, 442.7125, 421.5875, 476.85, 486.4, 438.2, 485.625, 490.0125, 443.425, 489.325, 467.25, 524.225, 464.6625, 492.825, 478.425, 484.825, 502.7, 471.425, 450.2, 498.875, 432.475, 505.425, 440.875, 524.1, 494.475, 550.75, 522.825, 433.875, 522.625, 485.5, 518.0, 530.825, 484.425, 527.875, 463.95, 551.225, 492.4, 494.875, 550.075, 497.525]}, {"task": "dmc_acrobot_swingup", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [112.125, 37.519219970703126, 100.0078125, 146.25474243164064, 315.9375, 303.519775390625, 367.9333435058594, 332.7125, 418.8, 411.9625, 414.275, 359.6, 419.0296875, 404.625, 465.7875, 439.725, 452.8625, 397.1, 458.0, 512.9, 539.8, 496.1, 480.0, 489.175, 492.0, 490.9, 495.85, 512.6625, 497.6, 504.35, 513.05, 518.8, 457.5125, 514.15, 529.4, 528.4772727272727, 561.55, 580.825, 539.425, 477.225, 527.8, 469.175, 495.625, 523.875, 476.575, 497.425, 443.125, 535.85, 390.225, 421.6]}, {"task": "dmc_acrobot_swingup", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0], "ys": [7.015625, 26.3669921875, 70.1416015625, 145.29375, 189.79319763183594, 333.8, 371.9, 315.840625, 404.65, 438.2125, 429.375, 411.625, 450.125, 413.6, 391.175, 388.4125, 431.325, 484.0, 468.15, 453.775, 493.4, 466.4, 497.8, 551.575, 502.95, 572.225, 487.8, 514.825, 404.3, 523.525, 518.0, 575.55, 552.425, 589.5, 508.525, 519.275]}, {"task": "dmc_cartpole_swingup_sparse", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 20.8, 428.5, 745.4, 736.3, 790.2, 713.4, 774.9, 784.3, 809.1, 780.5, 802.2, 742.3, 777.2, 795.4, 801.1, 815.8, 801.3, 808.5, 811.2, 793.8, 789.3, 785.1, 819.5, 801.5, 811.9, 786.6, 806.1, 790.1, 810.5, 820.2, 827.5, 817.9, 812.7, 809.7, 758.6, 812.2, 752.8, 711.7, 815.4, 737.2, 724.8, 742.0, 829.4, 817.7, 818.5, 816.8, 814.5, 807.8, 804.1]}, {"task": "dmc_cartpole_swingup_sparse", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 45.7, 494.7, 744.1, 818.8, 807.7, 824.8, 821.0, 793.7, 830.6, 803.3, 837.8, 800.0, 807.8, 785.3, 835.7, 816.0, 820.1, 811.2, 817.4, 789.1, 826.7, 825.8, 827.2, 711.3, 812.9, 838.2, 829.0, 778.5, 816.1, 823.5, 827.0, 811.2, 832.9, 823.6, 824.9, 836.8, 821.7, 772.1, 756.1, 813.9, 828.7, 785.3, 831.9, 816.9, 814.8, 798.0, 794.6, 817.1, 833.1]}, {"task": "dmc_cartpole_swingup_sparse", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 103.4, 430.5, 795.5, 757.3, 825.9, 786.6, 741.5, 815.9, 785.2, 820.9, 799.4, 803.8, 827.0, 827.0, 825.7, 820.2, 827.4, 819.1, 778.5, 622.5, 697.5, 803.5, 759.0, 817.9, 744.0, 746.4, 797.2, 698.2, 801.7, 704.1, 668.1, 750.4, 797.2, 740.9, 789.8, 700.5, 754.4, 641.8, 733.4, 761.2, 763.4, 759.6, 732.1, 746.4, 751.3, 784.4, 687.0, 773.3, 759.8]}, {"task": "dmc_cartpole_swingup_sparse", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 2.0, 139.5, 482.9, 711.1, 777.2, 808.0, 805.0, 795.4, 814.1, 767.5, 817.7, 795.9, 824.3, 816.4, 777.0, 812.5, 830.6, 828.3, 827.7, 830.2, 821.9, 829.3, 831.9, 821.6, 818.0, 815.8, 812.5, 819.1, 826.0, 823.3, 818.2, 818.8, 820.7, 811.2, 821.6, 750.8181818181819, 785.5, 805.0, 647.6, 542.9, 777.0, 783.6, 786.5, 804.7, 816.7, 793.5, 703.9, 791.4, 757.3]}, {"task": "dmc_cartpole_swingup_sparse", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0], "ys": [0.0, 54.3, 419.6, 775.3, 779.6, 797.4, 806.6, 734.0, 806.0, 817.0, 811.5, 794.5, 825.8, 810.2, 758.3, 822.8, 808.4, 827.3, 818.9, 828.8, 821.2, 829.3, 831.0, 812.3, 801.2, 802.1, 781.7, 783.9, 769.4, 746.7, 836.7, 765.4, 651.5, 788.1, 815.3, 694.2, 746.6]}, {"task": "dmc_cartpole_balance_sparse", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [9.0, 638.4, 983.8, 1000.0, 999.8, 998.5, 1000.0, 1000.0, 1000.0, 997.2, 1000.0, 848.0, 982.3, 993.7, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 978.4, 972.4, 1000.0, 1000.0, 1000.0, 983.2, 1000.0, 1000.0, 1000.0, 993.8, 982.0, 832.0, 921.5, 911.6, 925.1, 896.2, 998.6, 806.5, 877.0, 884.9, 992.6, 998.6, 1000.0, 804.6, 962.0, 1000.0, 1000.0, 998.6, 912.5, 958.6, 952.9]}, {"task": "dmc_cartpole_balance_sparse", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0], "ys": [12.0, 557.3, 997.4, 999.5, 998.6, 1000.0, 965.6, 990.8, 989.2, 921.0, 1000.0, 1000.0, 1000.0, 993.6, 981.6, 1000.0, 985.4, 987.7, 918.7, 959.5, 999.5, 873.0, 932.9, 980.2, 945.3, 1000.0, 985.3, 976.4, 982.4, 665.6, 15.6, 14.8, 14.9, 14.4, 12.2, 14.1, 12.3, 12.727272727272727, 13.5, 13.2, 11.9, 12.6, 12.6, 14.0, 19.4]}, {"task": "dmc_cartpole_balance_sparse", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [13.0, 541.4, 998.6, 1000.0, 1000.0, 996.2, 1000.0, 998.5, 999.6, 1000.0, 1000.0, 1000.0, 934.7, 1000.0, 1000.0, 998.5, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 967.3, 869.1, 942.6, 998.4, 998.4, 966.9, 997.4, 904.7, 922.9, 991.4, 1000.0, 1000.0, 816.6, 942.6, 985.3, 992.9, 1000.0, 981.1, 837.0, 44.0, 33.7, 84.0, 624.8, 1000.0, 1000.0, 995.1, 1000.0, 1000.0]}, {"task": "dmc_cartpole_balance_sparse", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [13.0, 702.5, 997.5, 986.1, 998.4, 999.1, 993.5, 996.2, 998.6, 1000.0, 999.8, 923.4, 936.9, 999.8, 1000.0, 995.6, 997.3, 976.5, 1000.0, 967.9, 1000.0, 1000.0, 915.0, 1000.0, 994.7, 1000.0, 971.6, 968.0, 996.3, 915.7, 904.9, 908.5, 983.7272727272727, 1000.0, 1000.0, 1000.0, 998.9, 997.2, 999.1, 967.5, 965.9, 982.3, 989.3, 968.0, 999.4, 999.3, 999.9, 978.2, 1000.0, 333.5]}, {"task": "dmc_cartpole_balance_sparse", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0], "ys": [9.0, 617.7, 1000.0, 917.8, 990.6, 994.8, 949.0, 999.7, 1000.0, 992.3, 1000.0, 1000.0, 998.7, 998.6, 987.1, 1000.0, 1000.0, 999.3, 1000.0, 951.5, 951.5, 999.4, 972.0, 927.8, 948.1, 1000.0, 1000.0, 928.1, 973.6, 987.3, 1000.0, 1000.0, 981.8, 1000.0, 1000.0]}, {"task": "dmc_walker_run", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [4.03125, 90.6359375, 178.575, 235.9125, 293.75, 359.953125, 557.2, 605.8, 636.1, 687.3, 691.7, 741.65, 736.4, 755.55, 757.0, 756.7, 754.25, 778.15, 790.4, 787.35, 637.71796875, 771.6, 792.7, 666.35, 787.1, 812.25, 799.55, 799.05, 807.6, 792.75, 797.55, 809.9, 792.9, 793.85, 806.45, 810.2, 817.25, 813.05, 813.35, 820.1, 809.55, 809.15, 820.15, 806.85, 814.9, 818.5, 811.5, 823.2, 826.7, 813.4]}, {"task": "dmc_walker_run", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [19.84375, 42.9890625, 157.01875, 235.8875, 290.825, 420.675, 533.7, 611.95, 625.7, 640.6, 643.1, 714.25, 738.9, 748.85, 763.5, 751.5, 767.65, 695.5609375, 775.7, 790.65, 805.25, 809.85, 800.9, 812.8, 792.55, 811.0, 815.5, 807.3, 815.65, 806.0, 809.15, 823.0, 821.5, 824.15, 823.55, 821.7, 818.7, 823.8, 840.95, 829.7, 788.3, 822.25, 822.3, 835.1, 834.3, 821.85, 833.2, 825.45, 836.8, 829.7]}, {"task": "dmc_walker_run", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [14.2890625, 65.8890625, 167.85, 286.9125, 446.725, 587.4, 628.65, 652.3, 683.1, 728.9, 750.45, 779.5, 783.15, 783.5, 796.8, 777.05, 808.4, 804.9, 808.0, 798.35, 804.65, 811.95, 801.7, 797.6, 804.45, 808.95, 807.55, 813.95, 809.2, 830.55, 810.7, 822.5, 820.5, 825.7, 815.65, 816.35, 821.85, 817.8, 830.5, 829.0, 820.4, 822.4, 825.9, 820.0, 822.45, 814.65, 823.4, 818.0, 818.5, 823.55]}, {"task": "dmc_walker_run", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [16.8125, 112.1546875, 210.7125, 279.7375, 329.125, 419.725, 566.45, 609.425, 684.6, 729.6, 742.85, 764.95, 778.15, 761.3, 779.05, 799.15, 799.25, 799.0, 799.75, 806.8, 811.2, 811.55, 812.95, 813.25, 811.8, 814.15, 800.5, 811.65, 807.1, 807.4, 807.15, 813.85, 808.2, 808.0, 817.9, 807.8, 803.55, 815.9, 811.25, 826.2, 812.45, 809.95, 824.2, 822.3, 812.4, 817.5, 821.25, 817.15, 823.45, 826.25]}, {"task": "dmc_walker_run", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [44.1875, 101.54375, 222.6625, 355.5, 432.8, 529.725, 612.775, 688.7, 729.7, 749.65, 761.75, 774.65, 773.65, 776.85, 789.4, 774.4, 779.4, 780.25, 770.75, 789.1, 776.8, 781.6, 787.8, 794.9, 790.35, 796.1, 790.4, 794.5, 798.55, 797.1, 793.45, 804.45, 813.95, 796.8, 812.95, 809.0, 821.3, 808.15, 806.2, 803.5, 802.3, 820.9, 804.55, 809.25, 809.7, 819.6, 815.75, 803.25, 826.2, 821.4]}, {"task": "dmc_hopper_hop", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.006755828857421875, 0.1185028076171875, 0.6159820556640625, 31.586579513549804, 62.715625, 142.85, 181.9375, 200.2625, 205.625, 230.1875, 254.85, 259.2125, 259.9375, 264.125, 298.325, 299.05, 297.975, 308.65, 316.9, 303.1625, 315.875, 276.375, 323.375, 328.55, 321.8, 292.275, 328.375, 328.975, 292.05, 329.925, 324.875, 334.125, 303.375, 344.925, 337.725, 334.8, 349.275, 343.675, 350.275, 350.75, 357.85, 345.375, 344.375, 345.0, 359.35, 352.15, 357.975, 348.15, 353.725, 339.575]}, {"task": "dmc_hopper_hop", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.00522613525390625, 0.0006381988525390625, 0.007147216796875, 2.3416015625, 20.80625, 78.796875, 116.71875, 175.9125, 192.0375, 202.95, 210.95, 213.6625, 233.575, 244.6625, 222.05, 261.8125, 282.2, 280.25, 276.7875, 283.35, 281.3, 280.775, 289.875, 299.975, 304.0, 294.0875, 302.85, 298.425, 265.675, 286.2375, 295.7, 295.25, 293.6125, 297.275, 315.275, 313.325, 308.925, 298.225, 317.475, 330.125, 327.475, 323.9, 326.9, 318.9, 319.25, 316.875, 317.1, 309.125, 320.15, 309.7]}, {"task": "dmc_hopper_hop", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.0, 1.310009765625, 25.66953125, 69.43125, 84.734375, 100.69375, 139.625, 180.4125, 206.25, 228.4, 248.65, 252.55, 276.575, 295.425, 306.0, 313.9, 279.0375, 316.4, 318.05, 321.45, 327.625, 337.875, 315.075, 385.925, 440.45, 424.325, 472.575, 455.725, 492.475, 500.825, 452.95, 506.325, 490.575, 467.575, 516.95, 509.575, 530.9, 544.875, 556.65, 517.3875, 583.8, 575.4, 575.1, 612.85, 614.7, 592.275, 609.0, 606.45, 620.35]}, {"task": "dmc_hopper_hop", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 1.80775146484375, 47.721484375, 104.73125, 145.7125, 189.85, 215.4125, 241.675, 282.1375, 254.5125, 291.85, 304.4125, 297.3125, 311.9, 321.1, 332.225, 320.5, 317.175, 320.4, 305.75, 322.45, 312.875, 324.05, 329.6, 328.025, 336.725, 333.475, 338.85, 300.8, 301.4875, 322.05, 334.65, 340.475, 371.225, 337.825, 377.825, 379.075, 379.25, 369.525, 393.025, 379.025, 358.4, 341.95, 375.025, 409.5, 415.5, 490.3, 530.15, 520.1, 548.225]}, {"task": "dmc_hopper_hop", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0], "ys": [0.0, 0.00049285888671875, 0.61513671875, 5.3861328125, 49.38515625, 103.88125, 176.55, 214.875, 290.9375, 328.7, 342.175, 336.8, 273.4875, 367.45, 360.825, 382.475, 454.9, 474.075, 458.25, 489.75, 501.1, 498.5, 547.625, 549.4, 541.9, 468.675, 563.8, 595.2, 542.175, 574.4, 596.0, 590.3, 609.85, 627.75]}, {"task": "dmc_walker_stand", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [357.5, 281.6125, 785.95, 876.55, 949.6, 954.95, 961.95, 941.45, 932.25, 945.65, 936.9, 906.75, 973.3, 958.25, 964.45, 937.25, 971.25, 968.5, 979.2, 973.35, 965.1, 973.1, 974.6, 982.65, 969.1, 979.9, 973.5, 985.6, 980.9, 976.95, 985.95, 988.35, 897.425, 976.15, 878.50625, 963.4, 943.15, 968.75, 970.45, 971.85, 970.25, 983.9, 904.125, 964.7, 956.05, 895.3375, 985.5, 984.3, 979.8, 903.525]}, {"task": "dmc_walker_stand", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [236.375, 252.7875, 700.95, 819.65, 918.35, 934.2, 965.55, 972.15, 962.3, 967.45, 966.0, 958.5, 925.75, 986.75, 970.4, 977.25, 893.75, 978.4, 980.4, 948.6, 939.9, 972.85, 958.1, 964.3, 939.3, 882.9625, 952.5, 971.8, 973.7, 964.2, 975.85, 939.65, 988.75, 966.85, 948.05, 938.2, 980.1, 963.3, 866.65, 922.9, 976.75, 977.25, 974.4, 973.2, 972.65, 980.55, 991.45, 972.05, 975.9, 857.35]}, {"task": "dmc_walker_stand", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [159.75, 313.48125, 669.3, 918.1, 940.7, 958.75, 965.75, 956.7, 972.05, 974.15, 959.45, 974.3, 977.05, 970.7, 966.2, 909.45, 961.15, 984.65, 974.5, 962.25, 980.8, 969.15, 983.85, 985.85, 984.9, 980.1, 985.35, 986.6, 987.7, 964.05, 963.05, 982.05, 970.7, 983.25, 979.4, 982.35, 888.925, 980.05, 977.9, 981.4, 987.3, 977.1, 886.85625, 967.75, 903.1375, 886.675, 983.6, 975.45, 990.65, 968.8]}, {"task": "dmc_walker_stand", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [91.375, 448.071875, 853.85, 963.05, 949.95, 979.75, 969.6, 970.95, 979.8, 911.775, 955.5, 971.4, 968.25, 973.05, 978.1, 972.0, 974.15, 981.85, 971.85, 977.85, 962.95, 890.875, 977.9, 978.55, 974.4, 964.7, 983.75, 978.65, 973.1, 982.25, 985.6, 978.9, 971.35, 982.2, 973.55, 983.45, 985.55, 896.1, 984.6, 967.0, 982.65, 962.2, 971.4, 971.05, 984.1, 974.05, 975.75, 979.65, 975.5, 977.35]}, {"task": "dmc_walker_stand", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [341.0, 279.89375, 605.15, 730.85, 870.1, 972.25, 971.1, 968.15, 962.4, 802.009375, 980.35, 979.85, 895.10625, 979.7, 982.35, 969.55, 977.3, 977.35, 978.8, 902.0625, 978.2, 976.3, 977.7, 974.0, 969.4, 980.65, 976.25, 991.95, 895.7875, 984.9, 968.25, 981.35, 968.55, 987.05, 976.75, 969.4, 982.1, 983.65, 979.6, 974.5, 975.1, 987.6, 991.4, 967.05, 983.15, 972.1, 978.75, 978.65, 987.8, 981.1]}, {"task": "dmc_cartpole_balance", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [135.75, 445.575, 847.05, 980.2, 988.6, 978.3, 978.15, 988.3, 990.6, 987.2, 992.4, 971.4, 964.2, 984.4, 984.0, 991.35, 994.9, 985.1, 981.55, 987.05, 998.4, 992.0, 997.25, 997.4, 942.225, 997.0, 996.6, 997.15, 997.5, 995.15, 994.7, 995.6, 996.3, 996.65, 997.1, 996.0, 910.05, 991.9, 994.95, 992.95, 997.15, 964.3, 996.9, 996.85, 996.4, 991.8, 956.3, 616.725, 203.8375, 251.2]}, {"task": "dmc_cartpole_balance", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0], "ys": [136.125, 418.5625, 865.5, 971.4, 978.5, 988.85, 987.55, 994.35, 994.05, 996.95, 989.1, 994.2, 996.2, 992.2, 992.85, 994.6, 997.4, 959.35, 997.0, 964.05, 997.2, 996.4, 996.8, 998.0, 991.5, 997.15, 992.2, 997.85, 958.85, 992.35, 993.5, 997.25, 946.15, 994.8, 926.8, 996.7727272727273, 986.3, 997.6, 998.9, 998.2, 998.2, 911.45, 921.25]}, {"task": "dmc_cartpole_balance", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [134.625, 529.4375, 798.15, 905.45, 969.35, 990.6, 970.85, 993.0, 995.1, 976.9, 995.65, 993.25, 972.45, 991.0, 954.7, 997.85, 997.4, 998.35, 996.5, 997.2, 997.25, 994.9, 996.7, 995.8, 997.5, 994.6, 997.4, 902.95, 992.75, 994.85, 996.45, 996.0, 914.5, 942.975, 951.4, 998.2, 996.35, 997.05, 946.225, 998.35, 995.6, 875.675, 962.15, 996.2, 951.9, 996.55, 988.35, 928.525, 918.925, 907.25]}, {"task": "dmc_cartpole_balance", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [135.5, 474.825, 807.15, 939.3, 971.85, 972.75, 992.15, 992.15, 993.4, 992.9, 992.4, 997.45, 996.85, 996.0, 997.5, 997.8, 997.7, 996.7, 997.95, 949.3, 978.4, 997.65, 984.55, 997.45, 997.9, 998.45, 997.5, 998.0, 998.05, 998.5, 998.15, 997.95, 998.15, 998.95, 998.4, 998.75, 998.1, 997.9, 998.35, 997.3, 991.55, 801.45, 881.4, 997.1, 987.7, 991.4, 944.05, 997.75, 995.7, 892.6]}, {"task": "dmc_cartpole_balance", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [136.0, 493.15, 816.6, 981.85, 987.9, 978.35, 983.65, 994.0, 987.25, 990.3, 971.8, 989.9, 977.75, 990.85, 995.3, 963.9, 982.0, 937.45, 989.15, 997.65, 899.725, 981.9, 983.35, 987.7, 995.35, 995.0, 951.2, 991.0, 977.75, 994.2, 947.2, 944.45, 987.05, 989.75, 988.05, 995.0, 994.8, 997.75, 959.05, 994.35, 993.45, 996.05, 992.9, 991.9, 787.975, 280.25, 269.95, 250.9375, 234.125, 234.9125]}, {"task": "dmc_finger_turn_easy", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [1000.0, 299.8, 198.4, 387.2, 359.7, 240.4, 576.7, 731.4, 698.8, 735.3, 815.4, 854.7, 924.5, 935.0, 961.7, 966.9, 962.6, 958.8, 959.4, 951.4, 966.1, 945.7, 963.2, 958.1, 884.3, 971.9, 762.0, 925.9, 947.7, 942.4, 964.4, 965.2, 941.8, 970.5, 946.9, 971.5, 870.2, 961.3, 868.2, 861.7, 981.3, 963.6, 880.9, 957.3, 874.2, 868.9, 977.3, 976.6, 976.1, 985.4]}, {"task": "dmc_finger_turn_easy", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 100.0, 573.0, 368.3, 492.2, 421.5, 549.5, 728.3, 588.4, 731.0, 858.2, 961.4, 952.8, 935.9, 940.3, 857.0, 844.7, 983.0, 878.4, 856.7, 820.9, 889.2, 812.5, 817.9, 947.2, 873.7, 800.9, 884.4, 879.6, 889.2, 875.9, 711.8, 964.3, 750.2, 914.6, 847.5, 877.5, 820.5, 883.8, 873.4, 941.2, 963.0, 964.2, 975.1, 979.3, 956.7, 921.0, 882.1, 942.2, 943.5]}, {"task": "dmc_finger_turn_easy", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [1000.0, 112.9, 185.8, 300.8, 288.4, 480.6, 810.4, 520.1, 341.4, 629.6, 757.6, 700.8, 649.9, 853.6, 861.1, 960.5, 863.8, 939.0, 956.2, 950.3, 889.7, 958.3, 863.3, 915.1, 968.5, 901.1, 942.2, 972.8, 869.1, 980.4, 849.4, 886.8, 953.6, 906.4, 843.7, 970.3, 779.8, 884.6, 851.8, 963.5, 886.9, 964.0, 787.5, 876.8, 957.7, 851.9, 974.5, 970.8, 952.5, 851.7]}, {"task": "dmc_finger_turn_easy", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 54.8, 215.2, 628.5, 587.3, 478.3, 343.7, 646.7, 623.4, 498.3, 668.4, 613.7, 838.5, 705.2, 712.2, 781.3, 618.5, 679.5, 725.6, 825.2, 682.1, 750.3, 757.8, 943.3, 825.4, 771.2, 685.6, 848.7, 664.5, 933.8, 903.4, 967.9, 777.7, 794.0, 932.2, 755.7, 947.2, 892.1, 923.1, 782.6, 960.7, 778.1, 821.3, 967.4, 946.4, 847.4, 929.7, 966.1, 960.7, 922.0]}, {"task": "dmc_finger_turn_easy", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 86.5, 182.0, 631.5, 169.5, 749.4, 420.9, 684.6, 855.1, 862.6, 806.0, 869.7, 938.3, 946.7, 848.3, 963.4, 950.3, 957.6, 875.4, 935.2, 934.5, 966.5, 970.0, 955.8, 964.3, 963.9, 944.0, 959.1, 898.3, 973.2, 943.6, 954.2, 958.6, 964.8, 974.5, 957.5, 972.6, 984.7, 940.8, 955.1, 963.7, 952.7, 949.7, 855.1, 959.1, 968.8, 965.4, 925.9, 948.1, 962.2]}, {"task": "dmc_quadruped_run", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [13.28125, 136.887255859375, 162.5984375, 288.771875, 203.784375, 308.55, 311.35, 357.3, 439.625, 423.1625, 500.225, 594.45, 607.35, 621.275, 740.525, 743.0, 836.15, 824.225, 725.725, 872.25, 864.9, 894.45, 912.55, 818.35, 857.6, 850.95, 877.25, 887.1, 916.45, 896.0, 904.1, 919.4, 900.75, 900.4, 907.05, 901.3, 904.35, 891.05, 915.15, 925.55, 911.6, 926.35, 940.2, 928.45, 946.3, 905.45, 920.6, 904.95, 913.7, 915.95]}, {"task": "dmc_quadruped_run", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0], "ys": [456.0, 68.16875, 109.1046875, 191.44375, 270.20625, 281.18125, 291.775, 323.8375, 343.775, 414.075, 398.0625, 367.5625, 430.0625, 428.975, 496.775, 674.5, 753.3, 766.7, 766.3, 789.85, 810.25, 794.725, 827.475, 893.0, 855.65, 851.55, 802.0, 877.85, 898.95, 920.8, 876.6, 934.5, 927.75, 894.75, 893.5, 933.1, 849.7]}, {"task": "dmc_quadruped_run", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.6904296875, 34.6921875, 100.825, 46.2890625, 103.65234375, 137.015625, 228.4625, 186.74375, 287.325, 311.6875, 270.3, 412.0125, 408.35, 417.125, 398.75, 361.175, 384.925, 382.55, 407.25, 455.0, 492.25, 575.875, 596.325, 679.25, 697.85, 754.95, 673.45, 800.5, 785.55, 809.575, 860.45, 881.25, 879.1, 875.35, 887.25, 898.2, 836.6, 880.0, 881.95, 871.4, 789.4, 843.05, 879.95, 884.85, 827.9, 850.55, 861.85, 864.65, 858.35, 814.0]}, {"task": "dmc_quadruped_run", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [229.25, 65.44375, 100.7203125, 114.68125, 144.25, 105.04521484375, 115.5140625, 153.052490234375, 159.315625, 244.465625, 360.75, 350.8, 323.05, 350.875, 380.65, 380.025, 376.9625, 373.625, 376.025, 437.175, 511.275, 599.575, 640.8, 720.3, 732.525, 791.2, 781.175, 784.5, 755.175, 849.8, 830.25, 838.25, 910.4, 863.3, 914.25, 920.1, 901.05, 875.15, 929.0, 924.65, 927.0, 906.3, 908.2, 917.45, 924.05, 891.0, 941.55, 920.05, 937.55, 938.35]}, {"task": "dmc_quadruped_run", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0], "ys": [18.828125, 33.2078125, 134.2328125, 88.4578125, 109.640625, 171.25, 238.425, 339.2375, 398.475, 349.125, 413.15, 368.475, 306.9, 367.025, 345.975, 417.025, 432.125, 412.425, 517.675, 529.025, 616.325, 757.0, 745.05, 886.75, 905.25, 840.95, 896.4, 897.95, 893.35, 771.325, 833.85, 896.75, 900.3]}, {"task": "dmc_cheetah_run", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.27783203125, 78.90673828125, 341.18125, 614.8, 605.725, 647.475, 625.325, 572.675, 638.75, 617.125, 747.35, 798.2, 828.55, 816.35, 878.65, 783.95, 804.55, 864.55, 842.95, 844.95, 849.65, 869.35, 845.15, 830.6, 870.3, 849.9, 878.6, 876.5, 876.0, 837.65, 831.9, 844.8, 872.2, 842.6, 856.8, 875.65, 875.65, 874.0, 871.1, 886.15, 870.4, 866.15, 848.65, 877.9, 862.75, 874.85, 861.65, 875.45, 868.55, 867.15]}, {"task": "dmc_cheetah_run", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [2.75390625, 223.03671875, 565.275, 601.25, 618.6875, 627.35625, 656.6625, 668.0125, 749.85, 791.5, 673.525, 738.3, 698.09375, 765.675, 837.8, 867.05, 834.15, 857.8, 825.05, 777.25, 818.35, 868.6, 891.05, 865.9, 842.75, 871.1, 870.85, 890.3, 865.0, 903.1, 901.15, 893.6, 866.6, 899.65, 854.55, 899.6, 872.55, 857.2, 890.9, 863.35, 873.1, 881.05, 875.0, 890.45, 864.6, 875.6, 875.95, 861.85, 886.65, 866.2]}, {"task": "dmc_cheetah_run", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [4.5546875, 91.2501953125, 245.590625, 581.175, 668.45, 661.85, 716.55, 725.8, 720.65, 760.6, 762.25, 798.95, 826.2, 816.45, 727.8, 809.7, 795.85, 883.6, 842.3, 851.15, 800.0, 840.1, 841.9, 835.05, 874.45, 875.8, 897.85, 887.9, 878.15, 855.975, 889.2, 858.55, 901.1, 902.05, 884.25, 880.15, 891.1, 904.4, 895.2, 840.85, 845.15, 873.0, 895.0, 874.4, 867.0, 895.2, 876.6, 892.45, 879.05, 896.3]}, {"task": "dmc_cheetah_run", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [3.908203125, 135.6875, 447.6, 613.2, 602.025, 651.825, 697.7, 687.55, 715.8, 705.2, 656.55, 698.1, 693.15, 691.9, 715.75, 697.55, 743.85, 767.1, 739.65, 774.2, 765.65, 774.8, 767.1, 757.15, 775.65, 755.15, 815.15, 834.95, 875.35, 864.05, 893.45, 845.45, 847.75, 870.1, 877.1, 883.4, 859.05, 874.35, 883.75, 899.75, 888.7, 881.2, 893.3, 894.6, 883.9, 879.35, 845.3, 875.6, 897.85, 829.975]}, {"task": "dmc_cheetah_run", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.450439453125, 109.3404296875, 325.46875, 463.09375, 572.4625, 561.65, 653.375, 653.85, 614.45, 702.0, 724.4, 703.5, 735.15, 735.2, 681.8, 763.95, 738.7, 732.2, 738.45, 781.2, 714.475, 793.55, 762.0, 810.6, 801.45, 826.0, 879.45, 854.15, 850.05, 872.0, 881.35, 878.85, 882.95, 880.15, 872.95, 877.5, 856.45, 838.45, 868.15, 885.0, 857.15, 903.3, 875.95, 858.3, 847.75, 877.5, 896.65, 886.5, 871.9, 888.6]}, {"task": "dmc_finger_spin", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 22.3, 332.9, 411.7, 403.3, 300.3, 360.6, 355.2, 399.3, 945.3, 627.5, 468.0, 430.7, 460.8, 409.7, 369.2, 257.0, 323.0, 453.9, 588.3, 593.0, 426.0, 449.6, 316.4, 550.1, 418.4, 454.8, 602.6, 553.4, 322.0, 399.7, 395.3, 333.4, 605.5, 622.6, 704.5, 652.0, 683.9, 192.0, 205.2, 307.7, 470.1, 501.3, 449.1, 397.9, 522.3, 377.6, 368.8, 356.8, 447.3]}, {"task": "dmc_finger_spin", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.1, 0.0, 0.0, 16.4, 271.2, 329.3, 321.1, 341.1, 334.0, 320.2, 327.8, 331.0, 336.5, 326.3, 326.5, 338.3, 323.7, 313.4, 282.3, 284.8, 190.0, 238.0, 274.7, 186.7, 139.9, 185.0, 79.2, 86.1, 144.6, 183.6, 138.2, 219.9, 209.7, 188.4, 220.2, 203.6, 181.6, 128.2, 120.4, 126.5, 150.1, 206.1, 120.7, 154.7, 197.7, 161.2, 181.3, 150.5, 167.4]}, {"task": "dmc_finger_spin", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 12.4, 238.8, 394.4, 406.6, 426.8, 535.0, 654.1, 668.6, 825.9, 862.4, 894.5, 675.5, 738.4, 843.3, 663.9, 807.9, 732.9, 699.3, 698.9, 971.7, 984.4, 977.8, 979.4, 977.8, 980.2, 984.2, 982.7, 981.3, 976.5, 978.7, 976.9, 976.4, 833.9, 914.4, 276.4, 530.6, 542.4, 636.0, 459.4, 544.2, 408.1, 563.9, 463.4, 464.0, 553.6, 380.7, 363.7, 430.8, 531.1]}, {"task": "dmc_finger_spin", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 2.5, 51.9, 353.8, 368.8, 367.9, 392.4, 376.2, 432.9, 362.1, 428.6, 446.2, 389.9, 308.9, 345.4, 339.1, 223.1, 294.1, 407.2, 436.0, 466.0, 500.3, 532.9, 490.6, 531.9, 566.6, 624.9, 616.2, 323.2, 571.0, 615.9, 618.6, 336.8, 211.3, 278.0, 268.1, 323.4, 330.4, 395.7, 453.3, 427.1, 565.0, 589.3, 592.3, 574.0, 514.2, 395.9, 454.7, 399.1, 457.1]}, {"task": "dmc_finger_spin", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 2.8, 0.5, 0.0, 186.9, 326.6, 389.5, 400.0, 399.0, 377.4, 334.2, 398.2, 391.0, 308.6, 374.9, 277.3, 384.7, 323.4, 285.6, 129.9, 0.3, 188.0, 97.2, 304.1, 249.2, 277.6, 216.1, 184.7, 234.4, 308.8, 281.5, 306.5, 343.2, 274.9, 251.0, 268.7, 340.9, 423.3, 404.3, 383.8, 97.9, 384.5, 337.7, 361.6, 325.9, 250.7, 439.0, 384.1, 358.6, 564.5]}, {"task": "dmc_walker_walk", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [30.328125, 173.3328125, 592.075, 820.25, 903.3, 921.2, 946.8, 936.4, 947.0, 932.85, 963.25, 961.45, 874.340625, 956.35, 960.7, 958.95, 959.85, 967.95, 954.7, 965.9, 976.25, 952.1, 970.25, 962.0, 960.7, 965.9, 929.3, 961.7, 965.6, 966.8, 966.95, 956.15, 971.6, 930.25, 841.05, 848.75, 916.9, 966.5, 971.05, 959.75, 968.35, 968.3, 970.8, 959.5, 959.55, 879.7890625, 964.85, 967.45, 969.2, 979.15]}, {"task": "dmc_walker_walk", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [15.2578125, 81.34609375, 406.7625, 659.175, 781.95, 885.4, 929.6, 946.75, 859.496875, 955.45, 971.0, 961.0, 957.65, 935.6, 926.1, 918.7, 946.85, 925.6, 937.3, 915.8, 965.4, 876.9984375, 965.4, 953.7, 947.4, 969.15, 933.4, 862.9078125, 965.1, 970.9, 961.8, 961.8, 958.4, 972.25, 971.6, 897.05, 902.9, 936.9, 941.95, 945.55, 962.25, 965.5, 970.3, 971.85, 939.8, 857.8, 880.95, 899.8, 926.05, 890.25]}, {"task": "dmc_walker_walk", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [29.90625, 232.71875, 589.2, 823.25, 856.7, 909.3, 926.25, 956.35, 942.35, 957.35, 959.85, 967.3, 969.8, 970.7, 977.7, 959.5, 950.85, 969.2, 966.4, 953.3, 968.3, 970.15, 967.45, 966.25, 959.55, 970.25, 974.45, 970.15, 979.45, 956.5, 964.6, 964.0, 946.9, 960.6, 977.8, 970.85, 968.1, 967.9, 965.9, 960.0, 967.05, 886.2, 966.3, 964.3, 964.8, 969.3, 962.85, 968.2, 967.65, 971.8]}, {"task": "dmc_walker_walk", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [36.3125, 88.1796875, 480.825, 690.0, 739.6, 842.0, 887.7, 943.0, 940.3, 944.4, 940.5, 958.8, 961.35, 961.15, 964.7, 964.75, 974.9, 964.45, 966.35, 872.78125, 967.3, 964.4, 973.15, 971.9, 966.3, 970.5, 962.75, 963.4, 973.05, 966.4, 970.75, 968.15, 938.35, 973.35, 971.8, 961.55, 968.75, 972.05, 959.85, 923.6, 915.65, 865.6, 837.9, 873.65, 938.35, 976.45, 975.95, 972.45, 969.5, 965.75]}, {"task": "dmc_walker_walk", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0], "ys": [17.4375, 110.3671875, 473.0, 621.45, 767.35, 871.15, 934.8, 943.55, 957.0, 953.05, 943.35, 957.45, 965.15, 964.95, 963.5, 948.6, 961.65, 969.9, 962.2, 971.3, 969.1, 970.9, 972.9, 975.5, 978.3, 962.1, 964.8, 970.45, 966.5454545454545, 980.15, 965.15, 966.35, 972.05]}, {"task": "dmc_reacher_easy", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 123.2, 414.5, 576.7, 765.1, 548.9, 483.3, 534.1, 281.5, 684.0, 485.7, 714.0, 579.8, 670.1, 840.9, 870.5, 942.4, 772.2, 686.7, 967.0, 877.2, 871.7, 855.1, 949.6, 877.6, 962.3, 971.5, 775.4, 871.8, 977.8, 974.8, 980.2, 955.9, 971.1, 979.5, 963.8, 978.4, 877.0, 977.3, 973.7, 939.8, 968.7, 887.7, 979.0, 935.9, 978.9, 954.3, 983.3, 971.1, 982.0]}, {"task": "dmc_reacher_easy", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0], "ys": [0.0, 56.3, 118.9, 196.5, 355.6, 586.4, 669.9, 599.3, 602.0, 578.7, 653.0, 679.2, 690.6, 873.9, 692.9, 687.7, 981.2, 777.1, 975.3, 972.6, 967.7, 849.5, 877.1, 881.4, 874.6, 880.2, 969.3, 777.6, 983.6, 933.9, 971.9, 976.5, 969.9, 980.4, 879.8, 973.6, 959.8, 972.2, 971.6, 974.5, 973.0, 970.6, 982.2, 976.3, 968.2]}, {"task": "dmc_reacher_easy", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [86.0, 149.1, 117.7, 480.5, 477.9, 316.9, 641.9, 449.3, 654.1, 676.4, 855.0, 805.4, 429.0, 790.1, 686.6, 729.2, 602.4, 739.6, 961.8, 747.0, 765.7, 969.6, 845.0, 873.4, 945.2, 949.7, 878.1, 973.3, 958.7, 778.2, 889.5, 979.1, 967.6, 872.0, 883.1, 970.1, 978.1, 785.9, 865.2, 974.6, 874.8, 975.3, 974.7, 880.4, 964.1, 867.5, 784.2, 972.0, 977.1, 975.8]}, {"task": "dmc_reacher_easy", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [393.0, 171.9, 434.9, 478.0, 451.8, 119.3, 544.6, 580.8, 745.1, 504.6, 506.9, 498.8, 640.5, 527.7, 498.2, 768.2, 645.7, 767.4, 875.3, 763.5, 758.3, 907.6, 960.1, 769.4, 942.5, 921.1, 940.5, 975.8, 894.9, 973.8, 974.7, 964.6, 969.1, 872.2, 978.7, 878.7272727272727, 967.5, 974.4, 875.7, 972.3, 974.9, 883.4, 971.3, 964.9, 879.9, 974.9, 879.4, 966.9, 961.1, 971.1]}, {"task": "dmc_reacher_easy", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0], "ys": [0.0, 123.5, 108.8, 390.9, 499.6, 576.1, 617.7, 388.1, 562.6, 865.0, 512.8, 602.2, 470.4, 578.9, 681.5, 411.8, 975.9, 579.7, 783.4, 582.6, 915.7, 860.9, 867.6, 970.3, 971.6, 947.4, 973.7, 975.0, 973.4, 879.9, 969.6, 859.6, 945.4, 974.6, 877.3, 895.0, 970.7, 783.9]}, {"task": "dmc_cup_catch", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 168.0, 395.9, 910.7, 913.4, 956.3, 960.4, 953.8, 968.0, 966.7, 968.2, 965.8, 973.0, 972.4, 971.7, 971.1, 969.7, 968.7, 977.9, 976.9, 962.8, 978.6, 972.8, 971.3, 969.9, 973.7, 980.4, 977.4, 975.8, 973.4, 960.5, 973.3, 973.0, 880.7, 971.6, 972.7, 942.1, 970.7, 964.0, 974.4, 973.0, 964.3, 872.0, 968.9, 845.0, 941.4, 935.9, 970.6, 972.3, 950.8]}, {"task": "dmc_cup_catch", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.0, 713.7, 938.0, 940.6, 967.8, 969.8, 972.3, 973.0, 971.3, 974.6, 967.9, 968.0, 971.5, 966.2, 968.2, 974.8, 973.0, 955.7, 966.6, 897.4, 970.0, 916.0, 975.4, 976.7, 963.8, 971.4, 976.1, 963.5, 974.2, 969.4, 974.0, 965.5, 970.8, 976.5, 974.7, 969.7, 968.7, 951.2, 964.8, 968.4, 969.6, 969.8, 980.1, 970.5, 974.1, 977.7, 959.1, 974.7, 968.1]}, {"task": "dmc_cup_catch", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.0, 668.0, 908.6, 962.4, 962.4, 957.7, 970.6, 968.1, 969.4, 963.5, 969.4, 969.3, 965.1, 968.8, 967.8, 968.3, 970.9, 969.8, 968.7, 964.3, 960.5, 970.6, 975.3, 971.4, 968.0, 970.0, 968.3, 962.9, 973.6, 943.1, 977.0, 964.3, 952.1, 975.4, 943.3, 956.8, 975.3, 961.2, 965.7, 965.8, 967.6, 972.1, 970.9, 957.9, 972.0, 947.9, 965.2, 973.8, 957.5]}, {"task": "dmc_cup_catch", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.0, 544.8, 799.4, 954.3, 966.3, 959.8, 964.6, 965.5, 973.3, 971.0, 975.8, 967.0, 976.6, 966.2, 966.3, 973.3, 973.8, 867.3, 978.1, 920.1, 967.3, 963.7, 970.1, 967.1, 962.2, 968.1, 974.7, 968.8, 968.4, 967.5, 971.1, 976.2, 967.9, 977.4, 980.1, 976.2, 969.3, 974.0, 970.0, 972.2, 966.4, 939.0, 972.5, 972.1, 954.3, 975.5, 968.0, 973.8, 968.4]}, {"task": "dmc_cup_catch", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 199.4, 191.9, 804.5, 893.7, 955.9, 951.6, 962.4, 960.4, 964.0, 958.2, 971.5, 970.3, 966.8, 961.7, 969.2, 975.4, 967.5, 969.8, 977.7, 971.7, 971.6, 971.6, 959.2, 972.2, 971.7, 967.5, 975.5, 961.4, 979.2, 963.3, 974.2, 973.0, 969.5, 976.7, 970.6, 972.1, 973.5, 967.5, 951.7, 974.6, 968.5, 974.1, 969.3, 968.8, 963.7, 970.6, 970.2, 963.6, 959.7]}, {"task": "dmc_finger_turn_hard", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.6, 145.0, 277.3, 225.3, 357.6, 428.1, 329.4, 441.2, 476.8, 839.1, 774.7, 742.7, 934.0, 845.3, 703.8, 723.2, 861.1, 898.2, 849.1, 977.6, 957.9, 954.6, 852.5, 964.5, 965.7, 874.7, 899.6, 971.0, 837.0, 935.8, 897.5, 966.0, 957.6, 864.6, 962.9, 960.2, 962.7, 970.3, 897.3, 858.6, 962.3, 792.4, 870.1, 778.4, 963.1, 967.0, 969.8, 812.7, 870.7]}, {"task": "dmc_finger_turn_hard", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [14.0, 98.2, 105.6, 276.1, 187.2, 254.1, 634.3, 194.3, 377.2, 889.2, 866.9, 624.5, 763.7, 752.5, 809.5, 564.6, 767.6, 887.1, 949.3, 917.3, 863.3, 910.0, 814.7, 931.2, 883.0, 869.0, 934.7, 626.4, 709.4, 965.5, 593.5, 642.1, 682.3, 821.6, 955.9, 951.4, 862.3, 964.2, 868.2, 948.5, 938.0, 960.0, 958.1, 945.2, 856.0, 966.1, 958.0, 675.2, 971.8, 862.3]}, {"task": "dmc_finger_turn_hard", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.0, 0.0, 77.3, 86.3, 167.0, 133.5, 443.7, 494.2, 328.9, 528.8, 502.4, 628.0, 858.4, 677.7, 853.6, 826.9, 922.3, 756.0, 844.5, 949.0, 811.4, 955.2, 886.3, 838.6, 873.8, 841.3, 665.1, 718.7, 736.6, 552.6, 872.0, 662.0, 909.4, 764.0, 709.4, 842.7, 964.0, 775.7, 976.8, 815.2, 929.8, 758.7, 940.7, 927.2, 821.7, 867.5, 914.4, 876.0, 909.6]}, {"task": "dmc_finger_turn_hard", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 88.9, 51.0, 176.1, 263.9, 44.7, 179.2, 612.5, 626.5, 552.9, 597.3, 764.1, 644.7, 823.0, 658.4, 863.5, 961.2, 864.7, 938.5, 851.0, 850.0, 765.8, 671.5, 917.1, 646.9, 922.4, 778.2, 873.6, 861.7, 826.8, 960.7, 869.9, 965.1, 879.9, 971.2, 873.8, 964.3, 816.8, 840.2, 936.0, 868.7, 673.0, 952.0, 869.0, 844.6, 844.5, 846.0, 600.5, 948.7, 927.8]}, {"task": "dmc_finger_turn_hard", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0], "ys": [0.0, 2.8, 17.9, 134.0, 294.3, 374.2, 243.3, 376.1, 453.7, 628.0, 797.6, 625.9, 758.8, 887.2, 950.1, 818.5, 572.2, 677.5, 697.3, 906.4, 897.5, 531.6, 762.3, 949.4, 742.1, 902.3, 866.0, 882.8, 932.3, 886.5, 807.8, 925.2, 746.7, 962.3, 817.3, 943.6]}, {"task": "dmc_pendulum_swingup", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 492.0, 793.4, 858.0, 825.1, 850.6, 696.1, 865.6, 809.6, 772.0, 827.9, 838.1, 774.4, 821.2, 763.1, 807.7, 845.9, 774.2, 869.2, 840.6, 879.1, 827.6, 753.7, 770.3, 800.9, 751.7, 747.5, 807.6, 844.0, 815.6, 705.0, 802.8, 843.4, 716.5, 788.2, 847.7, 771.3, 834.7, 799.8, 829.1, 770.7, 852.6, 748.6, 823.1, 817.9, 725.0, 798.3, 810.3, 751.9, 757.0]}, {"task": "dmc_pendulum_swingup", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 443.6, 794.6, 875.8, 838.9, 830.6, 858.3, 859.7, 849.5, 822.0, 859.0, 860.7, 811.5, 748.7, 848.6, 721.1, 858.8, 821.5, 872.0, 809.4, 861.1, 811.2, 806.0, 663.1, 727.5, 783.6, 831.6, 816.1, 829.4, 844.1, 874.0, 837.1, 813.3, 861.3, 774.9, 866.1, 741.4, 755.7, 837.1, 809.8, 817.4, 754.1, 835.8, 866.1, 818.2, 794.1, 817.9, 736.6, 832.7, 795.3]}, {"task": "dmc_pendulum_swingup", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 412.8, 795.0, 841.4, 868.3, 741.6, 811.6, 798.9, 800.4, 838.3, 816.8, 801.8, 769.7, 798.0, 843.1, 803.7, 785.7, 783.9, 834.6, 753.9, 675.8, 769.3, 674.8, 780.6, 823.7, 663.4, 789.9, 773.1, 827.7, 779.4, 778.2, 821.0, 849.9, 689.1, 816.4, 869.6, 840.2, 763.3, 820.9, 868.3, 787.7, 785.8, 642.6, 654.2, 796.0, 817.3, 872.2, 855.0, 812.3, 809.5]}, {"task": "dmc_pendulum_swingup", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [56.0, 82.5, 758.3, 864.4, 782.0, 701.6, 873.1, 823.3, 830.3, 830.6, 827.6, 799.2, 871.1, 821.8, 792.7, 825.2, 868.3, 857.0, 826.1, 872.2, 890.6, 777.6363636363636, 831.8888888888889, 867.6, 798.6, 796.9, 882.1818181818181, 799.6666666666666, 823.6, 721.2, 828.8, 796.1818181818181, 886.3, 805.0, 820.0, 793.5, 803.0, 826.5555555555555, 795.6, 828.2, 842.9090909090909, 821.6, 843.0, 858.6, 791.6, 848.6, 855.4, 768.4, 721.6, 810.7]}, {"task": "dmc_pendulum_swingup", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 395.5, 729.4, 788.1, 782.3, 821.7, 751.9, 784.9, 869.8, 826.6, 809.4, 791.3, 825.7, 769.1, 872.0, 840.6, 740.7, 846.8, 834.6, 821.0, 786.4, 831.2, 851.7, 796.9, 824.6, 803.3, 805.7, 877.0, 836.0, 799.0, 867.5, 791.6, 840.3, 843.3, 820.8, 776.5, 800.7, 788.4, 822.8, 849.6, 825.9, 790.3, 803.6, 797.1, 815.7, 776.0, 806.7, 822.2, 822.0, 787.6]}, {"task": "dmc_reacher_hard", "method": "dreamer", "seed": "0", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 7.5, 42.3, 1.7, 5.0, 100.0, 189.7, 135.9, 95.6, 48.9, 0.7, 339.4, 184.6, 385.3, 618.1, 397.2, 494.8, 521.9, 552.2, 772.1, 459.4, 370.9, 442.4, 656.5, 579.9, 670.1, 683.6, 436.3, 677.0, 767.0, 777.4, 581.4, 873.5, 797.8, 778.7, 769.6, 770.0, 739.9, 677.1, 870.1, 672.2, 923.1, 868.6, 887.5, 875.2, 876.1, 901.5, 880.1, 712.3, 963.8]}, {"task": "dmc_reacher_hard", "method": "dreamer", "seed": "1", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.4, 0.5, 1.3, 115.0, 184.0, 98.0, 95.8, 77.6, 143.7, 24.9, 484.7, 438.0, 477.4, 305.1, 294.2, 483.4, 773.3, 289.8, 287.6, 463.3, 753.9, 681.0, 689.3, 578.3, 382.3, 585.6, 578.5, 624.2, 841.5, 582.5, 386.8, 871.7, 777.4, 567.6, 773.8, 774.2, 857.3, 769.5, 780.7, 818.6, 583.7, 875.6, 772.4, 969.9, 977.0, 873.9, 821.8, 889.9, 967.8]}, {"task": "dmc_reacher_hard", "method": "dreamer", "seed": "4", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 2.9, 99.0, 85.6, 58.3, 81.9, 85.6, 4.0, 103.0, 219.5, 91.5, 474.7, 197.0, 575.8, 489.7, 386.3, 672.2, 671.3, 490.6, 583.7, 832.6, 783.8, 426.7, 324.6, 876.4, 666.8, 746.7, 486.1, 767.7, 782.0, 871.4, 923.4, 872.6, 775.9, 870.2, 673.6, 871.2, 676.7, 874.8, 777.6, 942.9, 965.7, 778.1, 860.9, 870.8, 961.9, 972.1, 866.7, 868.1, 965.1]}, {"task": "dmc_reacher_hard", "method": "dreamer", "seed": "3", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 0.5, 4.7, 100.4, 231.9, 1.4, 479.0, 104.8, 99.7, 134.2, 192.6, 201.9, 623.2, 353.6, 100.0, 387.2, 381.3, 171.8, 280.7, 679.8, 584.2, 488.6, 575.0, 679.2, 482.5, 679.7, 482.3, 657.4, 675.2, 580.1, 956.5, 723.4, 866.2, 777.6, 771.8, 675.8, 676.7, 878.3, 579.9, 673.4, 579.2, 681.6, 678.5, 966.3, 782.2, 972.4, 593.7, 969.3, 784.9, 777.3]}, {"task": "dmc_reacher_hard", "method": "dreamer", "seed": "2", "xs": [5000.0, 105000.0, 205000.0, 305000.0, 405000.0, 505000.0, 605000.0, 705000.0, 805000.0, 905000.0, 1005000.0, 1105000.0, 1205000.0, 1305000.0, 1405000.0, 1505000.0, 1605000.0, 1705000.0, 1805000.0, 1905000.0, 2005000.0, 2105000.0, 2205000.0, 2305000.0, 2405000.0, 2505000.0, 2605000.0, 2705000.0, 2805000.0, 2905000.0, 3005000.0, 3105000.0, 3205000.0, 3305000.0, 3405000.0, 3505000.0, 3605000.0, 3705000.0, 3805000.0, 3905000.0, 4005000.0, 4105000.0, 4205000.0, 4305000.0, 4405000.0, 4505000.0, 4605000.0, 4705000.0, 4805000.0, 4905000.0], "ys": [0.0, 1.3, 21.9, 42.3, 19.5, 108.9, 71.2, 0.3, 489.7, 40.7, 537.3, 316.6, 561.2, 452.1, 573.9, 577.0, 645.6, 483.8, 470.2, 573.5, 630.0, 479.0, 844.8, 494.4, 481.9, 773.5, 460.6, 518.9, 712.4, 639.3, 776.2, 777.1, 971.7, 868.9, 879.6, 870.9, 583.1, 755.2, 871.4, 776.9, 680.1, 675.9, 830.0, 778.0, 872.7, 866.1, 672.4, 739.8, 810.9, 871.1]}] -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import pathlib 4 | import pickle 5 | import re 6 | import uuid 7 | 8 | import gym 9 | import numpy as np 10 | import tensorflow as tf 11 | import tensorflow.compat.v1 as tf1 12 | import tensorflow_probability as tfp 13 | from tensorflow.keras.mixed_precision import experimental as prec 14 | from tensorflow_probability import distributions as tfd 15 | 16 | 17 | class AttrDict(dict): 18 | 19 | __setattr__ = dict.__setitem__ 20 | __getattr__ = dict.__getitem__ 21 | 22 | 23 | class Module(tf.Module): 24 | 25 | def save(self, filename): 26 | values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) 27 | with pathlib.Path(filename).open('wb') as f: 28 | pickle.dump(values, f) 29 | 30 | def load(self, filename): 31 | with pathlib.Path(filename).open('rb') as f: 32 | values = pickle.load(f) 33 | tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) 34 | 35 | def get(self, name, ctor, *args, **kwargs): 36 | # Create or get layer by name to avoid mentioning it in the constructor. 37 | if not hasattr(self, '_modules'): 38 | self._modules = {} 39 | if name not in self._modules: 40 | self._modules[name] = ctor(*args, **kwargs) 41 | return self._modules[name] 42 | 43 | 44 | def nest_summary(structure): 45 | if isinstance(structure, dict): 46 | return {k: nest_summary(v) for k, v in structure.items()} 47 | if isinstance(structure, list): 48 | return [nest_summary(v) for v in structure] 49 | if hasattr(structure, 'shape'): 50 | return str(structure.shape).replace(', ', 'x').strip('(), ') 51 | return '?' 52 | 53 | 54 | def graph_summary(writer, fn, *args): 55 | step = tf.summary.experimental.get_step() 56 | def inner(*args): 57 | tf.summary.experimental.set_step(step) 58 | with writer.as_default(): 59 | fn(*args) 60 | return tf.numpy_function(inner, args, []) 61 | 62 | 63 | def video_summary(name, video, step=None, fps=20): 64 | name = name if isinstance(name, str) else name.decode('utf-8') 65 | if np.issubdtype(video.dtype, np.floating): 66 | video = np.clip(255 * video, 0, 255).astype(np.uint8) 67 | B, T, H, W, C = video.shape 68 | try: 69 | frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) 70 | summary = tf1.Summary() 71 | image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C) 72 | image.encoded_image_string = encode_gif(frames, fps) 73 | summary.value.add(tag=name + '/gif', image=image) 74 | tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) 75 | except (IOError, OSError) as e: 76 | print('GIF summaries require ffmpeg in $PATH.', e) 77 | frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) 78 | tf.summary.image(name + '/grid', frames, step) 79 | 80 | 81 | def encode_gif(frames, fps): 82 | from subprocess import Popen, PIPE 83 | h, w, c = frames[0].shape 84 | pxfmt = {1: 'gray', 3: 'rgb24'}[c] 85 | cmd = ' '.join([ 86 | f'ffmpeg -y -f rawvideo -vcodec rawvideo', 87 | f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', 88 | f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', 89 | f'-r {fps:.02f} -f gif -']) 90 | proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) 91 | for image in frames: 92 | proc.stdin.write(image.tostring()) 93 | out, err = proc.communicate() 94 | if proc.returncode: 95 | raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) 96 | del proc 97 | return out 98 | 99 | 100 | def simulate(agent, envs, steps=0, episodes=0, state=None): 101 | # Initialize or unpack simulation state. 102 | if state is None: 103 | step, episode = 0, 0 104 | done = np.ones(len(envs), np.bool) 105 | length = np.zeros(len(envs), np.int32) 106 | obs = [None] * len(envs) 107 | agent_state = None 108 | else: 109 | step, episode, done, length, obs, agent_state = state 110 | while (steps and step < steps) or (episodes and episode < episodes): 111 | # Reset envs if necessary. 112 | if done.any(): 113 | indices = [index for index, d in enumerate(done) if d] 114 | promises = [envs[i].reset(blocking=False) for i in indices] 115 | for index, promise in zip(indices, promises): 116 | obs[index] = promise() 117 | # Step agents. 118 | obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} 119 | action, agent_state = agent(obs, done, agent_state) 120 | action = np.array(action) 121 | assert len(action) == len(envs) 122 | # Step envs. 123 | promises = [e.step(a, blocking=False) for e, a in zip(envs, action)] 124 | obs, _, done = zip(*[p()[:3] for p in promises]) 125 | obs = list(obs) 126 | done = np.stack(done) 127 | episode += int(done.sum()) 128 | length += 1 129 | step += (done * length).sum() 130 | length *= (1 - done) 131 | # Return new state to allow resuming the simulation. 132 | return (step - steps, episode - episodes, done, length, obs, agent_state) 133 | 134 | 135 | def count_episodes(directory): 136 | filenames = directory.glob('*.npz') 137 | lengths = [int(n.stem.rsplit('-', 1)[-1]) - 1 for n in filenames] 138 | episodes, steps = len(lengths), sum(lengths) 139 | return episodes, steps 140 | 141 | 142 | def save_episodes(directory, episodes): 143 | directory = pathlib.Path(directory).expanduser() 144 | directory.mkdir(parents=True, exist_ok=True) 145 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 146 | for episode in episodes: 147 | identifier = str(uuid.uuid4().hex) 148 | length = len(episode['reward']) 149 | filename = directory / f'{timestamp}-{identifier}-{length}.npz' 150 | with io.BytesIO() as f1: 151 | np.savez_compressed(f1, **episode) 152 | f1.seek(0) 153 | with filename.open('wb') as f2: 154 | f2.write(f1.read()) 155 | 156 | 157 | def load_episodes(directory, rescan, length=None, balance=False, seed=0): 158 | directory = pathlib.Path(directory).expanduser() 159 | random = np.random.RandomState(seed) 160 | cache = {} 161 | while True: 162 | for filename in directory.glob('*.npz'): 163 | if filename not in cache: 164 | try: 165 | with filename.open('rb') as f: 166 | episode = np.load(f) 167 | episode = {k: episode[k] for k in episode.keys()} 168 | except Exception as e: 169 | print(f'Could not load episode: {e}') 170 | continue 171 | cache[filename] = episode 172 | keys = list(cache.keys()) 173 | for index in random.choice(len(keys), rescan): 174 | episode = cache[keys[index]] 175 | if length: 176 | total = len(next(iter(episode.values()))) 177 | available = total - length 178 | if available < 1: 179 | print(f'Skipped short episode of length {available}.') 180 | continue 181 | if balance: 182 | index = min(random.randint(0, total), available) 183 | else: 184 | index = int(random.randint(0, available)) 185 | episode = {k: v[index: index + length] for k, v in episode.items()} 186 | yield episode 187 | 188 | 189 | class DummyEnv: 190 | 191 | def __init__(self): 192 | self._random = np.random.RandomState(seed=0) 193 | self._step = None 194 | 195 | @property 196 | def observation_space(self): 197 | low = np.zeros([64, 64, 3], dtype=np.uint8) 198 | high = 255 * np.ones([64, 64, 3], dtype=np.uint8) 199 | spaces = {'image': gym.spaces.Box(low, high)} 200 | return gym.spaces.Dict(spaces) 201 | 202 | @property 203 | def action_space(self): 204 | low = -np.ones([5], dtype=np.float32) 205 | high = np.ones([5], dtype=np.float32) 206 | return gym.spaces.Box(low, high) 207 | 208 | def reset(self): 209 | self._step = 0 210 | obs = self.observation_space.sample() 211 | return obs 212 | 213 | def step(self, action): 214 | obs = self.observation_space.sample() 215 | reward = self._random.uniform(0, 1) 216 | self._step += 1 217 | done = self._step >= 1000 218 | info = {} 219 | return obs, reward, done, info 220 | 221 | 222 | class SampleDist: 223 | 224 | def __init__(self, dist, samples=100): 225 | self._dist = dist 226 | self._samples = samples 227 | 228 | @property 229 | def name(self): 230 | return 'SampleDist' 231 | 232 | def __getattr__(self, name): 233 | return getattr(self._dist, name) 234 | 235 | def mean(self): 236 | samples = self._dist.sample(self._samples) 237 | return tf.reduce_mean(samples, 0) 238 | 239 | def mode(self): 240 | sample = self._dist.sample(self._samples) 241 | logprob = self._dist.log_prob(sample) 242 | return tf.gather(sample, tf.argmax(logprob))[0] 243 | 244 | def entropy(self): 245 | sample = self._dist.sample(self._samples) 246 | logprob = self.log_prob(sample) 247 | return -tf.reduce_mean(logprob, 0) 248 | 249 | 250 | class OneHotDist: 251 | 252 | def __init__(self, logits=None, probs=None): 253 | self._dist = tfd.Categorical(logits=logits, probs=probs) 254 | self._num_classes = self.mean().shape[-1] 255 | self._dtype = prec.global_policy().compute_dtype 256 | 257 | @property 258 | def name(self): 259 | return 'OneHotDist' 260 | 261 | def __getattr__(self, name): 262 | return getattr(self._dist, name) 263 | 264 | def prob(self, events): 265 | indices = tf.argmax(events, axis=-1) 266 | return self._dist.prob(indices) 267 | 268 | def log_prob(self, events): 269 | indices = tf.argmax(events, axis=-1) 270 | return self._dist.log_prob(indices) 271 | 272 | def mean(self): 273 | return self._dist.probs_parameter() 274 | 275 | def mode(self): 276 | return self._one_hot(self._dist.mode()) 277 | 278 | def sample(self, amount=None): 279 | amount = [amount] if amount else [] 280 | indices = self._dist.sample(*amount) 281 | sample = self._one_hot(indices) 282 | probs = self._dist.probs_parameter() 283 | sample += tf.cast(probs - tf.stop_gradient(probs), self._dtype) 284 | return sample 285 | 286 | def _one_hot(self, indices): 287 | return tf.one_hot(indices, self._num_classes, dtype=self._dtype) 288 | 289 | 290 | class TanhBijector(tfp.bijectors.Bijector): 291 | 292 | def __init__(self, validate_args=False, name='tanh'): 293 | super().__init__( 294 | forward_min_event_ndims=0, 295 | validate_args=validate_args, 296 | name=name) 297 | 298 | def _forward(self, x): 299 | return tf.nn.tanh(x) 300 | 301 | def _inverse(self, y): 302 | dtype = y.dtype 303 | y = tf.cast(y, tf.float32) 304 | y = tf.where( 305 | tf.less_equal(tf.abs(y), 1.), 306 | tf.clip_by_value(y, -0.99999997, 0.99999997), y) 307 | y = tf.atanh(y) 308 | y = tf.cast(y, dtype) 309 | return y 310 | 311 | def _forward_log_det_jacobian(self, x): 312 | log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) 313 | return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) 314 | 315 | 316 | def lambda_return( 317 | reward, value, pcont, bootstrap, lambda_, axis): 318 | # Setting lambda=1 gives a discounted Monte Carlo return. 319 | # Setting lambda=0 gives a fixed 1-step return. 320 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) 321 | if isinstance(pcont, (int, float)): 322 | pcont = pcont * tf.ones_like(reward) 323 | dims = list(range(reward.shape.ndims)) 324 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 325 | if axis != 0: 326 | reward = tf.transpose(reward, dims) 327 | value = tf.transpose(value, dims) 328 | pcont = tf.transpose(pcont, dims) 329 | if bootstrap is None: 330 | bootstrap = tf.zeros_like(value[-1]) 331 | next_values = tf.concat([value[1:], bootstrap[None]], 0) 332 | inputs = reward + pcont * next_values * (1 - lambda_) 333 | returns = static_scan( 334 | lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, 335 | (inputs, pcont), bootstrap, reverse=True) 336 | if axis != 0: 337 | returns = tf.transpose(returns, dims) 338 | return returns 339 | 340 | 341 | class Adam(tf.Module): 342 | 343 | def __init__(self, name, modules, lr, clip=None, wd=None, wdpattern=r'.*'): 344 | self._name = name 345 | self._modules = modules 346 | self._clip = clip 347 | self._wd = wd 348 | self._wdpattern = wdpattern 349 | self._opt = tf.optimizers.Adam(lr) 350 | self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic') 351 | self._variables = None 352 | 353 | @property 354 | def variables(self): 355 | return self._opt.variables() 356 | 357 | def __call__(self, tape, loss): 358 | if self._variables is None: 359 | variables = [module.variables for module in self._modules] 360 | self._variables = tf.nest.flatten(variables) 361 | count = sum(np.prod(x.shape) for x in self._variables) 362 | print(f'Found {count} {self._name} parameters.') 363 | assert len(loss.shape) == 0, loss.shape 364 | with tape: 365 | loss = self._opt.get_scaled_loss(loss) 366 | grads = tape.gradient(loss, self._variables) 367 | grads = self._opt.get_unscaled_gradients(grads) 368 | norm = tf.linalg.global_norm(grads) 369 | if self._clip: 370 | grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) 371 | if self._wd: 372 | context = tf.distribute.get_replica_context() 373 | context.merge_call(self._apply_weight_decay) 374 | self._opt.apply_gradients(zip(grads, self._variables)) 375 | return norm 376 | 377 | def _apply_weight_decay(self, strategy): 378 | print('Applied weight decay to variables:') 379 | for var in self._variables: 380 | if re.search(self._wdpattern, self._name + '/' + var.name): 381 | print('- ' + self._name + '/' + var.name) 382 | strategy.extended.update(var, lambda var: self._wd * var) 383 | 384 | 385 | def args_type(default): 386 | if isinstance(default, bool): 387 | return lambda x: bool(['False', 'True'].index(x)) 388 | if isinstance(default, int): 389 | return lambda x: float(x) if ('e' in x or '.' in x) else int(x) 390 | if isinstance(default, pathlib.Path): 391 | return lambda x: pathlib.Path(x).expanduser() 392 | return type(default) 393 | 394 | 395 | def static_scan(fn, inputs, start, reverse=False): 396 | last = start 397 | outputs = [[] for _ in tf.nest.flatten(start)] 398 | indices = range(len(tf.nest.flatten(inputs)[0])) 399 | if reverse: 400 | indices = reversed(indices) 401 | for index in indices: 402 | inp = tf.nest.map_structure(lambda x: x[index], inputs) 403 | last = fn(last, inp) 404 | [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] 405 | if reverse: 406 | outputs = [list(reversed(x)) for x in outputs] 407 | outputs = [tf.stack(x, 0) for x in outputs] 408 | return tf.nest.pack_sequence_as(start, outputs) 409 | 410 | 411 | def _mnd_sample(self, sample_shape=(), seed=None, name='sample'): 412 | return tf.random.normal( 413 | tuple(sample_shape) + tuple(self.event_shape), 414 | self.mean(), self.stddev(), self.dtype, seed, name) 415 | 416 | 417 | tfd.MultivariateNormalDiag.sample = _mnd_sample 418 | 419 | 420 | def _cat_sample(self, sample_shape=(), seed=None, name='sample'): 421 | assert len(sample_shape) in (0, 1), sample_shape 422 | assert len(self.logits_parameter().shape) == 2 423 | indices = tf.random.categorical( 424 | self.logits_parameter(), sample_shape[0] if sample_shape else 1, 425 | self.dtype, seed, name) 426 | if not sample_shape: 427 | indices = indices[..., 0] 428 | return indices 429 | 430 | 431 | tfd.Categorical.sample = _cat_sample 432 | 433 | 434 | class Every: 435 | 436 | def __init__(self, every): 437 | self._every = every 438 | self._last = None 439 | 440 | def __call__(self, step): 441 | if self._last is None: 442 | self._last = step 443 | return True 444 | if step >= self._last + self._every: 445 | self._last += self._every 446 | return True 447 | return False 448 | 449 | 450 | class Once: 451 | 452 | def __init__(self): 453 | self._once = True 454 | 455 | def __call__(self): 456 | if self._once: 457 | self._once = False 458 | return True 459 | return False 460 | -------------------------------------------------------------------------------- /wrappers.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import functools 3 | import sys 4 | import threading 5 | import traceback 6 | 7 | import gym 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | class DeepMindControl: 13 | 14 | def __init__(self, name, size=(64, 64), camera=None): 15 | domain, task = name.split('_', 1) 16 | if domain == 'cup': # Only domain with multiple words. 17 | domain = 'ball_in_cup' 18 | if isinstance(domain, str): 19 | from dm_control import suite 20 | self._env = suite.load(domain, task) 21 | else: 22 | assert task is None 23 | self._env = domain() 24 | self._size = size 25 | if camera is None: 26 | camera = dict(quadruped=2).get(domain, 0) 27 | self._camera = camera 28 | 29 | @property 30 | def observation_space(self): 31 | spaces = {} 32 | for key, value in self._env.observation_spec().items(): 33 | spaces[key] = gym.spaces.Box( 34 | -np.inf, np.inf, value.shape, dtype=np.float32) 35 | spaces['image'] = gym.spaces.Box( 36 | 0, 255, self._size + (3,), dtype=np.uint8) 37 | return gym.spaces.Dict(spaces) 38 | 39 | @property 40 | def action_space(self): 41 | spec = self._env.action_spec() 42 | return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) 43 | 44 | def step(self, action): 45 | time_step = self._env.step(action) 46 | obs = dict(time_step.observation) 47 | obs['image'] = self.render() 48 | reward = time_step.reward or 0 49 | done = time_step.last() 50 | info = {'discount': np.array(time_step.discount, np.float32)} 51 | return obs, reward, done, info 52 | 53 | def reset(self): 54 | time_step = self._env.reset() 55 | obs = dict(time_step.observation) 56 | obs['image'] = self.render() 57 | return obs 58 | 59 | def render(self, *args, **kwargs): 60 | if kwargs.get('mode', 'rgb_array') != 'rgb_array': 61 | raise ValueError("Only render mode 'rgb_array' is supported.") 62 | return self._env.physics.render(*self._size, camera_id=self._camera) 63 | 64 | 65 | class Atari: 66 | 67 | LOCK = threading.Lock() 68 | 69 | def __init__( 70 | self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, 71 | life_done=False, sticky_actions=True): 72 | import gym 73 | version = 0 if sticky_actions else 4 74 | name = ''.join(word.title() for word in name.split('_')) 75 | with self.LOCK: 76 | self._env = gym.make('{}NoFrameskip-v{}'.format(name, version)) 77 | self._action_repeat = action_repeat 78 | self._size = size 79 | self._grayscale = grayscale 80 | self._noops = noops 81 | self._life_done = life_done 82 | self._lives = None 83 | shape = self._env.observation_space.shape[:2] + (() if grayscale else (3,)) 84 | self._buffers = [np.empty(shape, dtype=np.uint8) for _ in range(2)] 85 | self._random = np.random.RandomState(seed=None) 86 | 87 | @property 88 | def observation_space(self): 89 | shape = self._size + (1 if self._grayscale else 3,) 90 | space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) 91 | return gym.spaces.Dict({'image': space}) 92 | 93 | @property 94 | def action_space(self): 95 | return self._env.action_space 96 | 97 | def close(self): 98 | return self._env.close() 99 | 100 | def reset(self): 101 | with self.LOCK: 102 | self._env.reset() 103 | noops = self._random.randint(1, self._noops + 1) 104 | for _ in range(noops): 105 | done = self._env.step(0)[2] 106 | if done: 107 | with self.LOCK: 108 | self._env.reset() 109 | self._lives = self._env.ale.lives() 110 | if self._grayscale: 111 | self._env.ale.getScreenGrayscale(self._buffers[0]) 112 | else: 113 | self._env.ale.getScreenRGB2(self._buffers[0]) 114 | self._buffers[1].fill(0) 115 | return self._get_obs() 116 | 117 | def step(self, action): 118 | total_reward = 0.0 119 | for step in range(self._action_repeat): 120 | _, reward, done, info = self._env.step(action) 121 | total_reward += reward 122 | if self._life_done: 123 | lives = self._env.ale.lives() 124 | done = done or lives < self._lives 125 | self._lives = lives 126 | if done: 127 | break 128 | elif step >= self._action_repeat - 2: 129 | index = step - (self._action_repeat - 2) 130 | if self._grayscale: 131 | self._env.ale.getScreenGrayscale(self._buffers[index]) 132 | else: 133 | self._env.ale.getScreenRGB2(self._buffers[index]) 134 | obs = self._get_obs() 135 | return obs, total_reward, done, info 136 | 137 | def render(self, mode): 138 | return self._env.render(mode) 139 | 140 | def _get_obs(self): 141 | if self._action_repeat > 1: 142 | np.maximum(self._buffers[0], self._buffers[1], out=self._buffers[0]) 143 | image = np.array(Image.fromarray(self._buffers[0]).resize( 144 | self._size, Image.BILINEAR)) 145 | image = np.clip(image, 0, 255).astype(np.uint8) 146 | image = image[:, :, None] if self._grayscale else image 147 | return {'image': image} 148 | 149 | 150 | class Collect: 151 | 152 | def __init__(self, env, callbacks=None, precision=32): 153 | self._env = env 154 | self._callbacks = callbacks or () 155 | self._precision = precision 156 | self._episode = None 157 | 158 | def __getattr__(self, name): 159 | return getattr(self._env, name) 160 | 161 | def step(self, action): 162 | obs, reward, done, info = self._env.step(action) 163 | obs = {k: self._convert(v) for k, v in obs.items()} 164 | transition = obs.copy() 165 | transition['action'] = action 166 | transition['reward'] = reward 167 | transition['discount'] = info.get('discount', np.array(1 - float(done))) 168 | self._episode.append(transition) 169 | if done: 170 | episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} 171 | episode = {k: self._convert(v) for k, v in episode.items()} 172 | info['episode'] = episode 173 | for callback in self._callbacks: 174 | callback(episode) 175 | return obs, reward, done, info 176 | 177 | def reset(self): 178 | obs = self._env.reset() 179 | transition = obs.copy() 180 | transition['action'] = np.zeros(self._env.action_space.shape) 181 | transition['reward'] = 0.0 182 | transition['discount'] = 1.0 183 | self._episode = [transition] 184 | return obs 185 | 186 | def _convert(self, value): 187 | value = np.array(value) 188 | if np.issubdtype(value.dtype, np.floating): 189 | dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] 190 | elif np.issubdtype(value.dtype, np.signedinteger): 191 | dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] 192 | elif np.issubdtype(value.dtype, np.uint8): 193 | dtype = np.uint8 194 | else: 195 | raise NotImplementedError(value.dtype) 196 | return value.astype(dtype) 197 | 198 | 199 | class TimeLimit: 200 | 201 | def __init__(self, env, duration): 202 | self._env = env 203 | self._duration = duration 204 | self._step = None 205 | 206 | def __getattr__(self, name): 207 | return getattr(self._env, name) 208 | 209 | def step(self, action): 210 | assert self._step is not None, 'Must reset environment.' 211 | obs, reward, done, info = self._env.step(action) 212 | self._step += 1 213 | if self._step >= self._duration: 214 | done = True 215 | if 'discount' not in info: 216 | info['discount'] = np.array(1.0).astype(np.float32) 217 | self._step = None 218 | return obs, reward, done, info 219 | 220 | def reset(self): 221 | self._step = 0 222 | return self._env.reset() 223 | 224 | 225 | class ActionRepeat: 226 | 227 | def __init__(self, env, amount): 228 | self._env = env 229 | self._amount = amount 230 | 231 | def __getattr__(self, name): 232 | return getattr(self._env, name) 233 | 234 | def step(self, action): 235 | done = False 236 | total_reward = 0 237 | current_step = 0 238 | while current_step < self._amount and not done: 239 | obs, reward, done, info = self._env.step(action) 240 | total_reward += reward 241 | current_step += 1 242 | return obs, total_reward, done, info 243 | 244 | 245 | class NormalizeActions: 246 | 247 | def __init__(self, env): 248 | self._env = env 249 | self._mask = np.logical_and( 250 | np.isfinite(env.action_space.low), 251 | np.isfinite(env.action_space.high)) 252 | self._low = np.where(self._mask, env.action_space.low, -1) 253 | self._high = np.where(self._mask, env.action_space.high, 1) 254 | 255 | def __getattr__(self, name): 256 | return getattr(self._env, name) 257 | 258 | @property 259 | def action_space(self): 260 | low = np.where(self._mask, -np.ones_like(self._low), self._low) 261 | high = np.where(self._mask, np.ones_like(self._low), self._high) 262 | return gym.spaces.Box(low, high, dtype=np.float32) 263 | 264 | def step(self, action): 265 | original = (action + 1) / 2 * (self._high - self._low) + self._low 266 | original = np.where(self._mask, original, action) 267 | return self._env.step(original) 268 | 269 | 270 | class ObsDict: 271 | 272 | def __init__(self, env, key='obs'): 273 | self._env = env 274 | self._key = key 275 | 276 | def __getattr__(self, name): 277 | return getattr(self._env, name) 278 | 279 | @property 280 | def observation_space(self): 281 | spaces = {self._key: self._env.observation_space} 282 | return gym.spaces.Dict(spaces) 283 | 284 | @property 285 | def action_space(self): 286 | return self._env.action_space 287 | 288 | def step(self, action): 289 | obs, reward, done, info = self._env.step(action) 290 | obs = {self._key: np.array(obs)} 291 | return obs, reward, done, info 292 | 293 | def reset(self): 294 | obs = self._env.reset() 295 | obs = {self._key: np.array(obs)} 296 | return obs 297 | 298 | 299 | class OneHotAction: 300 | 301 | def __init__(self, env): 302 | assert isinstance(env.action_space, gym.spaces.Discrete) 303 | self._env = env 304 | 305 | def __getattr__(self, name): 306 | return getattr(self._env, name) 307 | 308 | @property 309 | def action_space(self): 310 | shape = (self._env.action_space.n,) 311 | space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) 312 | space.sample = self._sample_action 313 | return space 314 | 315 | def step(self, action): 316 | index = np.argmax(action).astype(int) 317 | reference = np.zeros_like(action) 318 | reference[index] = 1 319 | if not np.allclose(reference, action): 320 | raise ValueError(f'Invalid one-hot action:\n{action}') 321 | return self._env.step(index) 322 | 323 | def reset(self): 324 | return self._env.reset() 325 | 326 | def _sample_action(self): 327 | actions = self._env.action_space.n 328 | index = self._random.randint(0, actions) 329 | reference = np.zeros(actions, dtype=np.float32) 330 | reference[index] = 1.0 331 | return reference 332 | 333 | 334 | class RewardObs: 335 | 336 | def __init__(self, env): 337 | self._env = env 338 | 339 | def __getattr__(self, name): 340 | return getattr(self._env, name) 341 | 342 | @property 343 | def observation_space(self): 344 | spaces = self._env.observation_space.spaces 345 | assert 'reward' not in spaces 346 | spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) 347 | return gym.spaces.Dict(spaces) 348 | 349 | def step(self, action): 350 | obs, reward, done, info = self._env.step(action) 351 | obs['reward'] = reward 352 | return obs, reward, done, info 353 | 354 | def reset(self): 355 | obs = self._env.reset() 356 | obs['reward'] = 0.0 357 | return obs 358 | 359 | 360 | class Async: 361 | 362 | _ACCESS = 1 363 | _CALL = 2 364 | _RESULT = 3 365 | _EXCEPTION = 4 366 | _CLOSE = 5 367 | 368 | def __init__(self, ctor, strategy='process'): 369 | self._strategy = strategy 370 | if strategy == 'none': 371 | self._env = ctor() 372 | elif strategy == 'thread': 373 | import multiprocessing.dummy as mp 374 | elif strategy == 'process': 375 | import multiprocessing as mp 376 | else: 377 | raise NotImplementedError(strategy) 378 | if strategy != 'none': 379 | self._conn, conn = mp.Pipe() 380 | self._process = mp.Process(target=self._worker, args=(ctor, conn)) 381 | atexit.register(self.close) 382 | self._process.start() 383 | self._obs_space = None 384 | self._action_space = None 385 | 386 | @property 387 | def observation_space(self): 388 | if not self._obs_space: 389 | self._obs_space = self.__getattr__('observation_space') 390 | return self._obs_space 391 | 392 | @property 393 | def action_space(self): 394 | if not self._action_space: 395 | self._action_space = self.__getattr__('action_space') 396 | return self._action_space 397 | 398 | def __getattr__(self, name): 399 | if self._strategy == 'none': 400 | return getattr(self._env, name) 401 | self._conn.send((self._ACCESS, name)) 402 | return self._receive() 403 | 404 | def call(self, name, *args, **kwargs): 405 | blocking = kwargs.pop('blocking', True) 406 | if self._strategy == 'none': 407 | return functools.partial(getattr(self._env, name), *args, **kwargs) 408 | payload = name, args, kwargs 409 | self._conn.send((self._CALL, payload)) 410 | promise = self._receive 411 | return promise() if blocking else promise 412 | 413 | def close(self): 414 | if self._strategy == 'none': 415 | try: 416 | self._env.close() 417 | except AttributeError: 418 | pass 419 | return 420 | try: 421 | self._conn.send((self._CLOSE, None)) 422 | self._conn.close() 423 | except IOError: 424 | # The connection was already closed. 425 | pass 426 | self._process.join() 427 | 428 | def step(self, action, blocking=True): 429 | return self.call('step', action, blocking=blocking) 430 | 431 | def reset(self, blocking=True): 432 | return self.call('reset', blocking=blocking) 433 | 434 | def _receive(self): 435 | try: 436 | message, payload = self._conn.recv() 437 | except ConnectionResetError: 438 | raise RuntimeError('Environment worker crashed.') 439 | # Re-raise exceptions in the main process. 440 | if message == self._EXCEPTION: 441 | stacktrace = payload 442 | raise Exception(stacktrace) 443 | if message == self._RESULT: 444 | return payload 445 | raise KeyError(f'Received message of unexpected type {message}') 446 | 447 | def _worker(self, ctor, conn): 448 | try: 449 | env = ctor() 450 | while True: 451 | try: 452 | # Only block for short times to have keyboard exceptions be raised. 453 | if not conn.poll(0.1): 454 | continue 455 | message, payload = conn.recv() 456 | except (EOFError, KeyboardInterrupt): 457 | break 458 | if message == self._ACCESS: 459 | name = payload 460 | result = getattr(env, name) 461 | conn.send((self._RESULT, result)) 462 | continue 463 | if message == self._CALL: 464 | name, args, kwargs = payload 465 | result = getattr(env, name)(*args, **kwargs) 466 | conn.send((self._RESULT, result)) 467 | continue 468 | if message == self._CLOSE: 469 | assert payload is None 470 | break 471 | raise KeyError(f'Received message of unknown type {message}') 472 | except Exception: 473 | stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 474 | print(f'Error in environment process: {stacktrace}') 475 | conn.send((self._EXCEPTION, stacktrace)) 476 | conn.close() 477 | --------------------------------------------------------------------------------