├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── dreamerv2 ├── agent.py ├── api.py ├── common │ ├── __init__.py │ ├── config.py │ ├── counter.py │ ├── dists.py │ ├── driver.py │ ├── envs.py │ ├── flags.py │ ├── logger.py │ ├── nets.py │ ├── other.py │ ├── plot.py │ ├── replay.py │ ├── tfutils.py │ └── when.py ├── configs.yaml ├── expl.py └── train.py ├── examples └── minigrid.py ├── scores ├── atari-dopamine.json ├── atari-dreamerv2-schedules.json ├── atari-dreamerv2.json ├── baselines.json ├── dmc-vision-dreamerv2.json ├── humanoid-dreamerv2.json └── montezuma-dreamerv2.json └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.egg-info 4 | dist 5 | MUJOCO_LOG.TXT 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.4.2-gpu 2 | 3 | # System packages. 4 | RUN apt-get update && apt-get install -y \ 5 | ffmpeg \ 6 | libgl1-mesa-dev \ 7 | python3-pip \ 8 | unrar \ 9 | wget \ 10 | && apt-get clean 11 | 12 | # MuJoCo. 13 | ENV MUJOCO_GL egl 14 | RUN mkdir -p /root/.mujoco && \ 15 | wget -nv https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip && \ 16 | unzip mujoco.zip -d /root/.mujoco && \ 17 | rm mujoco.zip 18 | 19 | # Python packages. 20 | RUN pip3 install --no-cache-dir \ 21 | 'gym[atari]' \ 22 | atari_py \ 23 | crafter \ 24 | dm_control \ 25 | ruamel.yaml \ 26 | tensorflow_probability==0.12.2 27 | 28 | # Atari ROMS. 29 | RUN wget -L -nv http://www.atarimania.com/roms/Roms.rar && \ 30 | unrar x Roms.rar && \ 31 | unzip ROMS.zip && \ 32 | python3 -m atari_py.import_roms ROMS && \ 33 | rm -rf Roms.rar ROMS.zip ROMS 34 | 35 | # MuJoCo key. 36 | ARG MUJOCO_KEY="" 37 | RUN echo "$MUJOCO_KEY" > /root/.mujoco/mjkey.txt 38 | RUN cat /root/.mujoco/mjkey.txt 39 | 40 | # DreamerV2. 41 | ENV TF_XLA_FLAGS --tf_xla_auto_jit=2 42 | COPY . /app 43 | WORKDIR /app 44 | CMD [ \ 45 | "python3", "dreamerv2/train.py", \ 46 | "--logdir", "/logdir/$(date +%Y%m%d-%H%M%S)", \ 47 | "--configs", "defaults", "atari", \ 48 | "--task", "atari_pong" \ 49 | ] 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Danijar Hafner 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Stable release 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/dreamerv2.svg)](https://pypi.python.org/pypi/dreamerv2/#history) 4 | 5 | # Mastering Atari with Discrete World Models 6 | 7 | Implementation of the [DreamerV2][website] agent in TensorFlow 2. Training 8 | curves for all 55 games are included. 9 | 10 |

11 | 12 |

13 | 14 | If you find this code useful, please reference in your paper: 15 | 16 | ``` 17 | @article{hafner2020dreamerv2, 18 | title={Mastering Atari with Discrete World Models}, 19 | author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy}, 20 | journal={arXiv preprint arXiv:2010.02193}, 21 | year={2020} 22 | } 23 | ``` 24 | 25 | [website]: https://danijar.com/dreamerv2 26 | 27 | ## Method 28 | 29 | DreamerV2 is the first world model agent that achieves human-level performance 30 | on the Atari benchmark. DreamerV2 also outperforms the final performance of the 31 | top model-free agents Rainbow and IQN using the same amount of experience and 32 | computation. The implementation in this repository alternates between training 33 | the world model, training the policy, and collecting experience and runs on a 34 | single GPU. 35 | 36 | ![World Model Learning](https://imgur.com/GRC9QAw.png) 37 | 38 | DreamerV2 learns a model of the environment directly from high-dimensional 39 | input images. For this, it predicts ahead using compact learned states. The 40 | states consist of a deterministic part and several categorical variables that 41 | are sampled. The prior for these categoricals is learned through a KL loss. The 42 | world model is learned end-to-end via straight-through gradients, meaning that 43 | the gradient of the density is set to the gradient of the sample. 44 | 45 | ![Actor Critic Learning](https://imgur.com/wH9kJ2O.png) 46 | 47 | DreamerV2 learns actor and critic networks from imagined trajectories of latent 48 | states. The trajectories start at encoded states of previously encountered 49 | sequences. The world model then predicts ahead using the selected actions and 50 | its learned state prior. The critic is trained using temporal difference 51 | learning and the actor is trained to maximize the value function via reinforce 52 | and straight-through gradients. 53 | 54 | For more information: 55 | 56 | - [Google AI Blog post](https://ai.googleblog.com/2021/02/mastering-atari-with-discrete-world.html) 57 | - [Project website](https://danijar.com/dreamerv2/) 58 | - [Research paper](https://arxiv.org/pdf/2010.02193.pdf) 59 | 60 | ## Using the Package 61 | 62 | The easiest way to run DreamerV2 on new environments is to install the package 63 | via `pip3 install dreamerv2`. The code automatically detects whether the 64 | environment uses discrete or continuous actions. Here is a usage example that 65 | trains DreamerV2 on the MiniGrid environment: 66 | 67 | ```python 68 | import gym 69 | import gym_minigrid 70 | import dreamerv2.api as dv2 71 | 72 | config = dv2.defaults.update({ 73 | 'logdir': '~/logdir/minigrid', 74 | 'log_every': 1e3, 75 | 'train_every': 10, 76 | 'prefill': 1e5, 77 | 'actor_ent': 3e-3, 78 | 'loss_scales.kl': 1.0, 79 | 'discount': 0.99, 80 | }).parse_flags() 81 | 82 | env = gym.make('MiniGrid-DoorKey-6x6-v0') 83 | env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(env) 84 | dv2.train(env, config) 85 | ``` 86 | 87 | ## Manual Instructions 88 | 89 | To modify the DreamerV2 agent, clone the repository and follow the instructions 90 | below. There is also a Dockerfile available, in case you do not want to install 91 | the dependencies on your system. 92 | 93 | Get dependencies: 94 | 95 | ```sh 96 | pip3 install tensorflow==2.6.0 tensorflow_probability ruamel.yaml 'gym[atari]' dm_control 97 | ``` 98 | 99 | Train on Atari: 100 | 101 | ```sh 102 | python3 dreamerv2/train.py --logdir ~/logdir/atari_pong/dreamerv2/1 \ 103 | --configs atari --task atari_pong 104 | ``` 105 | 106 | Train on DM Control: 107 | 108 | ```sh 109 | python3 dreamerv2/train.py --logdir ~/logdir/dmc_walker_walk/dreamerv2/1 \ 110 | --configs dmc_vision --task dmc_walker_walk 111 | ``` 112 | 113 | Monitor results: 114 | 115 | ```sh 116 | tensorboard --logdir ~/logdir 117 | ``` 118 | 119 | Generate plots: 120 | 121 | ```sh 122 | python3 common/plot.py --indir ~/logdir --outdir ~/plots \ 123 | --xaxis step --yaxis eval_return --bins 1e6 124 | ``` 125 | 126 | ## Docker Instructions 127 | 128 | The [Dockerfile](https://github.com/danijar/dreamerv2/blob/main/Dockerfile) 129 | lets you run DreamerV2 without installing its dependencies in your system. This 130 | requires you to have Docker with GPU access set up. 131 | 132 | Check your setup: 133 | 134 | ```sh 135 | docker run -it --rm --gpus all tensorflow/tensorflow:2.4.2-gpu nvidia-smi 136 | ``` 137 | 138 | Train on Atari: 139 | 140 | ```sh 141 | docker build -t dreamerv2 . 142 | docker run -it --rm --gpus all -v ~/logdir:/logdir dreamerv2 \ 143 | python3 dreamerv2/train.py --logdir /logdir/atari_pong/dreamerv2/1 \ 144 | --configs atari --task atari_pong 145 | ``` 146 | 147 | Train on DM Control: 148 | 149 | ```sh 150 | docker build -t dreamerv2 . --build-arg MUJOCO_KEY="$(cat ~/.mujoco/mjkey.txt)" 151 | docker run -it --rm --gpus all -v ~/logdir:/logdir dreamerv2 \ 152 | python3 dreamerv2/train.py --logdir /logdir/dmc_walker_walk/dreamerv2/1 \ 153 | --configs dmc_vision --task dmc_walker_walk 154 | ``` 155 | 156 | ## Tips 157 | 158 | - **Efficient debugging.** You can use the `debug` config as in `--configs 159 | atari debug`. This reduces the batch size, increases the evaluation 160 | frequency, and disables `tf.function` graph compilation for easy line-by-line 161 | debugging. 162 | 163 | - **Infinite gradient norms.** This is normal and described under loss scaling in 164 | the [mixed precision][mixed] guide. You can disable mixed precision by passing 165 | `--precision 32` to the training script. Mixed precision is faster but can in 166 | principle cause numerical instabilities. 167 | 168 | - **Accessing logged metrics.** The metrics are stored in both TensorBoard and 169 | JSON lines format. You can directly load them using `pandas.read_json()`. The 170 | plotting script also stores the binned and aggregated metrics of multiple runs 171 | into a single JSON file for easy manual plotting. 172 | 173 | [mixed]: https://www.tensorflow.org/guide/mixed_precision 174 | -------------------------------------------------------------------------------- /dreamerv2/agent.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import mixed_precision as prec 3 | 4 | import common 5 | import expl 6 | 7 | 8 | class Agent(common.Module): 9 | 10 | def __init__(self, config, obs_space, act_space, step): 11 | self.config = config 12 | self.obs_space = obs_space 13 | self.act_space = act_space['action'] 14 | self.step = step 15 | self.tfstep = tf.Variable(int(self.step), tf.int64) 16 | self.wm = WorldModel(config, obs_space, self.tfstep) 17 | self._task_behavior = ActorCritic(config, self.act_space, self.tfstep) 18 | if config.expl_behavior == 'greedy': 19 | self._expl_behavior = self._task_behavior 20 | else: 21 | self._expl_behavior = getattr(expl, config.expl_behavior)( 22 | self.config, self.act_space, self.wm, self.tfstep, 23 | lambda seq: self.wm.heads['reward'](seq['feat']).mode()) 24 | 25 | @tf.function 26 | def policy(self, obs, state=None, mode='train'): 27 | obs = tf.nest.map_structure(tf.tensor, obs) 28 | tf.py_function(lambda: self.tfstep.assign( 29 | int(self.step), read_value=False), [], []) 30 | if state is None: 31 | latent = self.wm.rssm.initial(len(obs['reward'])) 32 | action = tf.zeros((len(obs['reward']),) + self.act_space.shape) 33 | state = latent, action 34 | latent, action = state 35 | embed = self.wm.encoder(self.wm.preprocess(obs)) 36 | sample = (mode == 'train') or not self.config.eval_state_mean 37 | latent, _ = self.wm.rssm.obs_step( 38 | latent, action, embed, obs['is_first'], sample) 39 | feat = self.wm.rssm.get_feat(latent) 40 | if mode == 'eval': 41 | actor = self._task_behavior.actor(feat) 42 | action = actor.mode() 43 | noise = self.config.eval_noise 44 | elif mode == 'explore': 45 | actor = self._expl_behavior.actor(feat) 46 | action = actor.sample() 47 | noise = self.config.expl_noise 48 | elif mode == 'train': 49 | actor = self._task_behavior.actor(feat) 50 | action = actor.sample() 51 | noise = self.config.expl_noise 52 | action = common.action_noise(action, noise, self.act_space) 53 | outputs = {'action': action} 54 | state = (latent, action) 55 | return outputs, state 56 | 57 | @tf.function 58 | def train(self, data, state=None): 59 | metrics = {} 60 | state, outputs, mets = self.wm.train(data, state) 61 | metrics.update(mets) 62 | start = outputs['post'] 63 | reward = lambda seq: self.wm.heads['reward'](seq['feat']).mode() 64 | metrics.update(self._task_behavior.train( 65 | self.wm, start, data['is_terminal'], reward)) 66 | if self.config.expl_behavior != 'greedy': 67 | mets = self._expl_behavior.train(start, outputs, data)[-1] 68 | metrics.update({'expl_' + key: value for key, value in mets.items()}) 69 | return state, metrics 70 | 71 | @tf.function 72 | def report(self, data): 73 | report = {} 74 | data = self.wm.preprocess(data) 75 | for key in self.wm.heads['decoder'].cnn_keys: 76 | name = key.replace('/', '_') 77 | report[f'openl_{name}'] = self.wm.video_pred(data, key) 78 | return report 79 | 80 | 81 | class WorldModel(common.Module): 82 | 83 | def __init__(self, config, obs_space, tfstep): 84 | shapes = {k: tuple(v.shape) for k, v in obs_space.items()} 85 | self.config = config 86 | self.tfstep = tfstep 87 | self.rssm = common.EnsembleRSSM(**config.rssm) 88 | self.encoder = common.Encoder(shapes, **config.encoder) 89 | self.heads = {} 90 | self.heads['decoder'] = common.Decoder(shapes, **config.decoder) 91 | self.heads['reward'] = common.MLP([], **config.reward_head) 92 | if config.pred_discount: 93 | self.heads['discount'] = common.MLP([], **config.discount_head) 94 | for name in config.grad_heads: 95 | assert name in self.heads, name 96 | self.model_opt = common.Optimizer('model', **config.model_opt) 97 | 98 | def train(self, data, state=None): 99 | with tf.GradientTape() as model_tape: 100 | model_loss, state, outputs, metrics = self.loss(data, state) 101 | modules = [self.encoder, self.rssm, *self.heads.values()] 102 | metrics.update(self.model_opt(model_tape, model_loss, modules)) 103 | return state, outputs, metrics 104 | 105 | def loss(self, data, state=None): 106 | data = self.preprocess(data) 107 | embed = self.encoder(data) 108 | post, prior = self.rssm.observe( 109 | embed, data['action'], data['is_first'], state) 110 | kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.config.kl) 111 | assert len(kl_loss.shape) == 0 112 | likes = {} 113 | losses = {'kl': kl_loss} 114 | feat = self.rssm.get_feat(post) 115 | for name, head in self.heads.items(): 116 | grad_head = (name in self.config.grad_heads) 117 | inp = feat if grad_head else tf.stop_gradient(feat) 118 | out = head(inp) 119 | dists = out if isinstance(out, dict) else {name: out} 120 | for key, dist in dists.items(): 121 | like = tf.cast(dist.log_prob(data[key]), tf.float32) 122 | likes[key] = like 123 | losses[key] = -like.mean() 124 | model_loss = sum( 125 | self.config.loss_scales.get(k, 1.0) * v for k, v in losses.items()) 126 | outs = dict( 127 | embed=embed, feat=feat, post=post, 128 | prior=prior, likes=likes, kl=kl_value) 129 | metrics = {f'{name}_loss': value for name, value in losses.items()} 130 | metrics['model_kl'] = kl_value.mean() 131 | metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean() 132 | metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean() 133 | last_state = {k: v[:, -1] for k, v in post.items()} 134 | return model_loss, last_state, outs, metrics 135 | 136 | def imagine(self, policy, start, is_terminal, horizon): 137 | flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) 138 | start = {k: flatten(v) for k, v in start.items()} 139 | start['feat'] = self.rssm.get_feat(start) 140 | start['action'] = tf.zeros_like(policy(start['feat']).mode()) 141 | seq = {k: [v] for k, v in start.items()} 142 | for _ in range(horizon): 143 | action = policy(tf.stop_gradient(seq['feat'][-1])).sample() 144 | state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action) 145 | feat = self.rssm.get_feat(state) 146 | for key, value in {**state, 'action': action, 'feat': feat}.items(): 147 | seq[key].append(value) 148 | seq = {k: tf.stack(v, 0) for k, v in seq.items()} 149 | if 'discount' in self.heads: 150 | disc = self.heads['discount'](seq['feat']).mean() 151 | if is_terminal is not None: 152 | # Override discount prediction for the first step with the true 153 | # discount factor from the replay buffer. 154 | true_first = 1.0 - flatten(is_terminal).astype(disc.dtype) 155 | true_first *= self.config.discount 156 | disc = tf.concat([true_first[None], disc[1:]], 0) 157 | else: 158 | disc = self.config.discount * tf.ones(seq['feat'].shape[:-1]) 159 | seq['discount'] = disc 160 | # Shift discount factors because they imply whether the following state 161 | # will be valid, not whether the current state is valid. 162 | seq['weight'] = tf.math.cumprod( 163 | tf.concat([tf.ones_like(disc[:1]), disc[:-1]], 0), 0) 164 | return seq 165 | 166 | @tf.function 167 | def preprocess(self, obs): 168 | dtype = prec.global_policy().compute_dtype 169 | obs = obs.copy() 170 | for key, value in obs.items(): 171 | if key.startswith('log_'): 172 | continue 173 | if value.dtype == tf.int32: 174 | value = value.astype(dtype) 175 | if value.dtype == tf.uint8: 176 | value = value.astype(dtype) / 255.0 - 0.5 177 | obs[key] = value 178 | obs['reward'] = { 179 | 'identity': tf.identity, 180 | 'sign': tf.sign, 181 | 'tanh': tf.tanh, 182 | }[self.config.clip_rewards](obs['reward']) 183 | obs['discount'] = 1.0 - obs['is_terminal'].astype(dtype) 184 | obs['discount'] *= self.config.discount 185 | return obs 186 | 187 | @tf.function 188 | def video_pred(self, data, key): 189 | decoder = self.heads['decoder'] 190 | truth = data[key][:6] + 0.5 191 | embed = self.encoder(data) 192 | states, _ = self.rssm.observe( 193 | embed[:6, :5], data['action'][:6, :5], data['is_first'][:6, :5]) 194 | recon = decoder(self.rssm.get_feat(states))[key].mode()[:6] 195 | init = {k: v[:, -1] for k, v in states.items()} 196 | prior = self.rssm.imagine(data['action'][:6, 5:], init) 197 | openl = decoder(self.rssm.get_feat(prior))[key].mode() 198 | model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) 199 | error = (model - truth + 1) / 2 200 | video = tf.concat([truth, model, error], 2) 201 | B, T, H, W, C = video.shape 202 | return video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) 203 | 204 | 205 | class ActorCritic(common.Module): 206 | 207 | def __init__(self, config, act_space, tfstep): 208 | self.config = config 209 | self.act_space = act_space 210 | self.tfstep = tfstep 211 | discrete = hasattr(act_space, 'n') 212 | if self.config.actor.dist == 'auto': 213 | self.config = self.config.update({ 214 | 'actor.dist': 'onehot' if discrete else 'trunc_normal'}) 215 | if self.config.actor_grad == 'auto': 216 | self.config = self.config.update({ 217 | 'actor_grad': 'reinforce' if discrete else 'dynamics'}) 218 | self.actor = common.MLP(act_space.shape[0], **self.config.actor) 219 | self.critic = common.MLP([], **self.config.critic) 220 | if self.config.slow_target: 221 | self._target_critic = common.MLP([], **self.config.critic) 222 | self._updates = tf.Variable(0, tf.int64) 223 | else: 224 | self._target_critic = self.critic 225 | self.actor_opt = common.Optimizer('actor', **self.config.actor_opt) 226 | self.critic_opt = common.Optimizer('critic', **self.config.critic_opt) 227 | self.rewnorm = common.StreamNorm(**self.config.reward_norm) 228 | 229 | def train(self, world_model, start, is_terminal, reward_fn): 230 | metrics = {} 231 | hor = self.config.imag_horizon 232 | # The weights are is_terminal flags for the imagination start states. 233 | # Technically, they should multiply the losses from the second trajectory 234 | # step onwards, which is the first imagined step. However, we are not 235 | # training the action that led into the first step anyway, so we can use 236 | # them to scale the whole sequence. 237 | with tf.GradientTape() as actor_tape: 238 | seq = world_model.imagine(self.actor, start, is_terminal, hor) 239 | reward = reward_fn(seq) 240 | seq['reward'], mets1 = self.rewnorm(reward) 241 | mets1 = {f'reward_{k}': v for k, v in mets1.items()} 242 | target, mets2 = self.target(seq) 243 | actor_loss, mets3 = self.actor_loss(seq, target) 244 | with tf.GradientTape() as critic_tape: 245 | critic_loss, mets4 = self.critic_loss(seq, target) 246 | metrics.update(self.actor_opt(actor_tape, actor_loss, self.actor)) 247 | metrics.update(self.critic_opt(critic_tape, critic_loss, self.critic)) 248 | metrics.update(**mets1, **mets2, **mets3, **mets4) 249 | self.update_slow_target() # Variables exist after first forward pass. 250 | return metrics 251 | 252 | def actor_loss(self, seq, target): 253 | # Actions: 0 [a1] [a2] a3 254 | # ^ | ^ | ^ | 255 | # / v / v / v 256 | # States: [z0]->[z1]-> z2 -> z3 257 | # Targets: t0 [t1] [t2] 258 | # Baselines: [v0] [v1] v2 v3 259 | # Entropies: [e1] [e2] 260 | # Weights: [ 1] [w1] w2 w3 261 | # Loss: l1 l2 262 | metrics = {} 263 | # Two states are lost at the end of the trajectory, one for the boostrap 264 | # value prediction and one because the corresponding action does not lead 265 | # anywhere anymore. One target is lost at the start of the trajectory 266 | # because the initial state comes from the replay buffer. 267 | policy = self.actor(tf.stop_gradient(seq['feat'][:-2])) 268 | if self.config.actor_grad == 'dynamics': 269 | objective = target[1:] 270 | elif self.config.actor_grad == 'reinforce': 271 | baseline = self._target_critic(seq['feat'][:-2]).mode() 272 | advantage = tf.stop_gradient(target[1:] - baseline) 273 | action = tf.stop_gradient(seq['action'][1:-1]) 274 | objective = policy.log_prob(action) * advantage 275 | elif self.config.actor_grad == 'both': 276 | baseline = self._target_critic(seq['feat'][:-2]).mode() 277 | advantage = tf.stop_gradient(target[1:] - baseline) 278 | objective = policy.log_prob(seq['action'][1:-1]) * advantage 279 | mix = common.schedule(self.config.actor_grad_mix, self.tfstep) 280 | objective = mix * target[1:] + (1 - mix) * objective 281 | metrics['actor_grad_mix'] = mix 282 | else: 283 | raise NotImplementedError(self.config.actor_grad) 284 | ent = policy.entropy() 285 | ent_scale = common.schedule(self.config.actor_ent, self.tfstep) 286 | objective += ent_scale * ent 287 | weight = tf.stop_gradient(seq['weight']) 288 | actor_loss = -(weight[:-2] * objective).mean() 289 | metrics['actor_ent'] = ent.mean() 290 | metrics['actor_ent_scale'] = ent_scale 291 | return actor_loss, metrics 292 | 293 | def critic_loss(self, seq, target): 294 | # States: [z0] [z1] [z2] z3 295 | # Rewards: [r0] [r1] [r2] r3 296 | # Values: [v0] [v1] [v2] v3 297 | # Weights: [ 1] [w1] [w2] w3 298 | # Targets: [t0] [t1] [t2] 299 | # Loss: l0 l1 l2 300 | dist = self.critic(seq['feat'][:-1]) 301 | target = tf.stop_gradient(target) 302 | weight = tf.stop_gradient(seq['weight']) 303 | critic_loss = -(dist.log_prob(target) * weight[:-1]).mean() 304 | metrics = {'critic': dist.mode().mean()} 305 | return critic_loss, metrics 306 | 307 | def target(self, seq): 308 | # States: [z0] [z1] [z2] [z3] 309 | # Rewards: [r0] [r1] [r2] r3 310 | # Values: [v0] [v1] [v2] [v3] 311 | # Discount: [d0] [d1] [d2] d3 312 | # Targets: t0 t1 t2 313 | reward = tf.cast(seq['reward'], tf.float32) 314 | disc = tf.cast(seq['discount'], tf.float32) 315 | value = self._target_critic(seq['feat']).mode() 316 | # Skipping last time step because it is used for bootstrapping. 317 | target = common.lambda_return( 318 | reward[:-1], value[:-1], disc[:-1], 319 | bootstrap=value[-1], 320 | lambda_=self.config.discount_lambda, 321 | axis=0) 322 | metrics = {} 323 | metrics['critic_slow'] = value.mean() 324 | metrics['critic_target'] = target.mean() 325 | return target, metrics 326 | 327 | def update_slow_target(self): 328 | if self.config.slow_target: 329 | if self._updates % self.config.slow_target_update == 0: 330 | mix = 1.0 if self._updates == 0 else float( 331 | self.config.slow_target_fraction) 332 | for s, d in zip(self.critic.variables, self._target_critic.variables): 333 | d.assign(mix * s + (1 - mix) * d) 334 | self._updates.assign_add(1) 335 | -------------------------------------------------------------------------------- /dreamerv2/api.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | import sys 7 | import warnings 8 | 9 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 10 | logging.getLogger().setLevel('ERROR') 11 | warnings.filterwarnings('ignore', '.*box bound precision lowered.*') 12 | 13 | sys.path.append(str(pathlib.Path(__file__).parent)) 14 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 15 | 16 | import numpy as np 17 | import ruamel.yaml as yaml 18 | 19 | import agent 20 | import common 21 | 22 | from common import Config 23 | from common import GymWrapper 24 | from common import RenderImage 25 | from common import TerminalOutput 26 | from common import JSONLOutput 27 | from common import TensorBoardOutput 28 | 29 | configs = yaml.safe_load( 30 | (pathlib.Path(__file__).parent / 'configs.yaml').read_text()) 31 | defaults = common.Config(configs.pop('defaults')) 32 | 33 | 34 | def train(env, config, outputs=None): 35 | 36 | logdir = pathlib.Path(config.logdir).expanduser() 37 | logdir.mkdir(parents=True, exist_ok=True) 38 | config.save(logdir / 'config.yaml') 39 | print(config, '\n') 40 | print('Logdir', logdir) 41 | 42 | outputs = outputs or [ 43 | common.TerminalOutput(), 44 | common.JSONLOutput(config.logdir), 45 | common.TensorBoardOutput(config.logdir), 46 | ] 47 | replay = common.Replay(logdir / 'train_episodes', **config.replay) 48 | step = common.Counter(replay.stats['total_steps']) 49 | logger = common.Logger(step, outputs, multiplier=config.action_repeat) 50 | metrics = collections.defaultdict(list) 51 | 52 | should_train = common.Every(config.train_every) 53 | should_log = common.Every(config.log_every) 54 | should_video = common.Every(config.log_every) 55 | should_expl = common.Until(config.expl_until) 56 | 57 | def per_episode(ep): 58 | length = len(ep['reward']) - 1 59 | score = float(ep['reward'].astype(np.float64).sum()) 60 | print(f'Episode has {length} steps and return {score:.1f}.') 61 | logger.scalar('return', score) 62 | logger.scalar('length', length) 63 | for key, value in ep.items(): 64 | if re.match(config.log_keys_sum, key): 65 | logger.scalar(f'sum_{key}', ep[key].sum()) 66 | if re.match(config.log_keys_mean, key): 67 | logger.scalar(f'mean_{key}', ep[key].mean()) 68 | if re.match(config.log_keys_max, key): 69 | logger.scalar(f'max_{key}', ep[key].max(0).mean()) 70 | if should_video(step): 71 | for key in config.log_keys_video: 72 | logger.video(f'policy_{key}', ep[key]) 73 | logger.add(replay.stats) 74 | logger.write() 75 | 76 | env = common.GymWrapper(env) 77 | env = common.ResizeImage(env) 78 | if hasattr(env.act_space['action'], 'n'): 79 | env = common.OneHotAction(env) 80 | else: 81 | env = common.NormalizeAction(env) 82 | env = common.TimeLimit(env, config.time_limit) 83 | 84 | driver = common.Driver([env]) 85 | driver.on_episode(per_episode) 86 | driver.on_step(lambda tran, worker: step.increment()) 87 | driver.on_step(replay.add_step) 88 | driver.on_reset(replay.add_step) 89 | 90 | prefill = max(0, config.prefill - replay.stats['total_steps']) 91 | if prefill: 92 | print(f'Prefill dataset ({prefill} steps).') 93 | random_agent = common.RandomAgent(env.act_space) 94 | driver(random_agent, steps=prefill, episodes=1) 95 | driver.reset() 96 | 97 | print('Create agent.') 98 | agnt = agent.Agent(config, env.obs_space, env.act_space, step) 99 | dataset = iter(replay.dataset(**config.dataset)) 100 | train_agent = common.CarryOverState(agnt.train) 101 | train_agent(next(dataset)) 102 | if (logdir / 'variables.pkl').exists(): 103 | agnt.load(logdir / 'variables.pkl') 104 | else: 105 | print('Pretrain agent.') 106 | for _ in range(config.pretrain): 107 | train_agent(next(dataset)) 108 | policy = lambda *args: agnt.policy( 109 | *args, mode='explore' if should_expl(step) else 'train') 110 | 111 | def train_step(tran, worker): 112 | if should_train(step): 113 | for _ in range(config.train_steps): 114 | mets = train_agent(next(dataset)) 115 | [metrics[key].append(value) for key, value in mets.items()] 116 | if should_log(step): 117 | for name, values in metrics.items(): 118 | logger.scalar(name, np.array(values, np.float64).mean()) 119 | metrics[name].clear() 120 | logger.add(agnt.report(next(dataset))) 121 | logger.write(fps=True) 122 | driver.on_step(train_step) 123 | 124 | while step < config.steps: 125 | logger.write() 126 | driver(policy, steps=config.eval_every) 127 | agnt.save(logdir / 'variables.pkl') 128 | -------------------------------------------------------------------------------- /dreamerv2/common/__init__.py: -------------------------------------------------------------------------------- 1 | # General tools. 2 | from .config import * 3 | from .counter import * 4 | from .flags import * 5 | from .logger import * 6 | from .when import * 7 | 8 | # RL tools. 9 | from .other import * 10 | from .driver import * 11 | from .envs import * 12 | from .replay import * 13 | 14 | # TensorFlow tools. 15 | from .tfutils import * 16 | from .dists import * 17 | from .nets import * 18 | -------------------------------------------------------------------------------- /dreamerv2/common/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | import re 4 | 5 | 6 | class Config(dict): 7 | 8 | SEP = '.' 9 | IS_PATTERN = re.compile(r'.*[^A-Za-z0-9_.-].*') 10 | 11 | def __init__(self, *args, **kwargs): 12 | mapping = dict(*args, **kwargs) 13 | mapping = self._flatten(mapping) 14 | mapping = self._ensure_keys(mapping) 15 | mapping = self._ensure_values(mapping) 16 | self._flat = mapping 17 | self._nested = self._nest(mapping) 18 | # Need to assign the values to the base class dictionary so that 19 | # conversion to dict does not lose the content. 20 | super().__init__(self._nested) 21 | 22 | @property 23 | def flat(self): 24 | return self._flat.copy() 25 | 26 | def save(self, filename): 27 | filename = pathlib.Path(filename) 28 | if filename.suffix == '.json': 29 | filename.write_text(json.dumps(dict(self))) 30 | elif filename.suffix in ('.yml', '.yaml'): 31 | import ruamel.yaml as yaml 32 | with filename.open('w') as f: 33 | yaml.safe_dump(dict(self), f) 34 | else: 35 | raise NotImplementedError(filename.suffix) 36 | 37 | @classmethod 38 | def load(cls, filename): 39 | filename = pathlib.Path(filename) 40 | if filename.suffix == '.json': 41 | return cls(json.loads(filename.read_text())) 42 | elif filename.suffix in ('.yml', '.yaml'): 43 | import ruamel.yaml as yaml 44 | return cls(yaml.safe_load(filename.read_text())) 45 | else: 46 | raise NotImplementedError(filename.suffix) 47 | 48 | def parse_flags(self, argv=None, known_only=False, help_exists=None): 49 | from . import flags 50 | return flags.Flags(self).parse(argv, known_only, help_exists) 51 | 52 | def __contains__(self, name): 53 | try: 54 | self[name] 55 | return True 56 | except KeyError: 57 | return False 58 | 59 | def __getattr__(self, name): 60 | if name.startswith('_'): 61 | return super().__getattr__(name) 62 | try: 63 | return self[name] 64 | except KeyError: 65 | raise AttributeError(name) 66 | 67 | def __getitem__(self, name): 68 | result = self._nested 69 | for part in name.split(self.SEP): 70 | result = result[part] 71 | if isinstance(result, dict): 72 | result = type(self)(result) 73 | return result 74 | 75 | def __setattr__(self, key, value): 76 | if key.startswith('_'): 77 | return super().__setattr__(key, value) 78 | message = f"Tried to set key '{key}' on immutable config. Use update()." 79 | raise AttributeError(message) 80 | 81 | def __setitem__(self, key, value): 82 | if key.startswith('_'): 83 | return super().__setitem__(key, value) 84 | message = f"Tried to set key '{key}' on immutable config. Use update()." 85 | raise AttributeError(message) 86 | 87 | def __reduce__(self): 88 | return (type(self), (dict(self),)) 89 | 90 | def __str__(self): 91 | lines = ['\nConfig:'] 92 | keys, vals, typs = [], [], [] 93 | for key, val in self.flat.items(): 94 | keys.append(key + ':') 95 | vals.append(self._format_value(val)) 96 | typs.append(self._format_type(val)) 97 | max_key = max(len(k) for k in keys) if keys else 0 98 | max_val = max(len(v) for v in vals) if vals else 0 99 | for key, val, typ in zip(keys, vals, typs): 100 | key = key.ljust(max_key) 101 | val = val.ljust(max_val) 102 | lines.append(f'{key} {val} ({typ})') 103 | return '\n'.join(lines) 104 | 105 | def update(self, *args, **kwargs): 106 | result = self._flat.copy() 107 | inputs = self._flatten(dict(*args, **kwargs)) 108 | for key, new in inputs.items(): 109 | if self.IS_PATTERN.match(key): 110 | pattern = re.compile(key) 111 | keys = {k for k in result if pattern.match(k)} 112 | else: 113 | keys = [key] 114 | if not keys: 115 | raise KeyError(f'Unknown key or pattern {key}.') 116 | for key in keys: 117 | old = result[key] 118 | try: 119 | if isinstance(old, int) and isinstance(new, float): 120 | if float(int(new)) != new: 121 | message = f"Cannot convert fractional float {new} to int." 122 | raise ValueError(message) 123 | result[key] = type(old)(new) 124 | except (ValueError, TypeError): 125 | raise TypeError( 126 | f"Cannot convert '{new}' to type '{type(old).__name__}' " + 127 | f"of value '{old}' for key '{key}'.") 128 | return type(self)(result) 129 | 130 | def _flatten(self, mapping): 131 | result = {} 132 | for key, value in mapping.items(): 133 | if isinstance(value, dict): 134 | for k, v in self._flatten(value).items(): 135 | if self.IS_PATTERN.match(key) or self.IS_PATTERN.match(k): 136 | combined = f'{key}\\{self.SEP}{k}' 137 | else: 138 | combined = f'{key}{self.SEP}{k}' 139 | result[combined] = v 140 | else: 141 | result[key] = value 142 | return result 143 | 144 | def _nest(self, mapping): 145 | result = {} 146 | for key, value in mapping.items(): 147 | parts = key.split(self.SEP) 148 | node = result 149 | for part in parts[:-1]: 150 | if part not in node: 151 | node[part] = {} 152 | node = node[part] 153 | node[parts[-1]] = value 154 | return result 155 | 156 | def _ensure_keys(self, mapping): 157 | for key in mapping: 158 | assert not self.IS_PATTERN.match(key), key 159 | return mapping 160 | 161 | def _ensure_values(self, mapping): 162 | result = json.loads(json.dumps(mapping)) 163 | for key, value in result.items(): 164 | if isinstance(value, list): 165 | value = tuple(value) 166 | if isinstance(value, tuple): 167 | if len(value) == 0: 168 | message = 'Empty lists are disallowed because their type is unclear.' 169 | raise TypeError(message) 170 | if not isinstance(value[0], (str, float, int, bool)): 171 | message = 'Lists can only contain strings, floats, ints, bools' 172 | message += f' but not {type(value[0])}' 173 | raise TypeError(message) 174 | if not all(isinstance(x, type(value[0])) for x in value[1:]): 175 | message = 'Elements of a list must all be of the same type.' 176 | raise TypeError(message) 177 | result[key] = value 178 | return result 179 | 180 | def _format_value(self, value): 181 | if isinstance(value, (list, tuple)): 182 | return '[' + ', '.join(self._format_value(x) for x in value) + ']' 183 | return str(value) 184 | 185 | def _format_type(self, value): 186 | if isinstance(value, (list, tuple)): 187 | assert len(value) > 0, value 188 | return self._format_type(value[0]) + 's' 189 | return str(type(value).__name__) 190 | -------------------------------------------------------------------------------- /dreamerv2/common/counter.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | @functools.total_ordering 5 | class Counter: 6 | 7 | def __init__(self, initial=0): 8 | self.value = initial 9 | 10 | def __int__(self): 11 | return int(self.value) 12 | 13 | def __eq__(self, other): 14 | return int(self) == other 15 | 16 | def __ne__(self, other): 17 | return int(self) != other 18 | 19 | def __lt__(self, other): 20 | return int(self) < other 21 | 22 | def __add__(self, other): 23 | return int(self) + other 24 | 25 | def increment(self, amount=1): 26 | self.value += amount 27 | -------------------------------------------------------------------------------- /dreamerv2/common/dists.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | from tensorflow_probability import distributions as tfd 4 | 5 | 6 | # Patch to ignore seed to avoid synchronization across GPUs. 7 | _orig_random_categorical = tf.random.categorical 8 | def random_categorical(*args, **kwargs): 9 | kwargs['seed'] = None 10 | return _orig_random_categorical(*args, **kwargs) 11 | tf.random.categorical = random_categorical 12 | 13 | # Patch to ignore seed to avoid synchronization across GPUs. 14 | _orig_random_normal = tf.random.normal 15 | def random_normal(*args, **kwargs): 16 | kwargs['seed'] = None 17 | return _orig_random_normal(*args, **kwargs) 18 | tf.random.normal = random_normal 19 | 20 | 21 | class SampleDist: 22 | 23 | def __init__(self, dist, samples=100): 24 | self._dist = dist 25 | self._samples = samples 26 | 27 | @property 28 | def name(self): 29 | return 'SampleDist' 30 | 31 | def __getattr__(self, name): 32 | return getattr(self._dist, name) 33 | 34 | def mean(self): 35 | samples = self._dist.sample(self._samples) 36 | return samples.mean(0) 37 | 38 | def mode(self): 39 | sample = self._dist.sample(self._samples) 40 | logprob = self._dist.log_prob(sample) 41 | return tf.gather(sample, tf.argmax(logprob))[0] 42 | 43 | def entropy(self): 44 | sample = self._dist.sample(self._samples) 45 | logprob = self.log_prob(sample) 46 | return -logprob.mean(0) 47 | 48 | 49 | class OneHotDist(tfd.OneHotCategorical): 50 | 51 | def __init__(self, logits=None, probs=None, dtype=None): 52 | self._sample_dtype = dtype or tf.float32 53 | super().__init__(logits=logits, probs=probs) 54 | 55 | def mode(self): 56 | return tf.cast(super().mode(), self._sample_dtype) 57 | 58 | def sample(self, sample_shape=(), seed=None): 59 | # Straight through biased gradient estimator. 60 | sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype) 61 | probs = self._pad(super().probs_parameter(), sample.shape) 62 | sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype) 63 | return sample 64 | 65 | def _pad(self, tensor, shape): 66 | tensor = super().probs_parameter() 67 | while len(tensor.shape) < len(shape): 68 | tensor = tensor[None] 69 | return tensor 70 | 71 | 72 | class TruncNormalDist(tfd.TruncatedNormal): 73 | 74 | def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): 75 | super().__init__(loc, scale, low, high) 76 | self._clip = clip 77 | self._mult = mult 78 | 79 | def sample(self, *args, **kwargs): 80 | event = super().sample(*args, **kwargs) 81 | if self._clip: 82 | clipped = tf.clip_by_value( 83 | event, self.low + self._clip, self.high - self._clip) 84 | event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped) 85 | if self._mult: 86 | event *= self._mult 87 | return event 88 | 89 | 90 | class TanhBijector(tfp.bijectors.Bijector): 91 | 92 | def __init__(self, validate_args=False, name='tanh'): 93 | super().__init__( 94 | forward_min_event_ndims=0, 95 | validate_args=validate_args, 96 | name=name) 97 | 98 | def _forward(self, x): 99 | return tf.nn.tanh(x) 100 | 101 | def _inverse(self, y): 102 | dtype = y.dtype 103 | y = tf.cast(y, tf.float32) 104 | y = tf.where( 105 | tf.less_equal(tf.abs(y), 1.), 106 | tf.clip_by_value(y, -0.99999997, 0.99999997), y) 107 | y = tf.atanh(y) 108 | y = tf.cast(y, dtype) 109 | return y 110 | 111 | def _forward_log_det_jacobian(self, x): 112 | log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) 113 | return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) 114 | -------------------------------------------------------------------------------- /dreamerv2/common/driver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Driver: 5 | 6 | def __init__(self, envs, **kwargs): 7 | self._envs = envs 8 | self._kwargs = kwargs 9 | self._on_steps = [] 10 | self._on_resets = [] 11 | self._on_episodes = [] 12 | self._act_spaces = [env.act_space for env in envs] 13 | self.reset() 14 | 15 | def on_step(self, callback): 16 | self._on_steps.append(callback) 17 | 18 | def on_reset(self, callback): 19 | self._on_resets.append(callback) 20 | 21 | def on_episode(self, callback): 22 | self._on_episodes.append(callback) 23 | 24 | def reset(self): 25 | self._obs = [None] * len(self._envs) 26 | self._eps = [None] * len(self._envs) 27 | self._state = None 28 | 29 | def __call__(self, policy, steps=0, episodes=0): 30 | step, episode = 0, 0 31 | while step < steps or episode < episodes: 32 | obs = { 33 | i: self._envs[i].reset() 34 | for i, ob in enumerate(self._obs) if ob is None or ob['is_last']} 35 | for i, ob in obs.items(): 36 | self._obs[i] = ob() if callable(ob) else ob 37 | act = {k: np.zeros(v.shape) for k, v in self._act_spaces[i].items()} 38 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()} 39 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_resets] 40 | self._eps[i] = [tran] 41 | obs = {k: np.stack([o[k] for o in self._obs]) for k in self._obs[0]} 42 | actions, self._state = policy(obs, self._state, **self._kwargs) 43 | actions = [ 44 | {k: np.array(actions[k][i]) for k in actions} 45 | for i in range(len(self._envs))] 46 | assert len(actions) == len(self._envs) 47 | obs = [e.step(a) for e, a in zip(self._envs, actions)] 48 | obs = [ob() if callable(ob) else ob for ob in obs] 49 | for i, (act, ob) in enumerate(zip(actions, obs)): 50 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()} 51 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_steps] 52 | self._eps[i].append(tran) 53 | step += 1 54 | if ob['is_last']: 55 | ep = self._eps[i] 56 | ep = {k: self._convert([t[k] for t in ep]) for k in ep[0]} 57 | [fn(ep, **self._kwargs) for fn in self._on_episodes] 58 | episode += 1 59 | self._obs = obs 60 | 61 | def _convert(self, value): 62 | value = np.array(value) 63 | if np.issubdtype(value.dtype, np.floating): 64 | return value.astype(np.float32) 65 | elif np.issubdtype(value.dtype, np.signedinteger): 66 | return value.astype(np.int32) 67 | elif np.issubdtype(value.dtype, np.uint8): 68 | return value.astype(np.uint8) 69 | return value 70 | -------------------------------------------------------------------------------- /dreamerv2/common/envs.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import os 3 | import sys 4 | import threading 5 | import traceback 6 | 7 | import cloudpickle 8 | import gym 9 | import numpy as np 10 | 11 | 12 | class GymWrapper: 13 | 14 | def __init__(self, env, obs_key='image', act_key='action'): 15 | self._env = env 16 | self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') 17 | self._act_is_dict = hasattr(self._env.action_space, 'spaces') 18 | self._obs_key = obs_key 19 | self._act_key = act_key 20 | 21 | def __getattr__(self, name): 22 | if name.startswith('__'): 23 | raise AttributeError(name) 24 | try: 25 | return getattr(self._env, name) 26 | except AttributeError: 27 | raise ValueError(name) 28 | 29 | @property 30 | def obs_space(self): 31 | if self._obs_is_dict: 32 | spaces = self._env.observation_space.spaces.copy() 33 | else: 34 | spaces = {self._obs_key: self._env.observation_space} 35 | return { 36 | **spaces, 37 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 38 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), 39 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), 40 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), 41 | } 42 | 43 | @property 44 | def act_space(self): 45 | if self._act_is_dict: 46 | return self._env.action_space.spaces.copy() 47 | else: 48 | return {self._act_key: self._env.action_space} 49 | 50 | def step(self, action): 51 | if not self._act_is_dict: 52 | action = action[self._act_key] 53 | obs, reward, done, info = self._env.step(action) 54 | if not self._obs_is_dict: 55 | obs = {self._obs_key: obs} 56 | obs['reward'] = float(reward) 57 | obs['is_first'] = False 58 | obs['is_last'] = done 59 | obs['is_terminal'] = info.get('is_terminal', done) 60 | return obs 61 | 62 | def reset(self): 63 | obs = self._env.reset() 64 | if not self._obs_is_dict: 65 | obs = {self._obs_key: obs} 66 | obs['reward'] = 0.0 67 | obs['is_first'] = True 68 | obs['is_last'] = False 69 | obs['is_terminal'] = False 70 | return obs 71 | 72 | 73 | class DMC: 74 | 75 | def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): 76 | os.environ['MUJOCO_GL'] = 'egl' 77 | domain, task = name.split('_', 1) 78 | if domain == 'cup': # Only domain with multiple words. 79 | domain = 'ball_in_cup' 80 | if domain == 'manip': 81 | from dm_control import manipulation 82 | self._env = manipulation.load(task + '_vision') 83 | elif domain == 'locom': 84 | from dm_control.locomotion.examples import basic_rodent_2020 85 | self._env = getattr(basic_rodent_2020, task)() 86 | else: 87 | from dm_control import suite 88 | self._env = suite.load(domain, task) 89 | self._action_repeat = action_repeat 90 | self._size = size 91 | if camera in (-1, None): 92 | camera = dict( 93 | quadruped_walk=2, quadruped_run=2, quadruped_escape=2, 94 | quadruped_fetch=2, locom_rodent_maze_forage=1, 95 | locom_rodent_two_touch=1, 96 | ).get(name, 0) 97 | self._camera = camera 98 | self._ignored_keys = [] 99 | for key, value in self._env.observation_spec().items(): 100 | if value.shape == (0,): 101 | print(f"Ignoring empty observation key '{key}'.") 102 | self._ignored_keys.append(key) 103 | 104 | @property 105 | def obs_space(self): 106 | spaces = { 107 | 'image': gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), 108 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 109 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), 110 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), 111 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), 112 | } 113 | for key, value in self._env.observation_spec().items(): 114 | if key in self._ignored_keys: 115 | continue 116 | if value.dtype == np.float64: 117 | spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, np.float32) 118 | elif value.dtype == np.uint8: 119 | spaces[key] = gym.spaces.Box(0, 255, value.shape, np.uint8) 120 | else: 121 | raise NotImplementedError(value.dtype) 122 | return spaces 123 | 124 | @property 125 | def act_space(self): 126 | spec = self._env.action_spec() 127 | action = gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) 128 | return {'action': action} 129 | 130 | def step(self, action): 131 | assert np.isfinite(action['action']).all(), action['action'] 132 | reward = 0.0 133 | for _ in range(self._action_repeat): 134 | time_step = self._env.step(action['action']) 135 | reward += time_step.reward or 0.0 136 | if time_step.last(): 137 | break 138 | assert time_step.discount in (0, 1) 139 | obs = { 140 | 'reward': reward, 141 | 'is_first': False, 142 | 'is_last': time_step.last(), 143 | 'is_terminal': time_step.discount == 0, 144 | 'image': self._env.physics.render(*self._size, camera_id=self._camera), 145 | } 146 | obs.update({ 147 | k: v for k, v in dict(time_step.observation).items() 148 | if k not in self._ignored_keys}) 149 | return obs 150 | 151 | def reset(self): 152 | time_step = self._env.reset() 153 | obs = { 154 | 'reward': 0.0, 155 | 'is_first': True, 156 | 'is_last': False, 157 | 'is_terminal': False, 158 | 'image': self._env.physics.render(*self._size, camera_id=self._camera), 159 | } 160 | obs.update({ 161 | k: v for k, v in dict(time_step.observation).items() 162 | if k not in self._ignored_keys}) 163 | return obs 164 | 165 | 166 | class Atari: 167 | 168 | LOCK = threading.Lock() 169 | 170 | def __init__( 171 | self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, 172 | life_done=False, sticky=True, all_actions=False): 173 | assert size[0] == size[1] 174 | import gym.wrappers 175 | import gym.envs.atari 176 | if name == 'james_bond': 177 | name = 'jamesbond' 178 | with self.LOCK: 179 | env = gym.envs.atari.AtariEnv( 180 | game=name, obs_type='image', frameskip=1, 181 | repeat_action_probability=0.25 if sticky else 0.0, 182 | full_action_space=all_actions) 183 | # Avoid unnecessary rendering in inner env. 184 | env._get_obs = lambda: None 185 | # Tell wrapper that the inner env has no action repeat. 186 | env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') 187 | self._env = gym.wrappers.AtariPreprocessing( 188 | env, noops, action_repeat, size[0], life_done, grayscale) 189 | self._size = size 190 | self._grayscale = grayscale 191 | 192 | @property 193 | def obs_space(self): 194 | shape = self._size + (1 if self._grayscale else 3,) 195 | return { 196 | 'image': gym.spaces.Box(0, 255, shape, np.uint8), 197 | 'ram': gym.spaces.Box(0, 255, (128,), np.uint8), 198 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 199 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), 200 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), 201 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), 202 | } 203 | 204 | @property 205 | def act_space(self): 206 | return {'action': self._env.action_space} 207 | 208 | def step(self, action): 209 | image, reward, done, info = self._env.step(action['action']) 210 | if self._grayscale: 211 | image = image[..., None] 212 | return { 213 | 'image': image, 214 | 'ram': self._env.env._get_ram(), 215 | 'reward': reward, 216 | 'is_first': False, 217 | 'is_last': done, 218 | 'is_terminal': done, 219 | } 220 | 221 | def reset(self): 222 | with self.LOCK: 223 | image = self._env.reset() 224 | if self._grayscale: 225 | image = image[..., None] 226 | return { 227 | 'image': image, 228 | 'ram': self._env.env._get_ram(), 229 | 'reward': 0.0, 230 | 'is_first': True, 231 | 'is_last': False, 232 | 'is_terminal': False, 233 | } 234 | 235 | def close(self): 236 | return self._env.close() 237 | 238 | 239 | class Crafter: 240 | 241 | def __init__(self, outdir=None, reward=True, seed=None): 242 | import crafter 243 | self._env = crafter.Env(reward=reward, seed=seed) 244 | self._env = crafter.Recorder( 245 | self._env, outdir, 246 | save_stats=True, 247 | save_video=False, 248 | save_episode=False, 249 | ) 250 | self._achievements = crafter.constants.achievements.copy() 251 | 252 | @property 253 | def obs_space(self): 254 | spaces = { 255 | 'image': self._env.observation_space, 256 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 257 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), 258 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), 259 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), 260 | 'log_reward': gym.spaces.Box(-np.inf, np.inf, (), np.float32), 261 | } 262 | spaces.update({ 263 | f'log_achievement_{k}': gym.spaces.Box(0, 2 ** 31 - 1, (), np.int32) 264 | for k in self._achievements}) 265 | return spaces 266 | 267 | @property 268 | def act_space(self): 269 | return {'action': self._env.action_space} 270 | 271 | def step(self, action): 272 | image, reward, done, info = self._env.step(action['action']) 273 | obs = { 274 | 'image': image, 275 | 'reward': reward, 276 | 'is_first': False, 277 | 'is_last': done, 278 | 'is_terminal': info['discount'] == 0, 279 | 'log_reward': info['reward'], 280 | } 281 | obs.update({ 282 | f'log_achievement_{k}': v 283 | for k, v in info['achievements'].items()}) 284 | return obs 285 | 286 | def reset(self): 287 | obs = { 288 | 'image': self._env.reset(), 289 | 'reward': 0.0, 290 | 'is_first': True, 291 | 'is_last': False, 292 | 'is_terminal': False, 293 | 'log_reward': 0.0, 294 | } 295 | obs.update({ 296 | f'log_achievement_{k}': 0 297 | for k in self._achievements}) 298 | return obs 299 | 300 | 301 | class Dummy: 302 | 303 | def __init__(self): 304 | pass 305 | 306 | @property 307 | def obs_space(self): 308 | return { 309 | 'image': gym.spaces.Box(0, 255, (64, 64, 3), dtype=np.uint8), 310 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 311 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), 312 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), 313 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), 314 | } 315 | 316 | @property 317 | def act_space(self): 318 | return {'action': gym.spaces.Box(-1, 1, (6,), dtype=np.float32)} 319 | 320 | def step(self, action): 321 | return { 322 | 'image': np.zeros((64, 64, 3)), 323 | 'reward': 0.0, 324 | 'is_first': False, 325 | 'is_last': False, 326 | 'is_terminal': False, 327 | } 328 | 329 | def reset(self): 330 | return { 331 | 'image': np.zeros((64, 64, 3)), 332 | 'reward': 0.0, 333 | 'is_first': True, 334 | 'is_last': False, 335 | 'is_terminal': False, 336 | } 337 | 338 | 339 | class TimeLimit: 340 | 341 | def __init__(self, env, duration): 342 | self._env = env 343 | self._duration = duration 344 | self._step = None 345 | 346 | def __getattr__(self, name): 347 | if name.startswith('__'): 348 | raise AttributeError(name) 349 | try: 350 | return getattr(self._env, name) 351 | except AttributeError: 352 | raise ValueError(name) 353 | 354 | def step(self, action): 355 | assert self._step is not None, 'Must reset environment.' 356 | obs = self._env.step(action) 357 | self._step += 1 358 | if self._duration and self._step >= self._duration: 359 | obs['is_last'] = True 360 | self._step = None 361 | return obs 362 | 363 | def reset(self): 364 | self._step = 0 365 | return self._env.reset() 366 | 367 | 368 | class NormalizeAction: 369 | 370 | def __init__(self, env, key='action'): 371 | self._env = env 372 | self._key = key 373 | space = env.act_space[key] 374 | self._mask = np.isfinite(space.low) & np.isfinite(space.high) 375 | self._low = np.where(self._mask, space.low, -1) 376 | self._high = np.where(self._mask, space.high, 1) 377 | 378 | def __getattr__(self, name): 379 | if name.startswith('__'): 380 | raise AttributeError(name) 381 | try: 382 | return getattr(self._env, name) 383 | except AttributeError: 384 | raise ValueError(name) 385 | 386 | @property 387 | def act_space(self): 388 | low = np.where(self._mask, -np.ones_like(self._low), self._low) 389 | high = np.where(self._mask, np.ones_like(self._low), self._high) 390 | space = gym.spaces.Box(low, high, dtype=np.float32) 391 | return {**self._env.act_space, self._key: space} 392 | 393 | def step(self, action): 394 | orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low 395 | orig = np.where(self._mask, orig, action[self._key]) 396 | return self._env.step({**action, self._key: orig}) 397 | 398 | 399 | class OneHotAction: 400 | 401 | def __init__(self, env, key='action'): 402 | assert hasattr(env.act_space[key], 'n') 403 | self._env = env 404 | self._key = key 405 | self._random = np.random.RandomState() 406 | 407 | def __getattr__(self, name): 408 | if name.startswith('__'): 409 | raise AttributeError(name) 410 | try: 411 | return getattr(self._env, name) 412 | except AttributeError: 413 | raise ValueError(name) 414 | 415 | @property 416 | def act_space(self): 417 | shape = (self._env.act_space[self._key].n,) 418 | space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) 419 | space.sample = self._sample_action 420 | space.n = shape[0] 421 | return {**self._env.act_space, self._key: space} 422 | 423 | def step(self, action): 424 | index = np.argmax(action[self._key]).astype(int) 425 | reference = np.zeros_like(action[self._key]) 426 | reference[index] = 1 427 | if not np.allclose(reference, action[self._key]): 428 | raise ValueError(f'Invalid one-hot action:\n{action}') 429 | return self._env.step({**action, self._key: index}) 430 | 431 | def reset(self): 432 | return self._env.reset() 433 | 434 | def _sample_action(self): 435 | actions = self._env.act_space.n 436 | index = self._random.randint(0, actions) 437 | reference = np.zeros(actions, dtype=np.float32) 438 | reference[index] = 1.0 439 | return reference 440 | 441 | 442 | class ResizeImage: 443 | 444 | def __init__(self, env, size=(64, 64)): 445 | self._env = env 446 | self._size = size 447 | self._keys = [ 448 | k for k, v in env.obs_space.items() 449 | if len(v.shape) > 1 and v.shape[:2] != size] 450 | print(f'Resizing keys {",".join(self._keys)} to {self._size}.') 451 | if self._keys: 452 | from PIL import Image 453 | self._Image = Image 454 | 455 | def __getattr__(self, name): 456 | if name.startswith('__'): 457 | raise AttributeError(name) 458 | try: 459 | return getattr(self._env, name) 460 | except AttributeError: 461 | raise ValueError(name) 462 | 463 | @property 464 | def obs_space(self): 465 | spaces = self._env.obs_space 466 | for key in self._keys: 467 | shape = self._size + spaces[key].shape[2:] 468 | spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8) 469 | return spaces 470 | 471 | def step(self, action): 472 | obs = self._env.step(action) 473 | for key in self._keys: 474 | obs[key] = self._resize(obs[key]) 475 | return obs 476 | 477 | def reset(self): 478 | obs = self._env.reset() 479 | for key in self._keys: 480 | obs[key] = self._resize(obs[key]) 481 | return obs 482 | 483 | def _resize(self, image): 484 | image = self._Image.fromarray(image) 485 | image = image.resize(self._size, self._Image.NEAREST) 486 | image = np.array(image) 487 | return image 488 | 489 | 490 | class RenderImage: 491 | 492 | def __init__(self, env, key='image'): 493 | self._env = env 494 | self._key = key 495 | self._shape = self._env.render().shape 496 | 497 | def __getattr__(self, name): 498 | if name.startswith('__'): 499 | raise AttributeError(name) 500 | try: 501 | return getattr(self._env, name) 502 | except AttributeError: 503 | raise ValueError(name) 504 | 505 | @property 506 | def obs_space(self): 507 | spaces = self._env.obs_space 508 | spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8) 509 | return spaces 510 | 511 | def step(self, action): 512 | obs = self._env.step(action) 513 | obs[self._key] = self._env.render('rgb_array') 514 | return obs 515 | 516 | def reset(self): 517 | obs = self._env.reset() 518 | obs[self._key] = self._env.render('rgb_array') 519 | return obs 520 | 521 | 522 | class Async: 523 | 524 | # Message types for communication via the pipe. 525 | _ACCESS = 1 526 | _CALL = 2 527 | _RESULT = 3 528 | _CLOSE = 4 529 | _EXCEPTION = 5 530 | 531 | def __init__(self, constructor, strategy='thread'): 532 | self._pickled_ctor = cloudpickle.dumps(constructor) 533 | if strategy == 'process': 534 | import multiprocessing as mp 535 | context = mp.get_context('spawn') 536 | elif strategy == 'thread': 537 | import multiprocessing.dummy as context 538 | else: 539 | raise NotImplementedError(strategy) 540 | self._strategy = strategy 541 | self._conn, conn = context.Pipe() 542 | self._process = context.Process(target=self._worker, args=(conn,)) 543 | atexit.register(self.close) 544 | self._process.start() 545 | self._receive() # Ready. 546 | self._obs_space = None 547 | self._act_space = None 548 | 549 | def access(self, name): 550 | self._conn.send((self._ACCESS, name)) 551 | return self._receive 552 | 553 | def call(self, name, *args, **kwargs): 554 | payload = name, args, kwargs 555 | self._conn.send((self._CALL, payload)) 556 | return self._receive 557 | 558 | def close(self): 559 | try: 560 | self._conn.send((self._CLOSE, None)) 561 | self._conn.close() 562 | except IOError: 563 | pass # The connection was already closed. 564 | self._process.join(5) 565 | 566 | @property 567 | def obs_space(self): 568 | if not self._obs_space: 569 | self._obs_space = self.access('obs_space')() 570 | return self._obs_space 571 | 572 | @property 573 | def act_space(self): 574 | if not self._act_space: 575 | self._act_space = self.access('act_space')() 576 | return self._act_space 577 | 578 | def step(self, action, blocking=False): 579 | promise = self.call('step', action) 580 | if blocking: 581 | return promise() 582 | else: 583 | return promise 584 | 585 | def reset(self, blocking=False): 586 | promise = self.call('reset') 587 | if blocking: 588 | return promise() 589 | else: 590 | return promise 591 | 592 | def _receive(self): 593 | try: 594 | message, payload = self._conn.recv() 595 | except (OSError, EOFError): 596 | raise RuntimeError('Lost connection to environment worker.') 597 | # Re-raise exceptions in the main process. 598 | if message == self._EXCEPTION: 599 | stacktrace = payload 600 | raise Exception(stacktrace) 601 | if message == self._RESULT: 602 | return payload 603 | raise KeyError('Received message of unexpected type {}'.format(message)) 604 | 605 | def _worker(self, conn): 606 | try: 607 | ctor = cloudpickle.loads(self._pickled_ctor) 608 | env = ctor() 609 | conn.send((self._RESULT, None)) # Ready. 610 | while True: 611 | try: 612 | # Only block for short times to have keyboard exceptions be raised. 613 | if not conn.poll(0.1): 614 | continue 615 | message, payload = conn.recv() 616 | except (EOFError, KeyboardInterrupt): 617 | break 618 | if message == self._ACCESS: 619 | name = payload 620 | result = getattr(env, name) 621 | conn.send((self._RESULT, result)) 622 | continue 623 | if message == self._CALL: 624 | name, args, kwargs = payload 625 | result = getattr(env, name)(*args, **kwargs) 626 | conn.send((self._RESULT, result)) 627 | continue 628 | if message == self._CLOSE: 629 | break 630 | raise KeyError('Received message of unknown type {}'.format(message)) 631 | except Exception: 632 | stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 633 | print('Error in environment process: {}'.format(stacktrace)) 634 | conn.send((self._EXCEPTION, stacktrace)) 635 | finally: 636 | try: 637 | conn.close() 638 | except IOError: 639 | pass # The connection was already closed. 640 | -------------------------------------------------------------------------------- /dreamerv2/common/flags.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | 5 | class Flags: 6 | 7 | def __init__(self, *args, **kwargs): 8 | from .config import Config 9 | self._config = Config(*args, **kwargs) 10 | 11 | def parse(self, argv=None, known_only=False, help_exists=None): 12 | if help_exists is None: 13 | help_exists = not known_only 14 | if argv is None: 15 | argv = sys.argv[1:] 16 | if '--help' in argv: 17 | print('\nHelp:') 18 | lines = str(self._config).split('\n')[2:] 19 | print('\n'.join('--' + re.sub(r'[:,\[\]]', '', x) for x in lines)) 20 | help_exists and sys.exit() 21 | parsed = {} 22 | remaining = [] 23 | key = None 24 | vals = None 25 | for arg in argv: 26 | if arg.startswith('--'): 27 | if key: 28 | self._submit_entry(key, vals, parsed, remaining) 29 | if '=' in arg: 30 | key, val = arg.split('=', 1) 31 | vals = [val] 32 | else: 33 | key, vals = arg, [] 34 | else: 35 | if key: 36 | vals.append(arg) 37 | else: 38 | remaining.append(arg) 39 | self._submit_entry(key, vals, parsed, remaining) 40 | parsed = self._config.update(parsed) 41 | if known_only: 42 | return parsed, remaining 43 | else: 44 | for flag in remaining: 45 | if flag.startswith('--'): 46 | raise ValueError(f"Flag '{flag}' did not match any config keys.") 47 | assert not remaining, remaining 48 | return parsed 49 | 50 | def _submit_entry(self, key, vals, parsed, remaining): 51 | if not key and not vals: 52 | return 53 | if not key: 54 | vals = ', '.join(f"'{x}'" for x in vals) 55 | raise ValueError(f"Values {vals} were not preceeded by any flag.") 56 | name = key[len('--'):] 57 | if '=' in name: 58 | remaining.extend([key] + vals) 59 | return 60 | if self._config.IS_PATTERN.match(name): 61 | pattern = re.compile(name) 62 | keys = {k for k in self._config.flat if pattern.match(k)} 63 | elif name in self._config: 64 | keys = [name] 65 | else: 66 | keys = [] 67 | if not keys: 68 | remaining.extend([key] + vals) 69 | return 70 | if not vals: 71 | raise ValueError(f"Flag '{key}' was not followed by any values.") 72 | for key in keys: 73 | parsed[key] = self._parse_flag_value(self._config[key], vals, key) 74 | 75 | def _parse_flag_value(self, default, value, key): 76 | value = value if isinstance(value, (tuple, list)) else (value,) 77 | if isinstance(default, (tuple, list)): 78 | if len(value) == 1 and ',' in value[0]: 79 | value = value[0].split(',') 80 | return tuple(self._parse_flag_value(default[0], [x], key) for x in value) 81 | assert len(value) == 1, value 82 | value = str(value[0]) 83 | if default is None: 84 | return value 85 | if isinstance(default, bool): 86 | try: 87 | return bool(['False', 'True'].index(value)) 88 | except ValueError: 89 | message = f"Expected bool but got '{value}' for key '{key}'." 90 | raise TypeError(message) 91 | if isinstance(default, int): 92 | value = float(value) # Allow scientific notation for integers. 93 | if float(int(value)) != value: 94 | message = f"Expected int but got float '{value}' for key '{key}'." 95 | raise TypeError(message) 96 | return int(value) 97 | return type(default)(value) 98 | -------------------------------------------------------------------------------- /dreamerv2/common/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pathlib 4 | import time 5 | 6 | import numpy as np 7 | 8 | 9 | class Logger: 10 | 11 | def __init__(self, step, outputs, multiplier=1): 12 | self._step = step 13 | self._outputs = outputs 14 | self._multiplier = multiplier 15 | self._last_step = None 16 | self._last_time = None 17 | self._metrics = [] 18 | 19 | def add(self, mapping, prefix=None): 20 | step = int(self._step) * self._multiplier 21 | for name, value in dict(mapping).items(): 22 | name = f'{prefix}_{name}' if prefix else name 23 | value = np.array(value) 24 | if len(value.shape) not in (0, 2, 3, 4): 25 | raise ValueError( 26 | f"Shape {value.shape} for name '{name}' cannot be " 27 | "interpreted as scalar, image, or video.") 28 | self._metrics.append((step, name, value)) 29 | 30 | def scalar(self, name, value): 31 | self.add({name: value}) 32 | 33 | def image(self, name, value): 34 | self.add({name: value}) 35 | 36 | def video(self, name, value): 37 | self.add({name: value}) 38 | 39 | def write(self, fps=False): 40 | fps and self.scalar('fps', self._compute_fps()) 41 | if not self._metrics: 42 | return 43 | for output in self._outputs: 44 | output(self._metrics) 45 | self._metrics.clear() 46 | 47 | def _compute_fps(self): 48 | step = int(self._step) * self._multiplier 49 | if self._last_step is None: 50 | self._last_time = time.time() 51 | self._last_step = step 52 | return 0 53 | steps = step - self._last_step 54 | duration = time.time() - self._last_time 55 | self._last_time += duration 56 | self._last_step = step 57 | return steps / duration 58 | 59 | 60 | class TerminalOutput: 61 | 62 | def __call__(self, summaries): 63 | step = max(s for s, _, _, in summaries) 64 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0} 65 | formatted = {k: self._format_value(v) for k, v in scalars.items()} 66 | print(f'[{step}]', ' / '.join(f'{k} {v}' for k, v in formatted.items())) 67 | 68 | def _format_value(self, value): 69 | if value == 0: 70 | return '0' 71 | elif 0.01 < abs(value) < 10000: 72 | value = f'{value:.2f}' 73 | value = value.rstrip('0') 74 | value = value.rstrip('0') 75 | value = value.rstrip('.') 76 | return value 77 | else: 78 | value = f'{value:.1e}' 79 | value = value.replace('.0e', 'e') 80 | value = value.replace('+0', '') 81 | value = value.replace('+', '') 82 | value = value.replace('-0', '-') 83 | return value 84 | 85 | 86 | class JSONLOutput: 87 | 88 | def __init__(self, logdir): 89 | self._logdir = pathlib.Path(logdir).expanduser() 90 | 91 | def __call__(self, summaries): 92 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0} 93 | step = max(s for s, _, _, in summaries) 94 | with (self._logdir / 'metrics.jsonl').open('a') as f: 95 | f.write(json.dumps({'step': step, **scalars}) + '\n') 96 | 97 | 98 | class TensorBoardOutput: 99 | 100 | def __init__(self, logdir, fps=20): 101 | # The TensorFlow summary writer supports file protocols like gs://. We use 102 | # os.path over pathlib here to preserve those prefixes. 103 | self._logdir = os.path.expanduser(logdir) 104 | self._writer = None 105 | self._fps = fps 106 | 107 | def __call__(self, summaries): 108 | import tensorflow as tf 109 | self._ensure_writer() 110 | self._writer.set_as_default() 111 | for step, name, value in summaries: 112 | if len(value.shape) == 0: 113 | tf.summary.scalar('scalars/' + name, value, step) 114 | elif len(value.shape) == 2: 115 | tf.summary.image(name, value, step) 116 | elif len(value.shape) == 3: 117 | tf.summary.image(name, value, step) 118 | elif len(value.shape) == 4: 119 | self._video_summary(name, value, step) 120 | self._writer.flush() 121 | 122 | def _ensure_writer(self): 123 | if not self._writer: 124 | import tensorflow as tf 125 | self._writer = tf.summary.create_file_writer( 126 | self._logdir, max_queue=1000) 127 | 128 | def _video_summary(self, name, video, step): 129 | import tensorflow as tf 130 | import tensorflow.compat.v1 as tf1 131 | name = name if isinstance(name, str) else name.decode('utf-8') 132 | if np.issubdtype(video.dtype, np.floating): 133 | video = np.clip(255 * video, 0, 255).astype(np.uint8) 134 | try: 135 | T, H, W, C = video.shape 136 | summary = tf1.Summary() 137 | image = tf1.Summary.Image(height=H, width=W, colorspace=C) 138 | image.encoded_image_string = encode_gif(video, self._fps) 139 | summary.value.add(tag=name, image=image) 140 | tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) 141 | except (IOError, OSError) as e: 142 | print('GIF summaries require ffmpeg in $PATH.', e) 143 | tf.summary.image(name, video, step) 144 | 145 | 146 | def encode_gif(frames, fps): 147 | from subprocess import Popen, PIPE 148 | h, w, c = frames[0].shape 149 | pxfmt = {1: 'gray', 3: 'rgb24'}[c] 150 | cmd = ' '.join([ 151 | 'ffmpeg -y -f rawvideo -vcodec rawvideo', 152 | f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', 153 | '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', 154 | f'-r {fps:.02f} -f gif -']) 155 | proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) 156 | for image in frames: 157 | proc.stdin.write(image.tobytes()) 158 | out, err = proc.communicate() 159 | if proc.returncode: 160 | raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) 161 | del proc 162 | return out 163 | -------------------------------------------------------------------------------- /dreamerv2/common/nets.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.keras import layers as tfkl 6 | from tensorflow_probability import distributions as tfd 7 | from tensorflow.keras.mixed_precision import experimental as prec 8 | 9 | import common 10 | 11 | 12 | class EnsembleRSSM(common.Module): 13 | 14 | def __init__( 15 | self, ensemble=5, stoch=30, deter=200, hidden=200, discrete=False, 16 | act='elu', norm='none', std_act='softplus', min_std=0.1): 17 | super().__init__() 18 | self._ensemble = ensemble 19 | self._stoch = stoch 20 | self._deter = deter 21 | self._hidden = hidden 22 | self._discrete = discrete 23 | self._act = get_act(act) 24 | self._norm = norm 25 | self._std_act = std_act 26 | self._min_std = min_std 27 | self._cell = GRUCell(self._deter, norm=True) 28 | self._cast = lambda x: tf.cast(x, prec.global_policy().compute_dtype) 29 | 30 | def initial(self, batch_size): 31 | dtype = prec.global_policy().compute_dtype 32 | if self._discrete: 33 | state = dict( 34 | logit=tf.zeros([batch_size, self._stoch, self._discrete], dtype), 35 | stoch=tf.zeros([batch_size, self._stoch, self._discrete], dtype), 36 | deter=self._cell.get_initial_state(None, batch_size, dtype)) 37 | else: 38 | state = dict( 39 | mean=tf.zeros([batch_size, self._stoch], dtype), 40 | std=tf.zeros([batch_size, self._stoch], dtype), 41 | stoch=tf.zeros([batch_size, self._stoch], dtype), 42 | deter=self._cell.get_initial_state(None, batch_size, dtype)) 43 | return state 44 | 45 | @tf.function 46 | def observe(self, embed, action, is_first, state=None): 47 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape)))) 48 | if state is None: 49 | state = self.initial(tf.shape(action)[0]) 50 | post, prior = common.static_scan( 51 | lambda prev, inputs: self.obs_step(prev[0], *inputs), 52 | (swap(action), swap(embed), swap(is_first)), (state, state)) 53 | post = {k: swap(v) for k, v in post.items()} 54 | prior = {k: swap(v) for k, v in prior.items()} 55 | return post, prior 56 | 57 | @tf.function 58 | def imagine(self, action, state=None): 59 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape)))) 60 | if state is None: 61 | state = self.initial(tf.shape(action)[0]) 62 | assert isinstance(state, dict), state 63 | action = swap(action) 64 | prior = common.static_scan(self.img_step, action, state) 65 | prior = {k: swap(v) for k, v in prior.items()} 66 | return prior 67 | 68 | def get_feat(self, state): 69 | stoch = self._cast(state['stoch']) 70 | if self._discrete: 71 | shape = stoch.shape[:-2] + [self._stoch * self._discrete] 72 | stoch = tf.reshape(stoch, shape) 73 | return tf.concat([stoch, state['deter']], -1) 74 | 75 | def get_dist(self, state, ensemble=False): 76 | if ensemble: 77 | state = self._suff_stats_ensemble(state['deter']) 78 | if self._discrete: 79 | logit = state['logit'] 80 | logit = tf.cast(logit, tf.float32) 81 | dist = tfd.Independent(common.OneHotDist(logit), 1) 82 | else: 83 | mean, std = state['mean'], state['std'] 84 | mean = tf.cast(mean, tf.float32) 85 | std = tf.cast(std, tf.float32) 86 | dist = tfd.MultivariateNormalDiag(mean, std) 87 | return dist 88 | 89 | @tf.function 90 | def obs_step(self, prev_state, prev_action, embed, is_first, sample=True): 91 | # if is_first.any(): 92 | prev_state, prev_action = tf.nest.map_structure( 93 | lambda x: tf.einsum( 94 | 'b,b...->b...', 1.0 - is_first.astype(x.dtype), x), 95 | (prev_state, prev_action)) 96 | prior = self.img_step(prev_state, prev_action, sample) 97 | x = tf.concat([prior['deter'], embed], -1) 98 | x = self.get('obs_out', tfkl.Dense, self._hidden)(x) 99 | x = self.get('obs_out_norm', NormLayer, self._norm)(x) 100 | x = self._act(x) 101 | stats = self._suff_stats_layer('obs_dist', x) 102 | dist = self.get_dist(stats) 103 | stoch = dist.sample() if sample else dist.mode() 104 | post = {'stoch': stoch, 'deter': prior['deter'], **stats} 105 | return post, prior 106 | 107 | @tf.function 108 | def img_step(self, prev_state, prev_action, sample=True): 109 | prev_stoch = self._cast(prev_state['stoch']) 110 | prev_action = self._cast(prev_action) 111 | if self._discrete: 112 | shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete] 113 | prev_stoch = tf.reshape(prev_stoch, shape) 114 | x = tf.concat([prev_stoch, prev_action], -1) 115 | x = self.get('img_in', tfkl.Dense, self._hidden)(x) 116 | x = self.get('img_in_norm', NormLayer, self._norm)(x) 117 | x = self._act(x) 118 | deter = prev_state['deter'] 119 | x, deter = self._cell(x, [deter]) 120 | deter = deter[0] # Keras wraps the state in a list. 121 | stats = self._suff_stats_ensemble(x) 122 | index = tf.random.uniform((), 0, self._ensemble, tf.int32) 123 | stats = {k: v[index] for k, v in stats.items()} 124 | dist = self.get_dist(stats) 125 | stoch = dist.sample() if sample else dist.mode() 126 | prior = {'stoch': stoch, 'deter': deter, **stats} 127 | return prior 128 | 129 | def _suff_stats_ensemble(self, inp): 130 | bs = list(inp.shape[:-1]) 131 | inp = inp.reshape([-1, inp.shape[-1]]) 132 | stats = [] 133 | for k in range(self._ensemble): 134 | x = self.get(f'img_out_{k}', tfkl.Dense, self._hidden)(inp) 135 | x = self.get(f'img_out_norm_{k}', NormLayer, self._norm)(x) 136 | x = self._act(x) 137 | stats.append(self._suff_stats_layer(f'img_dist_{k}', x)) 138 | stats = { 139 | k: tf.stack([x[k] for x in stats], 0) 140 | for k, v in stats[0].items()} 141 | stats = { 142 | k: v.reshape([v.shape[0]] + bs + list(v.shape[2:])) 143 | for k, v in stats.items()} 144 | return stats 145 | 146 | def _suff_stats_layer(self, name, x): 147 | if self._discrete: 148 | x = self.get(name, tfkl.Dense, self._stoch * self._discrete, None)(x) 149 | logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete]) 150 | return {'logit': logit} 151 | else: 152 | x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x) 153 | mean, std = tf.split(x, 2, -1) 154 | std = { 155 | 'softplus': lambda: tf.nn.softplus(std), 156 | 'sigmoid': lambda: tf.nn.sigmoid(std), 157 | 'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2), 158 | }[self._std_act]() 159 | std = std + self._min_std 160 | return {'mean': mean, 'std': std} 161 | 162 | def kl_loss(self, post, prior, forward, balance, free, free_avg): 163 | kld = tfd.kl_divergence 164 | sg = lambda x: tf.nest.map_structure(tf.stop_gradient, x) 165 | lhs, rhs = (prior, post) if forward else (post, prior) 166 | mix = balance if forward else (1 - balance) 167 | if balance == 0.5: 168 | value = kld(self.get_dist(lhs), self.get_dist(rhs)) 169 | loss = tf.maximum(value, free).mean() 170 | else: 171 | value_lhs = value = kld(self.get_dist(lhs), self.get_dist(sg(rhs))) 172 | value_rhs = kld(self.get_dist(sg(lhs)), self.get_dist(rhs)) 173 | if free_avg: 174 | loss_lhs = tf.maximum(value_lhs.mean(), free) 175 | loss_rhs = tf.maximum(value_rhs.mean(), free) 176 | else: 177 | loss_lhs = tf.maximum(value_lhs, free).mean() 178 | loss_rhs = tf.maximum(value_rhs, free).mean() 179 | loss = mix * loss_lhs + (1 - mix) * loss_rhs 180 | return loss, value 181 | 182 | 183 | class Encoder(common.Module): 184 | 185 | def __init__( 186 | self, shapes, cnn_keys=r'.*', mlp_keys=r'.*', act='elu', norm='none', 187 | cnn_depth=48, cnn_kernels=(4, 4, 4, 4), mlp_layers=[400, 400, 400, 400]): 188 | self.shapes = shapes 189 | self.cnn_keys = [ 190 | k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3] 191 | self.mlp_keys = [ 192 | k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1] 193 | print('Encoder CNN inputs:', list(self.cnn_keys)) 194 | print('Encoder MLP inputs:', list(self.mlp_keys)) 195 | self._act = get_act(act) 196 | self._norm = norm 197 | self._cnn_depth = cnn_depth 198 | self._cnn_kernels = cnn_kernels 199 | self._mlp_layers = mlp_layers 200 | 201 | @tf.function 202 | def __call__(self, data): 203 | key, shape = list(self.shapes.items())[0] 204 | batch_dims = data[key].shape[:-len(shape)] 205 | data = { 206 | k: tf.reshape(v, (-1,) + tuple(v.shape)[len(batch_dims):]) 207 | for k, v in data.items()} 208 | outputs = [] 209 | if self.cnn_keys: 210 | outputs.append(self._cnn({k: data[k] for k in self.cnn_keys})) 211 | if self.mlp_keys: 212 | outputs.append(self._mlp({k: data[k] for k in self.mlp_keys})) 213 | output = tf.concat(outputs, -1) 214 | return output.reshape(batch_dims + output.shape[1:]) 215 | 216 | def _cnn(self, data): 217 | x = tf.concat(list(data.values()), -1) 218 | x = x.astype(prec.global_policy().compute_dtype) 219 | for i, kernel in enumerate(self._cnn_kernels): 220 | depth = 2 ** i * self._cnn_depth 221 | x = self.get(f'conv{i}', tfkl.Conv2D, depth, kernel, 2)(x) 222 | x = self.get(f'convnorm{i}', NormLayer, self._norm)(x) 223 | x = self._act(x) 224 | return x.reshape(tuple(x.shape[:-3]) + (-1,)) 225 | 226 | def _mlp(self, data): 227 | x = tf.concat(list(data.values()), -1) 228 | x = x.astype(prec.global_policy().compute_dtype) 229 | for i, width in enumerate(self._mlp_layers): 230 | x = self.get(f'dense{i}', tfkl.Dense, width)(x) 231 | x = self.get(f'densenorm{i}', NormLayer, self._norm)(x) 232 | x = self._act(x) 233 | return x 234 | 235 | 236 | class Decoder(common.Module): 237 | 238 | def __init__( 239 | self, shapes, cnn_keys=r'.*', mlp_keys=r'.*', act='elu', norm='none', 240 | cnn_depth=48, cnn_kernels=(4, 4, 4, 4), mlp_layers=[400, 400, 400, 400]): 241 | self._shapes = shapes 242 | self.cnn_keys = [ 243 | k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3] 244 | self.mlp_keys = [ 245 | k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1] 246 | print('Decoder CNN outputs:', list(self.cnn_keys)) 247 | print('Decoder MLP outputs:', list(self.mlp_keys)) 248 | self._act = get_act(act) 249 | self._norm = norm 250 | self._cnn_depth = cnn_depth 251 | self._cnn_kernels = cnn_kernels 252 | self._mlp_layers = mlp_layers 253 | 254 | def __call__(self, features): 255 | features = tf.cast(features, prec.global_policy().compute_dtype) 256 | outputs = {} 257 | if self.cnn_keys: 258 | outputs.update(self._cnn(features)) 259 | if self.mlp_keys: 260 | outputs.update(self._mlp(features)) 261 | return outputs 262 | 263 | def _cnn(self, features): 264 | channels = {k: self._shapes[k][-1] for k in self.cnn_keys} 265 | ConvT = tfkl.Conv2DTranspose 266 | x = self.get('convin', tfkl.Dense, 32 * self._cnn_depth)(features) 267 | x = tf.reshape(x, [-1, 1, 1, 32 * self._cnn_depth]) 268 | for i, kernel in enumerate(self._cnn_kernels): 269 | depth = 2 ** (len(self._cnn_kernels) - i - 2) * self._cnn_depth 270 | act, norm = self._act, self._norm 271 | if i == len(self._cnn_kernels) - 1: 272 | depth, act, norm = sum(channels.values()), tf.identity, 'none' 273 | x = self.get(f'conv{i}', ConvT, depth, kernel, 2)(x) 274 | x = self.get(f'convnorm{i}', NormLayer, norm)(x) 275 | x = act(x) 276 | x = x.reshape(features.shape[:-1] + x.shape[1:]) 277 | means = tf.split(x, list(channels.values()), -1) 278 | dists = { 279 | key: tfd.Independent(tfd.Normal(mean, 1), 3) 280 | for (key, shape), mean in zip(channels.items(), means)} 281 | return dists 282 | 283 | def _mlp(self, features): 284 | shapes = {k: self._shapes[k] for k in self.mlp_keys} 285 | x = features 286 | for i, width in enumerate(self._mlp_layers): 287 | x = self.get(f'dense{i}', tfkl.Dense, width)(x) 288 | x = self.get(f'densenorm{i}', NormLayer, self._norm)(x) 289 | x = self._act(x) 290 | dists = {} 291 | for key, shape in shapes.items(): 292 | dists[key] = self.get(f'dense_{key}', DistLayer, shape)(x) 293 | return dists 294 | 295 | 296 | class MLP(common.Module): 297 | 298 | def __init__(self, shape, layers, units, act='elu', norm='none', **out): 299 | self._shape = (shape,) if isinstance(shape, int) else shape 300 | self._layers = layers 301 | self._units = units 302 | self._norm = norm 303 | self._act = get_act(act) 304 | self._out = out 305 | 306 | def __call__(self, features): 307 | x = tf.cast(features, prec.global_policy().compute_dtype) 308 | x = x.reshape([-1, x.shape[-1]]) 309 | for index in range(self._layers): 310 | x = self.get(f'dense{index}', tfkl.Dense, self._units)(x) 311 | x = self.get(f'norm{index}', NormLayer, self._norm)(x) 312 | x = self._act(x) 313 | x = x.reshape(features.shape[:-1] + [x.shape[-1]]) 314 | return self.get('out', DistLayer, self._shape, **self._out)(x) 315 | 316 | 317 | class GRUCell(tf.keras.layers.AbstractRNNCell): 318 | 319 | def __init__(self, size, norm=False, act='tanh', update_bias=-1, **kwargs): 320 | super().__init__() 321 | self._size = size 322 | self._act = get_act(act) 323 | self._norm = norm 324 | self._update_bias = update_bias 325 | self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs) 326 | if norm: 327 | self._norm = tfkl.LayerNormalization(dtype=tf.float32) 328 | 329 | @property 330 | def state_size(self): 331 | return self._size 332 | 333 | @tf.function 334 | def call(self, inputs, state): 335 | state = state[0] # Keras wraps the state in a list. 336 | parts = self._layer(tf.concat([inputs, state], -1)) 337 | if self._norm: 338 | dtype = parts.dtype 339 | parts = tf.cast(parts, tf.float32) 340 | parts = self._norm(parts) 341 | parts = tf.cast(parts, dtype) 342 | reset, cand, update = tf.split(parts, 3, -1) 343 | reset = tf.nn.sigmoid(reset) 344 | cand = self._act(reset * cand) 345 | update = tf.nn.sigmoid(update + self._update_bias) 346 | output = update * cand + (1 - update) * state 347 | return output, [output] 348 | 349 | 350 | class DistLayer(common.Module): 351 | 352 | def __init__( 353 | self, shape, dist='mse', min_std=0.1, init_std=0.0): 354 | self._shape = shape 355 | self._dist = dist 356 | self._min_std = min_std 357 | self._init_std = init_std 358 | 359 | def __call__(self, inputs): 360 | out = self.get('out', tfkl.Dense, np.prod(self._shape))(inputs) 361 | out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], self._shape], 0)) 362 | out = tf.cast(out, tf.float32) 363 | if self._dist in ('normal', 'tanh_normal', 'trunc_normal'): 364 | std = self.get('std', tfkl.Dense, np.prod(self._shape))(inputs) 365 | std = tf.reshape(std, tf.concat([tf.shape(inputs)[:-1], self._shape], 0)) 366 | std = tf.cast(std, tf.float32) 367 | if self._dist == 'mse': 368 | dist = tfd.Normal(out, 1.0) 369 | return tfd.Independent(dist, len(self._shape)) 370 | if self._dist == 'normal': 371 | dist = tfd.Normal(out, std) 372 | return tfd.Independent(dist, len(self._shape)) 373 | if self._dist == 'binary': 374 | dist = tfd.Bernoulli(out) 375 | return tfd.Independent(dist, len(self._shape)) 376 | if self._dist == 'tanh_normal': 377 | mean = 5 * tf.tanh(out / 5) 378 | std = tf.nn.softplus(std + self._init_std) + self._min_std 379 | dist = tfd.Normal(mean, std) 380 | dist = tfd.TransformedDistribution(dist, common.TanhBijector()) 381 | dist = tfd.Independent(dist, len(self._shape)) 382 | return common.SampleDist(dist) 383 | if self._dist == 'trunc_normal': 384 | std = 2 * tf.nn.sigmoid((std + self._init_std) / 2) + self._min_std 385 | dist = common.TruncNormalDist(tf.tanh(out), std, -1, 1) 386 | return tfd.Independent(dist, 1) 387 | if self._dist == 'onehot': 388 | return common.OneHotDist(out) 389 | raise NotImplementedError(self._dist) 390 | 391 | 392 | class NormLayer(common.Module): 393 | 394 | def __init__(self, name): 395 | if name == 'none': 396 | self._layer = None 397 | elif name == 'layer': 398 | self._layer = tfkl.LayerNormalization() 399 | else: 400 | raise NotImplementedError(name) 401 | 402 | def __call__(self, features): 403 | if not self._layer: 404 | return features 405 | return self._layer(features) 406 | 407 | 408 | def get_act(name): 409 | if name == 'none': 410 | return tf.identity 411 | if name == 'mish': 412 | return lambda x: x * tf.math.tanh(tf.nn.softplus(x)) 413 | elif hasattr(tf.nn, name): 414 | return getattr(tf.nn, name) 415 | elif hasattr(tf, name): 416 | return getattr(tf, name) 417 | else: 418 | raise NotImplementedError(name) 419 | -------------------------------------------------------------------------------- /dreamerv2/common/other.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import contextlib 3 | import re 4 | import time 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow_probability import distributions as tfd 9 | 10 | from . import dists 11 | from . import tfutils 12 | 13 | 14 | class RandomAgent: 15 | 16 | def __init__(self, act_space, logprob=False): 17 | self.act_space = act_space['action'] 18 | self.logprob = logprob 19 | if hasattr(self.act_space, 'n'): 20 | self._dist = dists.OneHotDist(tf.zeros(self.act_space.n)) 21 | else: 22 | dist = tfd.Uniform(self.act_space.low, self.act_space.high) 23 | self._dist = tfd.Independent(dist, 1) 24 | 25 | def __call__(self, obs, state=None, mode=None): 26 | action = self._dist.sample(len(obs['is_first'])) 27 | output = {'action': action} 28 | if self.logprob: 29 | output['logprob'] = self._dist.log_prob(action) 30 | return output, None 31 | 32 | 33 | def static_scan(fn, inputs, start, reverse=False): 34 | last = start 35 | outputs = [[] for _ in tf.nest.flatten(start)] 36 | indices = range(tf.nest.flatten(inputs)[0].shape[0]) 37 | if reverse: 38 | indices = reversed(indices) 39 | for index in indices: 40 | inp = tf.nest.map_structure(lambda x: x[index], inputs) 41 | last = fn(last, inp) 42 | [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] 43 | if reverse: 44 | outputs = [list(reversed(x)) for x in outputs] 45 | outputs = [tf.stack(x, 0) for x in outputs] 46 | return tf.nest.pack_sequence_as(start, outputs) 47 | 48 | 49 | def schedule(string, step): 50 | try: 51 | return float(string) 52 | except ValueError: 53 | step = tf.cast(step, tf.float32) 54 | match = re.match(r'linear\((.+),(.+),(.+)\)', string) 55 | if match: 56 | initial, final, duration = [float(group) for group in match.groups()] 57 | mix = tf.clip_by_value(step / duration, 0, 1) 58 | return (1 - mix) * initial + mix * final 59 | match = re.match(r'warmup\((.+),(.+)\)', string) 60 | if match: 61 | warmup, value = [float(group) for group in match.groups()] 62 | scale = tf.clip_by_value(step / warmup, 0, 1) 63 | return scale * value 64 | match = re.match(r'exp\((.+),(.+),(.+)\)', string) 65 | if match: 66 | initial, final, halflife = [float(group) for group in match.groups()] 67 | return (initial - final) * 0.5 ** (step / halflife) + final 68 | match = re.match(r'horizon\((.+),(.+),(.+)\)', string) 69 | if match: 70 | initial, final, duration = [float(group) for group in match.groups()] 71 | mix = tf.clip_by_value(step / duration, 0, 1) 72 | horizon = (1 - mix) * initial + mix * final 73 | return 1 - 1 / horizon 74 | raise NotImplementedError(string) 75 | 76 | 77 | def lambda_return( 78 | reward, value, pcont, bootstrap, lambda_, axis): 79 | # Setting lambda=1 gives a discounted Monte Carlo return. 80 | # Setting lambda=0 gives a fixed 1-step return. 81 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) 82 | if isinstance(pcont, (int, float)): 83 | pcont = pcont * tf.ones_like(reward) 84 | dims = list(range(reward.shape.ndims)) 85 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 86 | if axis != 0: 87 | reward = tf.transpose(reward, dims) 88 | value = tf.transpose(value, dims) 89 | pcont = tf.transpose(pcont, dims) 90 | if bootstrap is None: 91 | bootstrap = tf.zeros_like(value[-1]) 92 | next_values = tf.concat([value[1:], bootstrap[None]], 0) 93 | inputs = reward + pcont * next_values * (1 - lambda_) 94 | returns = static_scan( 95 | lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, 96 | (inputs, pcont), bootstrap, reverse=True) 97 | if axis != 0: 98 | returns = tf.transpose(returns, dims) 99 | return returns 100 | 101 | 102 | def action_noise(action, amount, act_space): 103 | if amount == 0: 104 | return action 105 | amount = tf.cast(amount, action.dtype) 106 | if hasattr(act_space, 'n'): 107 | probs = amount / action.shape[-1] + (1 - amount) * action 108 | return dists.OneHotDist(probs=probs).sample() 109 | else: 110 | return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) 111 | 112 | 113 | class StreamNorm(tfutils.Module): 114 | 115 | def __init__(self, shape=(), momentum=0.99, scale=1.0, eps=1e-8): 116 | # Momentum of 0 normalizes only based on the current batch. 117 | # Momentum of 1 disables normalization. 118 | self._shape = tuple(shape) 119 | self._momentum = momentum 120 | self._scale = scale 121 | self._eps = eps 122 | self.mag = tf.Variable(tf.ones(shape, tf.float64), False) 123 | 124 | def __call__(self, inputs): 125 | metrics = {} 126 | self.update(inputs) 127 | metrics['mean'] = inputs.mean() 128 | metrics['std'] = inputs.std() 129 | outputs = self.transform(inputs) 130 | metrics['normed_mean'] = outputs.mean() 131 | metrics['normed_std'] = outputs.std() 132 | return outputs, metrics 133 | 134 | def reset(self): 135 | self.mag.assign(tf.ones_like(self.mag)) 136 | 137 | def update(self, inputs): 138 | batch = inputs.reshape((-1,) + self._shape) 139 | mag = tf.abs(batch).mean(0).astype(tf.float64) 140 | self.mag.assign(self._momentum * self.mag + (1 - self._momentum) * mag) 141 | 142 | def transform(self, inputs): 143 | values = inputs.reshape((-1,) + self._shape) 144 | values /= self.mag.astype(inputs.dtype)[None] + self._eps 145 | values *= self._scale 146 | return values.reshape(inputs.shape) 147 | 148 | 149 | class Timer: 150 | 151 | def __init__(self): 152 | self._indurs = collections.defaultdict(list) 153 | self._outdurs = collections.defaultdict(list) 154 | self._start_times = {} 155 | self._end_times = {} 156 | 157 | @contextlib.contextmanager 158 | def section(self, name): 159 | self.start(name) 160 | yield 161 | self.end(name) 162 | 163 | def wrap(self, function, name): 164 | def wrapped(*args, **kwargs): 165 | with self.section(name): 166 | return function(*args, **kwargs) 167 | return wrapped 168 | 169 | def start(self, name): 170 | now = time.time() 171 | self._start_times[name] = now 172 | if name in self._end_times: 173 | last = self._end_times[name] 174 | self._outdurs[name].append(now - last) 175 | 176 | def end(self, name): 177 | now = time.time() 178 | self._end_times[name] = now 179 | self._indurs[name].append(now - self._start_times[name]) 180 | 181 | def result(self): 182 | metrics = {} 183 | for key in self._indurs: 184 | indurs = self._indurs[key] 185 | outdurs = self._outdurs[key] 186 | metrics[f'timer_count_{key}'] = len(indurs) 187 | metrics[f'timer_inside_{key}'] = np.sum(indurs) 188 | metrics[f'timer_outside_{key}'] = np.sum(outdurs) 189 | indurs.clear() 190 | outdurs.clear() 191 | return metrics 192 | 193 | 194 | class CarryOverState: 195 | 196 | def __init__(self, fn): 197 | self._fn = fn 198 | self._state = None 199 | 200 | def __call__(self, *args): 201 | self._state, out = self._fn(*args, self._state) 202 | return out 203 | -------------------------------------------------------------------------------- /dreamerv2/common/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import functools 4 | import itertools 5 | import json 6 | import multiprocessing as mp 7 | import os 8 | import pathlib 9 | import re 10 | import subprocess 11 | import warnings 12 | 13 | os.environ['NO_AT_BRIDGE'] = '1' # Hide X org false warning. 14 | 15 | import matplotlib 16 | matplotlib.use('Agg') 17 | import matplotlib.pyplot as plt 18 | import matplotlib.ticker as ticker 19 | import numpy as np 20 | import pandas as pd 21 | 22 | np.set_string_function(lambda x: f'') 23 | 24 | Run = collections.namedtuple('Run', 'task method seed xs ys') 25 | 26 | PALETTES = dict( 27 | discrete=( 28 | '#377eb8', '#4daf4a', '#984ea3', '#e41a1c', '#ff7f00', '#a65628', 29 | '#f781bf', '#888888', '#a6cee3', '#b2df8a', '#cab2d6', '#fb9a99', 30 | ), 31 | contrast=( 32 | '#0022ff', '#33aa00', '#ff0011', '#ddaa00', '#cc44dd', '#0088aa', 33 | '#001177', '#117700', '#990022', '#885500', '#553366', '#006666', 34 | ), 35 | gradient=( 36 | '#fde725', '#a0da39', '#4ac16d', '#1fa187', '#277f8e', '#365c8d', 37 | '#46327e', '#440154', 38 | ), 39 | baselines=( 40 | '#222222', '#666666', '#aaaaaa', '#cccccc', 41 | ), 42 | ) 43 | 44 | LEGEND = dict( 45 | fontsize='medium', numpoints=1, labelspacing=0, columnspacing=1.2, 46 | handlelength=1.5, handletextpad=0.5, loc='lower center') 47 | 48 | DEFAULT_BASELINES = ['d4pg', 'rainbow_sticky', 'human_gamer', 'impala'] 49 | 50 | 51 | def find_keys(args): 52 | filenames = [] 53 | for indir in args.indir: 54 | task = next(indir.iterdir()) # First only. 55 | for method in task.iterdir(): 56 | seed = next(indir.iterdir()) # First only. 57 | filenames += list(seed.glob('**/*.jsonl')) 58 | keys = set() 59 | for filename in filenames: 60 | keys |= set(load_jsonl(filename).columns) 61 | print(f'Keys ({len(keys)}):', ', '.join(keys), flush=True) 62 | 63 | 64 | def load_runs(args): 65 | total, toload = [], [] 66 | for indir in args.indir: 67 | filenames = list(indir.glob('**/*.jsonl')) 68 | total += filenames 69 | for filename in filenames: 70 | task, method, seed = filename.relative_to(indir).parts[:-1] 71 | if not any(p.search(task) for p in args.tasks): 72 | continue 73 | if not any(p.search(method) for p in args.methods): 74 | continue 75 | toload.append((filename, indir)) 76 | print(f'Loading {len(toload)} of {len(total)} runs...') 77 | jobs = [functools.partial(load_run, f, i, args) for f, i in toload] 78 | # Disable async data loading: 79 | # runs = [j() for j in jobs] 80 | with mp.Pool(10) as pool: 81 | promises = [pool.apply_async(j) for j in jobs] 82 | runs = [p.get() for p in promises] 83 | runs = [r for r in runs if r is not None] 84 | return runs 85 | 86 | 87 | def load_run(filename, indir, args): 88 | task, method, seed = filename.relative_to(indir).parts[:-1] 89 | prefix = f'indir{args.indir.index(indir)+1}_' 90 | if task == 'atari_jamesbond': 91 | task = 'atari_james_bond' 92 | seed = prefix + seed 93 | if args.prefix: 94 | method = prefix + method 95 | df = load_jsonl(filename) 96 | if df is None: 97 | print('Skipping empty run') 98 | return 99 | try: 100 | df = df[[args.xaxis, args.yaxis]].dropna() 101 | if args.maxval: 102 | df = df.replace([+np.inf], +args.maxval) 103 | df = df.replace([-np.inf], -args.maxval) 104 | df[args.yaxis] = df[args.yaxis].clip(-args.maxval, +args.maxval) 105 | except KeyError: 106 | return 107 | xs = df[args.xaxis].to_numpy() 108 | if args.xmult != 1: 109 | xs = xs.astype(np.float32) * args.xmult 110 | ys = df[args.yaxis].to_numpy() 111 | bins = { 112 | 'atari': 1e6, 113 | 'dmc': 1e4, 114 | 'crafter': 1e4, 115 | }.get(task.split('_')[0], 1e5) if args.bins == -1 else args.bins 116 | if bins: 117 | borders = np.arange(0, xs.max() + 1e-8, bins) 118 | xs, ys = bin_scores(xs, ys, borders) 119 | if not len(xs): 120 | print('Skipping empty run', task, method, seed) 121 | return 122 | return Run(task, method, seed, xs, ys) 123 | 124 | 125 | def load_baselines(patterns, prefix=False): 126 | runs = [] 127 | directory = pathlib.Path(__file__).parent.parent / 'scores' 128 | for filename in directory.glob('**/*_baselines.json'): 129 | for task, methods in json.loads(filename.read_text()).items(): 130 | for method, score in methods.items(): 131 | if prefix: 132 | method = f'baseline_{method}' 133 | if not any(p.search(method) for p in patterns): 134 | continue 135 | runs.append(Run(task, method, None, None, score)) 136 | return runs 137 | 138 | 139 | def stats(runs, baselines): 140 | tasks = sorted(set(r.task for r in runs)) 141 | methods = sorted(set(r.method for r in runs)) 142 | seeds = sorted(set(r.seed for r in runs)) 143 | baseline = sorted(set(r.method for r in baselines)) 144 | print('Loaded', len(runs), 'runs.') 145 | print(f'Tasks ({len(tasks)}):', ', '.join(tasks)) 146 | print(f'Methods ({len(methods)}):', ', '.join(methods)) 147 | print(f'Seeds ({len(seeds)}):', ', '.join(seeds)) 148 | print(f'Baselines ({len(baseline)}):', ', '.join(baseline)) 149 | 150 | 151 | def order_methods(runs, baselines, args): 152 | methods = [] 153 | for pattern in args.methods: 154 | for method in sorted(set(r.method for r in runs)): 155 | if pattern.search(method): 156 | if method not in methods: 157 | methods.append(method) 158 | if method not in args.colors: 159 | index = len(args.colors) % len(args.palette) 160 | args.colors[method] = args.palette[index] 161 | non_baseline_colors = len(args.colors) 162 | for pattern in args.baselines: 163 | for method in sorted(set(r.method for r in baselines)): 164 | if pattern.search(method): 165 | if method not in methods: 166 | methods.append(method) 167 | if method not in args.colors: 168 | index = len(args.colors) - non_baseline_colors 169 | index = index % len(PALETTES['baselines']) 170 | args.colors[method] = PALETTES['baselines'][index] 171 | return methods 172 | 173 | 174 | def figure(runs, methods, args): 175 | tasks = sorted(set(r.task for r in runs if r.xs is not None)) 176 | rows = int(np.ceil((len(tasks) + len(args.add)) / args.cols)) 177 | figsize = args.size[0] * args.cols, args.size[1] * rows 178 | fig, axes = plt.subplots(rows, args.cols, figsize=figsize, squeeze=False) 179 | for task, ax in zip(tasks, axes.flatten()): 180 | relevant = [r for r in runs if r.task == task] 181 | plot(task, ax, relevant, methods, args) 182 | for name, ax in zip(args.add, axes.flatten()[len(tasks):]): 183 | ax.set_facecolor((0.9, 0.9, 0.9)) 184 | if name == 'median': 185 | plot_combined( 186 | 'combined_median', ax, runs, methods, args, 187 | agg=lambda x: np.nanmedian(x, -1)) 188 | elif name == 'mean': 189 | plot_combined( 190 | 'combined_mean', ax, runs, methods, args, 191 | agg=lambda x: np.nanmean(x, -1)) 192 | elif name == 'gamer_median': 193 | plot_combined( 194 | 'combined_gamer_median', ax, runs, methods, args, 195 | lo='random', hi='human_gamer', 196 | agg=lambda x: np.nanmedian(x, -1)) 197 | elif name == 'gamer_mean': 198 | plot_combined( 199 | 'combined_gamer_mean', ax, runs, methods, args, 200 | lo='random', hi='human_gamer', 201 | agg=lambda x: np.nanmean(x, -1)) 202 | elif name == 'record_mean': 203 | plot_combined( 204 | 'combined_record_mean', ax, runs, methods, args, 205 | lo='random', hi='record', 206 | agg=lambda x: np.nanmean(x, -1)) 207 | elif name == 'clip_record_mean': 208 | plot_combined( 209 | 'combined_clipped_record_mean', ax, runs, methods, args, 210 | lo='random', hi='record', clip=True, 211 | agg=lambda x: np.nanmean(x, -1)) 212 | elif name == 'seeds': 213 | plot_combined( 214 | 'combined_seeds', ax, runs, methods, args, 215 | agg=lambda x: np.isfinite(x).sum(-1)) 216 | elif name == 'human_above': 217 | plot_combined( 218 | 'combined_above_human_gamer', ax, runs, methods, args, 219 | agg=lambda y: (y >= 1.0).astype(float).sum(-1)) 220 | elif name == 'human_below': 221 | plot_combined( 222 | 'combined_below_human_gamer', ax, runs, methods, args, 223 | agg=lambda y: (y <= 1.0).astype(float).sum(-1)) 224 | else: 225 | raise NotImplementedError(name) 226 | if args.xlim: 227 | for ax in axes[:-1].flatten(): 228 | ax.xaxis.get_offset_text().set_visible(False) 229 | if args.xlabel: 230 | for ax in axes[-1]: 231 | ax.set_xlabel(args.xlabel) 232 | if args.ylabel: 233 | for ax in axes[:, 0]: 234 | ax.set_ylabel(args.ylabel) 235 | for ax in axes.flatten()[len(tasks) + len(args.add):]: 236 | ax.axis('off') 237 | legend(fig, args.labels, ncol=args.legendcols, **LEGEND) 238 | return fig 239 | 240 | 241 | def plot(task, ax, runs, methods, args): 242 | assert runs 243 | try: 244 | title = task.split('_', 1)[1].replace('_', ' ').title() 245 | except IndexError: 246 | title = task.title() 247 | ax.set_title(title) 248 | xlim = [+np.inf, -np.inf] 249 | for index, method in enumerate(methods): 250 | relevant = [r for r in runs if r.method == method] 251 | if not relevant: 252 | continue 253 | if any(r.xs is None for r in relevant): 254 | baseline(index, method, ax, relevant, args) 255 | else: 256 | if args.agg == 'none': 257 | xs, ys = curve_lines(index, task, method, ax, relevant, args) 258 | else: 259 | xs, ys = curve_area(index, task, method, ax, relevant, args) 260 | if len(xs) == len(ys) == 0: 261 | print(f'Skipping empty: {task} {method}') 262 | continue 263 | xlim = [min(xlim[0], np.nanmin(xs)), max(xlim[1], np.nanmax(xs))] 264 | ax.ticklabel_format(axis='x', style='sci', scilimits=(0, 0)) 265 | steps = [1, 2, 2.5, 5, 10] 266 | ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps)) 267 | ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps)) 268 | if np.isfinite(xlim).all(): 269 | ax.set_xlim(args.xlim or xlim) 270 | if args.xlim: 271 | ticks = sorted({*ax.get_xticks(), *args.xlim}) 272 | ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]] 273 | ax.set_xticks(ticks) 274 | if args.ylim: 275 | ax.set_ylim(args.ylim) 276 | if args.ylimticks: 277 | ticks = sorted({*ax.get_yticks(), *args.ylim}) 278 | ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]] 279 | ax.set_yticks(ticks) 280 | 281 | 282 | def plot_combined( 283 | name, ax, runs, methods, args, agg, lo=None, hi=None, clip=False): 284 | tasks = sorted(set(run.task for run in runs if run.xs is not None)) 285 | seeds = list(set(run.seed for run in runs)) 286 | runs = [r for r in runs if r.task in tasks] # Discard unused baselines. 287 | # Bin all runs onto the same X steps. 288 | borders = sorted( 289 | [r.xs for r in runs if r.xs is not None], 290 | key=lambda x: np.nanmax(x))[-1] 291 | for index, run in enumerate(runs): 292 | if run.xs is None: 293 | continue 294 | xs, ys = bin_scores(run.xs, run.ys, borders, fill='last') 295 | runs[index] = run._replace(xs=xs, ys=ys) 296 | # Per-task normalization by low and high baseline. 297 | if lo or hi: 298 | mins = collections.defaultdict(list) 299 | maxs = collections.defaultdict(list) 300 | [mins[r.task].append(r.ys) for r in load_baselines([re.compile(lo)])] 301 | [maxs[r.task].append(r.ys) for r in load_baselines([re.compile(hi)])] 302 | mins = {task: min(ys) for task, ys in mins.items() if task in tasks} 303 | maxs = {task: max(ys) for task, ys in maxs.items() if task in tasks} 304 | missing_baselines = [] 305 | for task in tasks: 306 | if task not in mins or task not in maxs: 307 | missing_baselines.append(task) 308 | if set(missing_baselines) == set(tasks): 309 | print(f'No baselines found to normalize any tasks in {name} plot.') 310 | else: 311 | for task in missing_baselines: 312 | print(f'No baselines found to normalize {task} in {name} plot.') 313 | for index, run in enumerate(runs): 314 | if run.task not in mins or run.task not in maxs: 315 | continue 316 | ys = (run.ys - mins[run.task]) / (maxs[run.task] - mins[run.task]) 317 | if clip: 318 | ys = np.minimum(ys, 1.0) 319 | runs[index] = run._replace(ys=ys) 320 | # Aggregate across tasks but not methods or seeds. 321 | combined = [] 322 | for method, seed in itertools.product(methods, seeds): 323 | relevant = [r for r in runs if r.method == method and r.seed == seed] 324 | if not relevant: 325 | continue 326 | if relevant[0].xs is None: 327 | xs, ys = None, np.array([r.ys for r in relevant]) 328 | else: 329 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in relevant])) 330 | with warnings.catch_warnings(): # Ignore empty slice warnings. 331 | warnings.simplefilter('ignore', category=RuntimeWarning) 332 | combined.append(Run('combined', method, seed, xs, agg(ys))) 333 | plot(name, ax, combined, methods, args) 334 | 335 | 336 | def curve_lines(index, task, method, ax, runs, args): 337 | zorder = 10000 - 10 * index - 1 338 | for run in runs: 339 | color = args.colors[method] 340 | ax.plot(run.xs, run.ys, label=method, color=color, zorder=zorder) 341 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in runs])) 342 | return xs, ys 343 | 344 | 345 | def curve_area(index, task, method, ax, runs, args): 346 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in runs])) 347 | with warnings.catch_warnings(): # NaN buckets remain NaN. 348 | warnings.simplefilter('ignore', category=RuntimeWarning) 349 | if args.agg == 'std1': 350 | mean, std = np.nanmean(ys, -1), np.nanstd(ys, -1) 351 | lo, mi, hi = mean - std, mean, mean + std 352 | elif args.agg == 'per0': 353 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (0, 50, 100)] 354 | elif args.agg == 'per5': 355 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (5, 50, 95)] 356 | elif args.agg == 'per25': 357 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (25, 50, 75)] 358 | else: 359 | raise NotImplementedError(args.agg) 360 | color = args.colors[method] 361 | kw = dict(color=color, zorder=1000 - 10 * index, alpha=0.1, linewidths=0) 362 | mask = ~np.isnan(mi) 363 | xs, lo, mi, hi = xs[mask], lo[mask], mi[mask], hi[mask] 364 | ax.fill_between(xs, lo, hi, **kw) 365 | ax.plot(xs, mi, label=method, color=color, zorder=10000 - 10 * index - 1) 366 | return xs, mi 367 | 368 | 369 | def baseline(index, method, ax, runs, args): 370 | assert all(run.xs is None for run in runs) 371 | ys = np.array([run.ys for run in runs]) 372 | mean, std = ys.mean(), ys.std() 373 | color = args.colors[method] 374 | kw = dict(color=color, zorder=500 - 20 * index - 1, alpha=0.1, linewidths=0) 375 | ax.fill_between([-np.inf, np.inf], [mean - std] * 2, [mean + std] * 2, **kw) 376 | kw = dict(ls='--', color=color, zorder=5000 - 10 * index - 1) 377 | ax.axhline(mean, label=method, **kw) 378 | 379 | 380 | def legend(fig, mapping=None, **kwargs): 381 | entries = {} 382 | for ax in fig.axes: 383 | for handle, label in zip(*ax.get_legend_handles_labels()): 384 | if mapping and label in mapping: 385 | label = mapping[label] 386 | entries[label] = handle 387 | leg = fig.legend(entries.values(), entries.keys(), **kwargs) 388 | leg.get_frame().set_edgecolor('white') 389 | extent = leg.get_window_extent(fig.canvas.get_renderer()) 390 | extent = extent.transformed(fig.transFigure.inverted()) 391 | yloc, xloc = kwargs['loc'].split() 392 | y0 = dict(lower=extent.y1, center=0, upper=0)[yloc] 393 | y1 = dict(lower=1, center=1, upper=extent.y0)[yloc] 394 | x0 = dict(left=extent.x1, center=0, right=0)[xloc] 395 | x1 = dict(left=1, center=1, right=extent.x0)[xloc] 396 | fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5) 397 | 398 | 399 | def save(fig, args): 400 | args.outdir.mkdir(parents=True, exist_ok=True) 401 | filename = args.outdir / 'curves.png' 402 | fig.savefig(filename, dpi=args.dpi) 403 | print('Saved to', filename) 404 | filename = args.outdir / 'curves.pdf' 405 | fig.savefig(filename) 406 | try: 407 | subprocess.call(['pdfcrop', str(filename), str(filename)]) 408 | except FileNotFoundError: 409 | print('Install texlive-extra-utils to crop PDF outputs.') 410 | 411 | 412 | def bin_scores(xs, ys, borders, reducer=np.nanmean, fill='nan'): 413 | order = np.argsort(xs) 414 | xs, ys = xs[order], ys[order] 415 | binned = [] 416 | with warnings.catch_warnings(): # Empty buckets become NaN. 417 | warnings.simplefilter('ignore', category=RuntimeWarning) 418 | for start, stop in zip(borders[:-1], borders[1:]): 419 | left = (xs <= start).sum() 420 | right = (xs <= stop).sum() 421 | if left < right: 422 | value = reducer(ys[left:right]) 423 | elif binned: 424 | value = {'nan': np.nan, 'last': binned[-1]}[fill] 425 | else: 426 | value = np.nan 427 | binned.append(value) 428 | return borders[1:], np.array(binned) 429 | 430 | 431 | def stack_scores(multiple_xs, multiple_ys, fill='last'): 432 | longest_xs = sorted(multiple_xs, key=lambda x: len(x))[-1] 433 | multiple_padded_ys = [] 434 | for xs, ys in zip(multiple_xs, multiple_ys): 435 | assert (longest_xs[:len(xs)] == xs).all(), (list(xs), list(longest_xs)) 436 | value = {'nan': np.nan, 'last': ys[-1]}[fill] 437 | padding = [value] * (len(longest_xs) - len(xs)) 438 | padded_ys = np.concatenate([ys, padding]) 439 | multiple_padded_ys.append(padded_ys) 440 | stacked_ys = np.stack(multiple_padded_ys, -1) 441 | return longest_xs, stacked_ys 442 | 443 | 444 | def load_jsonl(filename): 445 | try: 446 | with filename.open() as f: 447 | lines = list(f.readlines()) 448 | records = [] 449 | for index, line in enumerate(lines): 450 | try: 451 | records.append(json.loads(line)) 452 | except Exception: 453 | if index == len(lines) - 1: 454 | continue # Silently skip last line if it is incomplete. 455 | raise ValueError( 456 | f'Skipping invalid JSON line ({index+1}/{len(lines)+1}) in' 457 | f'{filename}: {line}') 458 | return pd.DataFrame(records) 459 | except ValueError as e: 460 | print('Invalid', filename, e) 461 | return None 462 | 463 | 464 | def save_runs(runs, filename): 465 | filename.parent.mkdir(parents=True, exist_ok=True) 466 | records = [] 467 | for run in runs: 468 | if run.xs is None: 469 | continue 470 | records.append(dict( 471 | task=run.task, method=run.method, seed=run.seed, 472 | xs=run.xs.tolist(), ys=run.ys.tolist())) 473 | runs = json.dumps(records) 474 | filename.write_text(runs) 475 | print('Saved', filename) 476 | 477 | 478 | def main(args): 479 | find_keys(args) 480 | runs = load_runs(args) 481 | save_runs(runs, args.outdir / 'runs.json') 482 | baselines = load_baselines(args.baselines, args.prefix) 483 | stats(runs, baselines) 484 | methods = order_methods(runs, baselines, args) 485 | if not runs: 486 | print('Noting to plot.') 487 | return 488 | # Adjust options based on loaded runs. 489 | tasks = set(r.task for r in runs) 490 | if 'auto' in args.add: 491 | index = args.add.index('auto') 492 | del args.add[index] 493 | atari = any(run.task.startswith('atari_') for run in runs) 494 | if len(tasks) < 2: 495 | pass 496 | elif atari: 497 | args.add[index:index] = [ 498 | 'gamer_median', 'gamer_mean', 'record_mean', 'clip_record_mean', 499 | ] 500 | else: 501 | args.add[index:index] = ['mean', 'median'] 502 | args.cols = min(args.cols, len(tasks) + len(args.add)) 503 | args.legendcols = min(args.legendcols, args.cols) 504 | print('Plotting...') 505 | fig = figure(runs + baselines, methods, args) 506 | save(fig, args) 507 | 508 | 509 | def parse_args(): 510 | boolean = lambda x: bool(['False', 'True'].index(x)) 511 | parser = argparse.ArgumentParser() 512 | parser.add_argument('--indir', nargs='+', type=pathlib.Path, required=True) 513 | parser.add_argument('--indir-prefix', type=pathlib.Path) 514 | parser.add_argument('--outdir', type=pathlib.Path, required=True) 515 | parser.add_argument('--subdir', type=boolean, default=True) 516 | parser.add_argument('--xaxis', type=str, default='step') 517 | parser.add_argument('--yaxis', type=str, default='eval_return') 518 | parser.add_argument('--tasks', nargs='+', default=[r'.*']) 519 | parser.add_argument('--methods', nargs='+', default=[r'.*']) 520 | parser.add_argument('--baselines', nargs='+', default=DEFAULT_BASELINES) 521 | parser.add_argument('--prefix', type=boolean, default=False) 522 | parser.add_argument('--bins', type=float, default=-1) 523 | parser.add_argument('--agg', type=str, default='std1') 524 | parser.add_argument('--size', nargs=2, type=float, default=[2.5, 2.3]) 525 | parser.add_argument('--dpi', type=int, default=80) 526 | parser.add_argument('--cols', type=int, default=6) 527 | parser.add_argument('--xlim', nargs=2, type=float, default=None) 528 | parser.add_argument('--ylim', nargs=2, type=float, default=None) 529 | parser.add_argument('--ylimticks', type=boolean, default=True) 530 | parser.add_argument('--xlabel', type=str, default=None) 531 | parser.add_argument('--ylabel', type=str, default=None) 532 | parser.add_argument('--xticks', type=int, default=6) 533 | parser.add_argument('--yticks', type=int, default=5) 534 | parser.add_argument('--xmult', type=float, default=1) 535 | parser.add_argument('--labels', nargs='+', default=None) 536 | parser.add_argument('--palette', nargs='+', default=['contrast']) 537 | parser.add_argument('--legendcols', type=int, default=4) 538 | parser.add_argument('--colors', nargs='+', default={}) 539 | parser.add_argument('--maxval', type=float, default=0) 540 | parser.add_argument('--add', nargs='+', type=str, default=['auto', 'seeds']) 541 | args = parser.parse_args() 542 | if args.subdir: 543 | args.outdir /= args.indir[0].stem 544 | if args.indir_prefix: 545 | args.indir = [args.indir_prefix / indir for indir in args.indir] 546 | args.indir = [d.expanduser() for d in args.indir] 547 | args.outdir = args.outdir.expanduser() 548 | if args.labels: 549 | assert len(args.labels) % 2 == 0 550 | args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])} 551 | if args.colors: 552 | assert len(args.colors) % 2 == 0 553 | args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])} 554 | args.tasks = [re.compile(p) for p in args.tasks] 555 | args.methods = [re.compile(p) for p in args.methods] 556 | args.baselines = [re.compile(p) for p in args.baselines] 557 | if 'return' not in args.yaxis: 558 | args.baselines = [] 559 | if args.prefix is None: 560 | args.prefix = len(args.indir) > 1 561 | if len(args.palette) == 1 and args.palette[0] in PALETTES: 562 | args.palette = 10 * PALETTES[args.palette[0]] 563 | if len(args.add) == 1 and args.add[0] == 'none': 564 | args.add = [] 565 | return args 566 | 567 | 568 | if __name__ == '__main__': 569 | main(parse_args()) 570 | -------------------------------------------------------------------------------- /dreamerv2/common/replay.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import datetime 3 | import io 4 | import pathlib 5 | import uuid 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | 11 | class Replay: 12 | 13 | def __init__( 14 | self, directory, capacity=0, ongoing=False, minlen=1, maxlen=0, 15 | prioritize_ends=False): 16 | self._directory = pathlib.Path(directory).expanduser() 17 | self._directory.mkdir(parents=True, exist_ok=True) 18 | self._capacity = capacity 19 | self._ongoing = ongoing 20 | self._minlen = minlen 21 | self._maxlen = maxlen 22 | self._prioritize_ends = prioritize_ends 23 | self._random = np.random.RandomState() 24 | # filename -> key -> value_sequence 25 | self._complete_eps = load_episodes(self._directory, capacity, minlen) 26 | # worker -> key -> value_sequence 27 | self._ongoing_eps = collections.defaultdict( 28 | lambda: collections.defaultdict(list)) 29 | self._total_episodes, self._total_steps = count_episodes(directory) 30 | self._loaded_episodes = len(self._complete_eps) 31 | self._loaded_steps = sum(eplen(x) for x in self._complete_eps.values()) 32 | 33 | @property 34 | def stats(self): 35 | return { 36 | 'total_steps': self._total_steps, 37 | 'total_episodes': self._total_episodes, 38 | 'loaded_steps': self._loaded_steps, 39 | 'loaded_episodes': self._loaded_episodes, 40 | } 41 | 42 | def add_step(self, transition, worker=0): 43 | episode = self._ongoing_eps[worker] 44 | for key, value in transition.items(): 45 | episode[key].append(value) 46 | if transition['is_last']: 47 | self.add_episode(episode) 48 | episode.clear() 49 | 50 | def add_episode(self, episode): 51 | length = eplen(episode) 52 | if length < self._minlen: 53 | print(f'Skipping short episode of length {length}.') 54 | return 55 | self._total_steps += length 56 | self._loaded_steps += length 57 | self._total_episodes += 1 58 | self._loaded_episodes += 1 59 | episode = {key: convert(value) for key, value in episode.items()} 60 | filename = save_episode(self._directory, episode) 61 | self._complete_eps[str(filename)] = episode 62 | self._enforce_limit() 63 | 64 | def dataset(self, batch, length): 65 | example = next(iter(self._generate_chunks(length))) 66 | dataset = tf.data.Dataset.from_generator( 67 | lambda: self._generate_chunks(length), 68 | {k: v.dtype for k, v in example.items()}, 69 | {k: v.shape for k, v in example.items()}) 70 | dataset = dataset.batch(batch, drop_remainder=True) 71 | dataset = dataset.prefetch(5) 72 | return dataset 73 | 74 | def _generate_chunks(self, length): 75 | sequence = self._sample_sequence() 76 | while True: 77 | chunk = collections.defaultdict(list) 78 | added = 0 79 | while added < length: 80 | needed = length - added 81 | adding = {k: v[:needed] for k, v in sequence.items()} 82 | sequence = {k: v[needed:] for k, v in sequence.items()} 83 | for key, value in adding.items(): 84 | chunk[key].append(value) 85 | added += len(adding['action']) 86 | if len(sequence['action']) < 1: 87 | sequence = self._sample_sequence() 88 | chunk = {k: np.concatenate(v) for k, v in chunk.items()} 89 | yield chunk 90 | 91 | def _sample_sequence(self): 92 | episodes = list(self._complete_eps.values()) 93 | if self._ongoing: 94 | episodes += [ 95 | x for x in self._ongoing_eps.values() 96 | if eplen(x) >= self._minlen] 97 | episode = self._random.choice(episodes) 98 | total = len(episode['action']) 99 | length = total 100 | if self._maxlen: 101 | length = min(length, self._maxlen) 102 | # Randomize length to avoid all chunks ending at the same time in case the 103 | # episodes are all of the same length. 104 | length -= np.random.randint(self._minlen) 105 | length = max(self._minlen, length) 106 | upper = total - length + 1 107 | if self._prioritize_ends: 108 | upper += self._minlen 109 | index = min(self._random.randint(upper), total - length) 110 | sequence = { 111 | k: convert(v[index: index + length]) 112 | for k, v in episode.items() if not k.startswith('log_')} 113 | sequence['is_first'] = np.zeros(len(sequence['action']), np.bool) 114 | sequence['is_first'][0] = True 115 | if self._maxlen: 116 | assert self._minlen <= len(sequence['action']) <= self._maxlen 117 | return sequence 118 | 119 | def _enforce_limit(self): 120 | if not self._capacity: 121 | return 122 | while self._loaded_episodes > 1 and self._loaded_steps > self._capacity: 123 | # Relying on Python preserving the insertion order of dicts. 124 | oldest, episode = next(iter(self._complete_eps.items())) 125 | self._loaded_steps -= eplen(episode) 126 | self._loaded_episodes -= 1 127 | del self._complete_eps[oldest] 128 | 129 | 130 | def count_episodes(directory): 131 | filenames = list(directory.glob('*.npz')) 132 | num_episodes = len(filenames) 133 | num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames) 134 | return num_episodes, num_steps 135 | 136 | 137 | def save_episode(directory, episode): 138 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 139 | identifier = str(uuid.uuid4().hex) 140 | length = eplen(episode) 141 | filename = directory / f'{timestamp}-{identifier}-{length}.npz' 142 | with io.BytesIO() as f1: 143 | np.savez_compressed(f1, **episode) 144 | f1.seek(0) 145 | with filename.open('wb') as f2: 146 | f2.write(f1.read()) 147 | return filename 148 | 149 | 150 | def load_episodes(directory, capacity=None, minlen=1): 151 | # The returned directory from filenames to episodes is guaranteed to be in 152 | # temporally sorted order. 153 | filenames = sorted(directory.glob('*.npz')) 154 | if capacity: 155 | num_steps = 0 156 | num_episodes = 0 157 | for filename in reversed(filenames): 158 | length = int(str(filename).split('-')[-1][:-4]) 159 | num_steps += length 160 | num_episodes += 1 161 | if num_steps >= capacity: 162 | break 163 | filenames = filenames[-num_episodes:] 164 | episodes = {} 165 | for filename in filenames: 166 | try: 167 | with filename.open('rb') as f: 168 | episode = np.load(f) 169 | episode = {k: episode[k] for k in episode.keys()} 170 | except Exception as e: 171 | print(f'Could not load episode {str(filename)}: {e}') 172 | continue 173 | episodes[str(filename)] = episode 174 | return episodes 175 | 176 | 177 | def convert(value): 178 | value = np.array(value) 179 | if np.issubdtype(value.dtype, np.floating): 180 | return value.astype(np.float32) 181 | elif np.issubdtype(value.dtype, np.signedinteger): 182 | return value.astype(np.int32) 183 | elif np.issubdtype(value.dtype, np.uint8): 184 | return value.astype(np.uint8) 185 | return value 186 | 187 | 188 | def eplen(episode): 189 | return len(episode['action']) - 1 190 | -------------------------------------------------------------------------------- /dreamerv2/common/tfutils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import pickle 3 | import re 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.keras import mixed_precision as prec 8 | 9 | try: 10 | from tensorflow.python.distribute import values 11 | except Exception: 12 | from google3.third_party.tensorflow.python.distribute import values 13 | 14 | tf.tensor = tf.convert_to_tensor 15 | for base in (tf.Tensor, tf.Variable, values.PerReplica): 16 | base.mean = tf.math.reduce_mean 17 | base.std = tf.math.reduce_std 18 | base.var = tf.math.reduce_variance 19 | base.sum = tf.math.reduce_sum 20 | base.any = tf.math.reduce_any 21 | base.all = tf.math.reduce_all 22 | base.min = tf.math.reduce_min 23 | base.max = tf.math.reduce_max 24 | base.abs = tf.math.abs 25 | base.logsumexp = tf.math.reduce_logsumexp 26 | base.transpose = tf.transpose 27 | base.reshape = tf.reshape 28 | base.astype = tf.cast 29 | 30 | 31 | # values.PerReplica.dtype = property(lambda self: self.values[0].dtype) 32 | 33 | # tf.TensorHandle.__repr__ = lambda x: '' 34 | # tf.TensorHandle.__str__ = lambda x: '' 35 | # np.set_printoptions(threshold=5, edgeitems=0) 36 | 37 | 38 | class Module(tf.Module): 39 | 40 | def save(self, filename): 41 | values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) 42 | amount = len(tf.nest.flatten(values)) 43 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values))) 44 | print(f'Save checkpoint with {amount} tensors and {count} parameters.') 45 | with pathlib.Path(filename).open('wb') as f: 46 | pickle.dump(values, f) 47 | 48 | def load(self, filename): 49 | with pathlib.Path(filename).open('rb') as f: 50 | values = pickle.load(f) 51 | amount = len(tf.nest.flatten(values)) 52 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values))) 53 | print(f'Load checkpoint with {amount} tensors and {count} parameters.') 54 | tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) 55 | 56 | def get(self, name, ctor, *args, **kwargs): 57 | # Create or get layer by name to avoid mentioning it in the constructor. 58 | if not hasattr(self, '_modules'): 59 | self._modules = {} 60 | if name not in self._modules: 61 | self._modules[name] = ctor(*args, **kwargs) 62 | return self._modules[name] 63 | 64 | 65 | class Optimizer(tf.Module): 66 | 67 | def __init__( 68 | self, name, lr, eps=1e-4, clip=None, wd=None, 69 | opt='adam', wd_pattern=r'.*'): 70 | assert 0 <= wd < 1 71 | assert not clip or 1 <= clip 72 | self._name = name 73 | self._clip = clip 74 | self._wd = wd 75 | self._wd_pattern = wd_pattern 76 | self._opt = { 77 | 'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps), 78 | 'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps), 79 | 'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps), 80 | 'sgd': lambda: tf.optimizers.SGD(lr), 81 | 'momentum': lambda: tf.optimizers.SGD(lr, 0.9), 82 | }[opt]() 83 | self._mixed = (prec.global_policy().compute_dtype == tf.float16) 84 | if self._mixed: 85 | self._opt = prec.LossScaleOptimizer(self._opt, dynamic=True) 86 | self._once = True 87 | 88 | @property 89 | def variables(self): 90 | return self._opt.variables() 91 | 92 | def __call__(self, tape, loss, modules): 93 | assert loss.dtype is tf.float32, (self._name, loss.dtype) 94 | assert len(loss.shape) == 0, (self._name, loss.shape) 95 | metrics = {} 96 | 97 | # Find variables. 98 | modules = modules if hasattr(modules, '__len__') else (modules,) 99 | varibs = tf.nest.flatten([module.variables for module in modules]) 100 | count = sum(np.prod(x.shape) for x in varibs) 101 | if self._once: 102 | print(f'Found {count} {self._name} parameters.') 103 | self._once = False 104 | 105 | # Check loss. 106 | tf.debugging.check_numerics(loss, self._name + '_loss') 107 | metrics[f'{self._name}_loss'] = loss 108 | 109 | # Compute scaled gradient. 110 | if self._mixed: 111 | with tape: 112 | loss = self._opt.get_scaled_loss(loss) 113 | grads = tape.gradient(loss, varibs) 114 | if self._mixed: 115 | grads = self._opt.get_unscaled_gradients(grads) 116 | if self._mixed: 117 | metrics[f'{self._name}_loss_scale'] = self._opt.loss_scale 118 | 119 | # Distributed sync. 120 | context = tf.distribute.get_replica_context() 121 | if context: 122 | grads = context.all_reduce('mean', grads) 123 | 124 | # Gradient clipping. 125 | norm = tf.linalg.global_norm(grads) 126 | if not self._mixed: 127 | tf.debugging.check_numerics(norm, self._name + '_norm') 128 | if self._clip: 129 | grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) 130 | metrics[f'{self._name}_grad_norm'] = norm 131 | 132 | # Weight decay. 133 | if self._wd: 134 | self._apply_weight_decay(varibs) 135 | 136 | # Apply gradients. 137 | self._opt.apply_gradients( 138 | zip(grads, varibs), 139 | experimental_aggregate_gradients=False) 140 | 141 | return metrics 142 | 143 | def _apply_weight_decay(self, varibs): 144 | nontrivial = (self._wd_pattern != r'.*') 145 | if nontrivial: 146 | print('Applied weight decay to variables:') 147 | for var in varibs: 148 | if re.search(self._wd_pattern, self._name + '/' + var.name): 149 | if nontrivial: 150 | print('- ' + self._name + '/' + var.name) 151 | var.assign((1 - self._wd) * var) 152 | -------------------------------------------------------------------------------- /dreamerv2/common/when.py: -------------------------------------------------------------------------------- 1 | class Every: 2 | 3 | def __init__(self, every): 4 | self._every = every 5 | self._last = None 6 | 7 | def __call__(self, step): 8 | step = int(step) 9 | if not self._every: 10 | return False 11 | if self._last is None: 12 | self._last = step 13 | return True 14 | if step >= self._last + self._every: 15 | self._last += self._every 16 | return True 17 | return False 18 | 19 | 20 | class Once: 21 | 22 | def __init__(self): 23 | self._once = True 24 | 25 | def __call__(self): 26 | if self._once: 27 | self._once = False 28 | return True 29 | return False 30 | 31 | 32 | class Until: 33 | 34 | def __init__(self, until): 35 | self._until = until 36 | 37 | def __call__(self, step): 38 | step = int(step) 39 | if not self._until: 40 | return True 41 | return step < self._until 42 | -------------------------------------------------------------------------------- /dreamerv2/configs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: /dev/null 5 | seed: 0 6 | task: dmc_walker_walk 7 | envs: 1 8 | envs_parallel: none 9 | render_size: [64, 64] 10 | dmc_camera: -1 11 | atari_grayscale: True 12 | time_limit: 0 13 | action_repeat: 1 14 | steps: 1e8 15 | log_every: 1e4 16 | eval_every: 1e5 17 | eval_eps: 1 18 | prefill: 10000 19 | pretrain: 1 20 | train_every: 5 21 | train_steps: 1 22 | expl_until: 0 23 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 24 | dataset: {batch: 16, length: 50} 25 | log_keys_video: ['image'] 26 | log_keys_sum: '^$' 27 | log_keys_mean: '^$' 28 | log_keys_max: '^$' 29 | precision: 16 30 | jit: True 31 | 32 | # Agent 33 | clip_rewards: tanh 34 | expl_behavior: greedy 35 | expl_noise: 0.0 36 | eval_noise: 0.0 37 | eval_state_mean: False 38 | 39 | # World Model 40 | grad_heads: [decoder, reward, discount] 41 | pred_discount: True 42 | rssm: {ensemble: 1, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 43 | encoder: {mlp_keys: '.*', cnn_keys: '.*', act: elu, norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} 44 | decoder: {mlp_keys: '.*', cnn_keys: '.*', act: elu, norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]} 45 | reward_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 46 | discount_head: {layers: 4, units: 400, act: elu, norm: none, dist: binary} 47 | loss_scales: {kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0} 48 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 49 | model_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100, wd: 1e-6} 50 | 51 | # Actor Critic 52 | actor: {layers: 4, units: 400, act: elu, norm: none, dist: auto, min_std: 0.1} 53 | critic: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 54 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 55 | critic_opt: {opt: adam, lr: 2e-4, eps: 1e-5, clip: 100, wd: 1e-6} 56 | discount: 0.99 57 | discount_lambda: 0.95 58 | imag_horizon: 15 59 | actor_grad: auto 60 | actor_grad_mix: 0.1 61 | actor_ent: 2e-3 62 | slow_target: True 63 | slow_target_update: 100 64 | slow_target_fraction: 1 65 | slow_baseline: True 66 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 67 | 68 | # Exploration 69 | expl_intr_scale: 1.0 70 | expl_extr_scale: 0.0 71 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 72 | expl_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 73 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 74 | disag_target: stoch 75 | disag_log: False 76 | disag_models: 10 77 | disag_offset: 1 78 | disag_action_cond: True 79 | expl_model_loss: kl 80 | 81 | atari: 82 | 83 | task: atari_pong 84 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 85 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 86 | time_limit: 27000 87 | action_repeat: 4 88 | steps: 5e7 89 | eval_every: 2.5e5 90 | log_every: 1e4 91 | prefill: 50000 92 | train_every: 16 93 | clip_rewards: tanh 94 | rssm: {hidden: 600, deter: 600} 95 | model_opt.lr: 2e-4 96 | actor_opt.lr: 4e-5 97 | critic_opt.lr: 1e-4 98 | actor_ent: 1e-3 99 | discount: 0.999 100 | loss_scales.kl: 0.1 101 | loss_scales.discount: 5.0 102 | 103 | crafter: 104 | 105 | task: crafter_reward 106 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 107 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 108 | log_keys_max: '^log_achievement_.*' 109 | log_keys_sum: '^log_reward$' 110 | rssm: {hidden: 1024, deter: 1024} 111 | discount: 0.999 112 | model_opt.lr: 1e-4 113 | actor_opt.lr: 1e-4 114 | critic_opt.lr: 1e-4 115 | actor_ent: 3e-3 116 | .*\.norm: layer 117 | 118 | dmc_vision: 119 | 120 | task: dmc_walker_walk 121 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 122 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 123 | action_repeat: 2 124 | eval_every: 1e4 125 | prefill: 1000 126 | pretrain: 100 127 | clip_rewards: identity 128 | pred_discount: False 129 | replay.prioritize_ends: False 130 | grad_heads: [decoder, reward] 131 | rssm: {hidden: 200, deter: 200} 132 | model_opt.lr: 3e-4 133 | actor_opt.lr: 8e-5 134 | critic_opt.lr: 8e-5 135 | actor_ent: 1e-4 136 | kl.free: 1.0 137 | 138 | dmc_proprio: 139 | 140 | task: dmc_walker_walk 141 | encoder: {mlp_keys: '.*', cnn_keys: '$^'} 142 | decoder: {mlp_keys: '.*', cnn_keys: '$^'} 143 | action_repeat: 2 144 | eval_every: 1e4 145 | prefill: 1000 146 | pretrain: 100 147 | clip_rewards: identity 148 | pred_discount: False 149 | replay.prioritize_ends: False 150 | grad_heads: [decoder, reward] 151 | rssm: {hidden: 200, deter: 200} 152 | model_opt.lr: 3e-4 153 | actor_opt.lr: 8e-5 154 | critic_opt.lr: 8e-5 155 | actor_ent: 1e-4 156 | kl.free: 1.0 157 | 158 | debug: 159 | 160 | jit: False 161 | time_limit: 100 162 | eval_every: 300 163 | log_every: 300 164 | prefill: 100 165 | pretrain: 1 166 | train_steps: 1 167 | replay: {minlen: 10, maxlen: 30} 168 | dataset: {batch: 10, length: 10} 169 | -------------------------------------------------------------------------------- /dreamerv2/expl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability import distributions as tfd 3 | 4 | import agent 5 | import common 6 | 7 | 8 | class Random(common.Module): 9 | 10 | def __init__(self, config, act_space, wm, tfstep, reward): 11 | self.config = config 12 | self.act_space = self.act_space 13 | 14 | def actor(self, feat): 15 | shape = feat.shape[:-1] + self.act_space.shape 16 | if self.config.actor.dist == 'onehot': 17 | return common.OneHotDist(tf.zeros(shape)) 18 | else: 19 | dist = tfd.Uniform(-tf.ones(shape), tf.ones(shape)) 20 | return tfd.Independent(dist, 1) 21 | 22 | def train(self, start, context, data): 23 | return None, {} 24 | 25 | 26 | class Plan2Explore(common.Module): 27 | 28 | def __init__(self, config, act_space, wm, tfstep, reward): 29 | self.config = config 30 | self.reward = reward 31 | self.wm = wm 32 | self.ac = agent.ActorCritic(config, act_space, tfstep) 33 | self.actor = self.ac.actor 34 | stoch_size = config.rssm.stoch 35 | if config.rssm.discrete: 36 | stoch_size *= config.rssm.discrete 37 | size = { 38 | 'embed': 32 * config.encoder.cnn_depth, 39 | 'stoch': stoch_size, 40 | 'deter': config.rssm.deter, 41 | 'feat': config.rssm.stoch + config.rssm.deter, 42 | }[self.config.disag_target] 43 | self._networks = [ 44 | common.MLP(size, **config.expl_head) 45 | for _ in range(config.disag_models)] 46 | self.opt = common.Optimizer('expl', **config.expl_opt) 47 | self.extr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm) 48 | self.intr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm) 49 | 50 | def train(self, start, context, data): 51 | metrics = {} 52 | stoch = start['stoch'] 53 | if self.config.rssm.discrete: 54 | stoch = tf.reshape( 55 | stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1])) 56 | target = { 57 | 'embed': context['embed'], 58 | 'stoch': stoch, 59 | 'deter': start['deter'], 60 | 'feat': context['feat'], 61 | }[self.config.disag_target] 62 | inputs = context['feat'] 63 | if self.config.disag_action_cond: 64 | action = tf.cast(data['action'], inputs.dtype) 65 | inputs = tf.concat([inputs, action], -1) 66 | metrics.update(self._train_ensemble(inputs, target)) 67 | metrics.update(self.ac.train( 68 | self.wm, start, data['is_terminal'], self._intr_reward)) 69 | return None, metrics 70 | 71 | def _intr_reward(self, seq): 72 | inputs = seq['feat'] 73 | if self.config.disag_action_cond: 74 | action = tf.cast(seq['action'], inputs.dtype) 75 | inputs = tf.concat([inputs, action], -1) 76 | preds = [head(inputs).mode() for head in self._networks] 77 | disag = tf.tensor(preds).std(0).mean(-1) 78 | if self.config.disag_log: 79 | disag = tf.math.log(disag) 80 | reward = self.config.expl_intr_scale * self.intr_rewnorm(disag)[0] 81 | if self.config.expl_extr_scale: 82 | reward += self.config.expl_extr_scale * self.extr_rewnorm( 83 | self.reward(seq))[0] 84 | return reward 85 | 86 | def _train_ensemble(self, inputs, targets): 87 | if self.config.disag_offset: 88 | targets = targets[:, self.config.disag_offset:] 89 | inputs = inputs[:, :-self.config.disag_offset] 90 | targets = tf.stop_gradient(targets) 91 | inputs = tf.stop_gradient(inputs) 92 | with tf.GradientTape() as tape: 93 | preds = [head(inputs) for head in self._networks] 94 | loss = -sum([pred.log_prob(targets).mean() for pred in preds]) 95 | metrics = self.opt(tape, loss, self._networks) 96 | return metrics 97 | 98 | 99 | class ModelLoss(common.Module): 100 | 101 | def __init__(self, config, act_space, wm, tfstep, reward): 102 | self.config = config 103 | self.reward = reward 104 | self.wm = wm 105 | self.ac = agent.ActorCritic(config, act_space, tfstep) 106 | self.actor = self.ac.actor 107 | self.head = common.MLP([], **self.config.expl_head) 108 | self.opt = common.Optimizer('expl', **self.config.expl_opt) 109 | 110 | def train(self, start, context, data): 111 | metrics = {} 112 | target = tf.cast(context[self.config.expl_model_loss], tf.float32) 113 | with tf.GradientTape() as tape: 114 | loss = -self.head(context['feat']).log_prob(target).mean() 115 | metrics.update(self.opt(tape, loss, self.head)) 116 | metrics.update(self.ac.train( 117 | self.wm, start, data['is_terminal'], self._intr_reward)) 118 | return None, metrics 119 | 120 | def _intr_reward(self, seq): 121 | reward = self.config.expl_intr_scale * self.head(seq['feat']).mode() 122 | if self.config.expl_extr_scale: 123 | reward += self.config.expl_extr_scale * self.reward(seq) 124 | return reward 125 | -------------------------------------------------------------------------------- /dreamerv2/train.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import pathlib 6 | import re 7 | import sys 8 | import warnings 9 | 10 | try: 11 | import rich.traceback 12 | rich.traceback.install() 13 | except ImportError: 14 | pass 15 | 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | logging.getLogger().setLevel('ERROR') 18 | warnings.filterwarnings('ignore', '.*box bound precision lowered.*') 19 | 20 | sys.path.append(str(pathlib.Path(__file__).parent)) 21 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 22 | 23 | import numpy as np 24 | import ruamel.yaml as yaml 25 | 26 | import agent 27 | import common 28 | 29 | 30 | def main(): 31 | 32 | configs = yaml.safe_load(( 33 | pathlib.Path(sys.argv[0]).parent / 'configs.yaml').read_text()) 34 | parsed, remaining = common.Flags(configs=['defaults']).parse(known_only=True) 35 | config = common.Config(configs['defaults']) 36 | for name in parsed.configs: 37 | config = config.update(configs[name]) 38 | config = common.Flags(config).parse(remaining) 39 | 40 | logdir = pathlib.Path(config.logdir).expanduser() 41 | logdir.mkdir(parents=True, exist_ok=True) 42 | config.save(logdir / 'config.yaml') 43 | print(config, '\n') 44 | print('Logdir', logdir) 45 | 46 | import tensorflow as tf 47 | tf.config.experimental_run_functions_eagerly(not config.jit) 48 | message = 'No GPU found. To actually train on CPU remove this assert.' 49 | assert tf.config.experimental.list_physical_devices('GPU'), message 50 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 51 | tf.config.experimental.set_memory_growth(gpu, True) 52 | assert config.precision in (16, 32), config.precision 53 | if config.precision == 16: 54 | from tensorflow.keras.mixed_precision import experimental as prec 55 | prec.set_policy(prec.Policy('mixed_float16')) 56 | 57 | train_replay = common.Replay(logdir / 'train_episodes', **config.replay) 58 | eval_replay = common.Replay(logdir / 'eval_episodes', **dict( 59 | capacity=config.replay.capacity // 10, 60 | minlen=config.dataset.length, 61 | maxlen=config.dataset.length)) 62 | step = common.Counter(train_replay.stats['total_steps']) 63 | outputs = [ 64 | common.TerminalOutput(), 65 | common.JSONLOutput(logdir), 66 | common.TensorBoardOutput(logdir), 67 | ] 68 | logger = common.Logger(step, outputs, multiplier=config.action_repeat) 69 | metrics = collections.defaultdict(list) 70 | 71 | should_train = common.Every(config.train_every) 72 | should_log = common.Every(config.log_every) 73 | should_video_train = common.Every(config.eval_every) 74 | should_video_eval = common.Every(config.eval_every) 75 | should_expl = common.Until(config.expl_until // config.action_repeat) 76 | 77 | def make_env(mode): 78 | suite, task = config.task.split('_', 1) 79 | if suite == 'dmc': 80 | env = common.DMC( 81 | task, config.action_repeat, config.render_size, config.dmc_camera) 82 | env = common.NormalizeAction(env) 83 | elif suite == 'atari': 84 | env = common.Atari( 85 | task, config.action_repeat, config.render_size, 86 | config.atari_grayscale) 87 | env = common.OneHotAction(env) 88 | elif suite == 'crafter': 89 | assert config.action_repeat == 1 90 | outdir = logdir / 'crafter' if mode == 'train' else None 91 | reward = bool(['noreward', 'reward'].index(task)) or mode == 'eval' 92 | env = common.Crafter(outdir, reward) 93 | env = common.OneHotAction(env) 94 | else: 95 | raise NotImplementedError(suite) 96 | env = common.TimeLimit(env, config.time_limit) 97 | return env 98 | 99 | def per_episode(ep, mode): 100 | length = len(ep['reward']) - 1 101 | score = float(ep['reward'].astype(np.float64).sum()) 102 | print(f'{mode.title()} episode has {length} steps and return {score:.1f}.') 103 | logger.scalar(f'{mode}_return', score) 104 | logger.scalar(f'{mode}_length', length) 105 | for key, value in ep.items(): 106 | if re.match(config.log_keys_sum, key): 107 | logger.scalar(f'sum_{mode}_{key}', ep[key].sum()) 108 | if re.match(config.log_keys_mean, key): 109 | logger.scalar(f'mean_{mode}_{key}', ep[key].mean()) 110 | if re.match(config.log_keys_max, key): 111 | logger.scalar(f'max_{mode}_{key}', ep[key].max(0).mean()) 112 | should = {'train': should_video_train, 'eval': should_video_eval}[mode] 113 | if should(step): 114 | for key in config.log_keys_video: 115 | logger.video(f'{mode}_policy_{key}', ep[key]) 116 | replay = dict(train=train_replay, eval=eval_replay)[mode] 117 | logger.add(replay.stats, prefix=mode) 118 | logger.write() 119 | 120 | print('Create envs.') 121 | num_eval_envs = min(config.envs, config.eval_eps) 122 | if config.envs_parallel == 'none': 123 | train_envs = [make_env('train') for _ in range(config.envs)] 124 | eval_envs = [make_env('eval') for _ in range(num_eval_envs)] 125 | else: 126 | make_async_env = lambda mode: common.Async( 127 | functools.partial(make_env, mode), config.envs_parallel) 128 | train_envs = [make_async_env('train') for _ in range(config.envs)] 129 | eval_envs = [make_async_env('eval') for _ in range(eval_envs)] 130 | act_space = train_envs[0].act_space 131 | obs_space = train_envs[0].obs_space 132 | train_driver = common.Driver(train_envs) 133 | train_driver.on_episode(lambda ep: per_episode(ep, mode='train')) 134 | train_driver.on_step(lambda tran, worker: step.increment()) 135 | train_driver.on_step(train_replay.add_step) 136 | train_driver.on_reset(train_replay.add_step) 137 | eval_driver = common.Driver(eval_envs) 138 | eval_driver.on_episode(lambda ep: per_episode(ep, mode='eval')) 139 | eval_driver.on_episode(eval_replay.add_episode) 140 | 141 | prefill = max(0, config.prefill - train_replay.stats['total_steps']) 142 | if prefill: 143 | print(f'Prefill dataset ({prefill} steps).') 144 | random_agent = common.RandomAgent(act_space) 145 | train_driver(random_agent, steps=prefill, episodes=1) 146 | eval_driver(random_agent, episodes=1) 147 | train_driver.reset() 148 | eval_driver.reset() 149 | 150 | print('Create agent.') 151 | train_dataset = iter(train_replay.dataset(**config.dataset)) 152 | report_dataset = iter(train_replay.dataset(**config.dataset)) 153 | eval_dataset = iter(eval_replay.dataset(**config.dataset)) 154 | agnt = agent.Agent(config, obs_space, act_space, step) 155 | train_agent = common.CarryOverState(agnt.train) 156 | train_agent(next(train_dataset)) 157 | if (logdir / 'variables.pkl').exists(): 158 | agnt.load(logdir / 'variables.pkl') 159 | else: 160 | print('Pretrain agent.') 161 | for _ in range(config.pretrain): 162 | train_agent(next(train_dataset)) 163 | train_policy = lambda *args: agnt.policy( 164 | *args, mode='explore' if should_expl(step) else 'train') 165 | eval_policy = lambda *args: agnt.policy(*args, mode='eval') 166 | 167 | def train_step(tran, worker): 168 | if should_train(step): 169 | for _ in range(config.train_steps): 170 | mets = train_agent(next(train_dataset)) 171 | [metrics[key].append(value) for key, value in mets.items()] 172 | if should_log(step): 173 | for name, values in metrics.items(): 174 | logger.scalar(name, np.array(values, np.float64).mean()) 175 | metrics[name].clear() 176 | logger.add(agnt.report(next(report_dataset)), prefix='train') 177 | logger.write(fps=True) 178 | train_driver.on_step(train_step) 179 | 180 | while step < config.steps: 181 | logger.write() 182 | print('Start evaluation.') 183 | logger.add(agnt.report(next(eval_dataset)), prefix='eval') 184 | eval_driver(eval_policy, episodes=config.eval_eps) 185 | print('Start training.') 186 | train_driver(train_policy, steps=config.eval_every) 187 | agnt.save(logdir / 'variables.pkl') 188 | for env in train_envs + eval_envs: 189 | try: 190 | env.close() 191 | except Exception: 192 | pass 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | -------------------------------------------------------------------------------- /examples/minigrid.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gym_minigrid 3 | import dreamerv2.api as dv2 4 | 5 | config = dv2.defaults.update({ 6 | 'logdir': '~/logdir/minigrid', 7 | 'log_every': 1e3, 8 | 'train_every': 10, 9 | 'prefill': 1e5, 10 | 'actor_ent': 3e-3, 11 | 'loss_scales.kl': 1.0, 12 | 'discount': 0.99, 13 | }).parse_flags() 14 | 15 | env = gym.make('MiniGrid-DoorKey-6x6-v0') 16 | env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(env) 17 | dv2.train(env, config) 18 | -------------------------------------------------------------------------------- /scores/baselines.json: -------------------------------------------------------------------------------- 1 | {"atari_alien": {"c51_sticky_2e8": 2474.88, "ddqn_determ_2e8": 3747.7, "dqn_determ_2e8": 1620.0, "dqn_sticky_2e8": 2484.49, "dueling_determ_2e8": 4461.4, "human_gamer": 7127.7, "human_record": 251916.0, "iqn_sticky_2e8": 4961.42, "prioritized_determ_2e8": 4203.8, "rainbow_determ_2e8": 9491.7, "rainbow_determ_4e5": 739.9, "rainbow_sticky_2e8": 3456.95, "random": 228.8, "reactor_determ_2e8": 4199.4, "simple_determ_4e5": 405.2}, "atari_amidar": {"c51_sticky_2e8": 1363.18, "ddqn_determ_2e8": 1793.3, "dqn_determ_2e8": 978.0, "dqn_sticky_2e8": 1207.74, "dueling_determ_2e8": 2354.5, "human_gamer": 1719.5, "human_record": 104159.0, "iqn_sticky_2e8": 2393.43, "prioritized_determ_2e8": 1838.9, "rainbow_determ_2e8": 5131.2, "rainbow_determ_4e5": 188.6, "rainbow_sticky_2e8": 2529.09, "random": 5.8, "reactor_determ_2e8": 1546.8, "simple_determ_4e5": 88.0}, "atari_assault": {"c51_sticky_2e8": 1800.12, "ddqn_determ_2e8": 5393.2, "dqn_determ_2e8": 4280.4, "dqn_sticky_2e8": 1525.24, "dueling_determ_2e8": 4621.0, "human_gamer": 742.0, "human_record": 8647.0, "iqn_sticky_2e8": 4884.78, "prioritized_determ_2e8": 7672.1, "rainbow_determ_2e8": 14198.5, "rainbow_determ_4e5": 431.2, "rainbow_sticky_2e8": 3228.82, "random": 222.4, "reactor_determ_2e8": 17543.8, "simple_determ_4e5": 369.3}, "atari_asterix": {"c51_sticky_2e8": 13453.37, "ddqn_determ_2e8": 17356.5, "dqn_determ_2e8": 4359.0, "dqn_sticky_2e8": 2711.41, "dueling_determ_2e8": 28188.0, "human_gamer": 8503.3, "human_record": 1000000.0, "iqn_sticky_2e8": 10374.44, "prioritized_determ_2e8": 31527.0, "rainbow_determ_2e8": 428200.3, "rainbow_determ_4e5": 470.8, "rainbow_sticky_2e8": 18366.58, "random": 210.0, "reactor_determ_2e8": 16121.0, "simple_determ_4e5": 1089.5}, "atari_asteroids": {"c51_sticky_2e8": 1079.58, "ddqn_determ_2e8": 734.7, "dqn_determ_2e8": 1364.5, "dqn_sticky_2e8": 698.37, "dueling_determ_2e8": 2837.7, "human_gamer": 47388.7, "human_record": 10506650.0, "iqn_sticky_2e8": 1584.51, "prioritized_determ_2e8": 2654.3, "rainbow_determ_2e8": 2712.8, "rainbow_sticky_2e8": 1483.54, "random": 719.1, "reactor_determ_2e8": 4467.4}, "atari_atlantis": {"c51_sticky_2e8": 824994.31, "ddqn_determ_2e8": 106056.0, "dqn_determ_2e8": 279987.0, "dqn_sticky_2e8": 853640.0, "dueling_determ_2e8": 382572.0, "human_gamer": 29028.1, "human_record": 10604840.0, "iqn_sticky_2e8": 890214.0, "prioritized_determ_2e8": 357324.0, "rainbow_determ_2e8": 826659.5, "rainbow_sticky_2e8": 802548.0, "random": 12850.0, "reactor_determ_2e8": 968179.5}, "atari_bank_heist": {"c51_sticky_2e8": 757.65, "ddqn_determ_2e8": 1030.6, "dqn_determ_2e8": 455.0, "dqn_sticky_2e8": 601.79, "dueling_determ_2e8": 1611.9, "human_gamer": 753.1, "human_record": 82058.0, "iqn_sticky_2e8": 1052.27, "prioritized_determ_2e8": 1054.6, "rainbow_determ_2e8": 1358.0, "rainbow_determ_4e5": 51.0, "rainbow_sticky_2e8": 1074.97, "random": 14.2, "reactor_determ_2e8": 1236.8, "simple_determ_4e5": 8.2}, "atari_battle_zone": {"c51_sticky_2e8": 25501.76, "ddqn_determ_2e8": 31700.0, "dqn_determ_2e8": 29900.0, "dqn_sticky_2e8": 17784.84, "dueling_determ_2e8": 37150.0, "human_gamer": 37187.5, "human_record": 801000.0, "iqn_sticky_2e8": 40953.44, "prioritized_determ_2e8": 31530.0, "rainbow_determ_2e8": 62010.0, "rainbow_determ_4e5": 10124.6, "rainbow_sticky_2e8": 40060.59, "random": 2360.0, "reactor_determ_2e8": 98235.0, "simple_determ_4e5": 5184.4}, "atari_beam_rider": {"c51_sticky_2e8": 5331.59, "ddqn_determ_2e8": 13772.8, "dqn_determ_2e8": 8627.5, "dqn_sticky_2e8": 5852.42, "dueling_determ_2e8": 12164.0, "human_gamer": 16926.5, "human_record": 999999.0, "iqn_sticky_2e8": 7130.45, "prioritized_determ_2e8": 23384.2, "rainbow_determ_2e8": 16850.2, "rainbow_sticky_2e8": 6290.48, "random": 363.9, "reactor_determ_2e8": 8811.8}, "atari_berzerk": {"c51_sticky_2e8": 516.69, "ddqn_determ_2e8": 1225.4, "dqn_determ_2e8": 585.6, "dqn_sticky_2e8": 487.48, "dueling_determ_2e8": 1472.6, "human_gamer": 2630.4, "human_record": 1057940.0, "iqn_sticky_2e8": 648.4, "prioritized_determ_2e8": 1305.6, "rainbow_determ_2e8": 2545.6, "rainbow_sticky_2e8": 833.41, "random": 123.7, "reactor_determ_2e8": 1515.7}, "atari_bowling": {"c51_sticky_2e8": 32.1, "ddqn_determ_2e8": 68.1, "dqn_determ_2e8": 50.4, "dqn_sticky_2e8": 30.12, "dueling_determ_2e8": 65.5, "human_gamer": 160.7, "human_record": 300.0, "iqn_sticky_2e8": 39.44, "prioritized_determ_2e8": 47.9, "rainbow_determ_2e8": 30.0, "rainbow_sticky_2e8": 42.86, "random": 23.1, "reactor_determ_2e8": 59.3}, "atari_boxing": {"c51_sticky_2e8": 81.4, "ddqn_determ_2e8": 91.6, "dqn_determ_2e8": 88.0, "dqn_sticky_2e8": 77.99, "dueling_determ_2e8": 99.4, "human_gamer": 12.1, "human_record": 100.0, "iqn_sticky_2e8": 97.91, "prioritized_determ_2e8": 95.6, "rainbow_determ_2e8": 99.6, "rainbow_determ_4e5": 0.2, "rainbow_sticky_2e8": 98.56, "random": 0.1, "reactor_determ_2e8": 99.7, "simple_determ_4e5": 9.1}, "atari_breakout": {"c51_sticky_2e8": 202.39, "ddqn_determ_2e8": 418.5, "dqn_determ_2e8": 385.5, "dqn_sticky_2e8": 96.23, "dueling_determ_2e8": 345.3, "human_gamer": 30.5, "human_record": 864.0, "iqn_sticky_2e8": 79.48, "prioritized_determ_2e8": 373.9, "rainbow_determ_2e8": 417.5, "rainbow_determ_4e5": 1.9, "rainbow_sticky_2e8": 120.07, "random": 1.7, "reactor_determ_2e8": 509.5, "simple_determ_4e5": 12.7}, "atari_centipede": {"c51_sticky_2e8": 7594.53, "ddqn_determ_2e8": 5409.4, "dqn_determ_2e8": 4657.7, "dqn_sticky_2e8": 2583.03, "dueling_determ_2e8": 7561.4, "human_gamer": 12017.0, "human_record": 1301709.0, "iqn_sticky_2e8": 3727.64, "prioritized_determ_2e8": 4463.2, "rainbow_determ_2e8": 8167.3, "rainbow_sticky_2e8": 6509.87, "random": 2090.9, "reactor_determ_2e8": 7267.2}, "atari_chopper_command": {"c51_sticky_2e8": 5718.16, "ddqn_determ_2e8": 5809.0, "dqn_determ_2e8": 6126.0, "dqn_sticky_2e8": 2690.61, "dueling_determ_2e8": 11215.0, "human_gamer": 7387.8, "human_record": 999999.0, "iqn_sticky_2e8": 9281.96, "prioritized_determ_2e8": 8600.0, "rainbow_determ_2e8": 16654.0, "rainbow_determ_4e5": 861.8, "rainbow_sticky_2e8": 12337.51, "random": 811.0, "reactor_determ_2e8": 19901.5, "simple_determ_4e5": 1246.9}, "atari_crazy_climber": {"c51_sticky_2e8": 136041.14, "ddqn_determ_2e8": 117282.0, "dqn_determ_2e8": 110763.0, "dqn_sticky_2e8": 104568.76, "dueling_determ_2e8": 143570.0, "human_gamer": 35829.4, "human_record": 219900.0, "iqn_sticky_2e8": 132738.13, "prioritized_determ_2e8": 141161.0, "rainbow_determ_2e8": 168788.5, "rainbow_determ_4e5": 16185.3, "rainbow_sticky_2e8": 145389.29, "random": 10780.5, "reactor_determ_2e8": 173274.0, "simple_determ_4e5": 39827.8}, "atari_defender": {"ddqn_determ_2e8": 35338.5, "dqn_determ_2e8": 23633.0, "dueling_determ_2e8": 42214.0, "human_gamer": 18688.9, "human_record": 6010500.0, "prioritized_determ_2e8": 31286.5, "rainbow_determ_2e8": 55105.0, "random": 2874.5, "reactor_determ_2e8": 181074.3}, "atari_demon_attack": {"c51_sticky_2e8": 7098.0, "ddqn_determ_2e8": 58044.2, "dqn_determ_2e8": 12149.4, "dqn_sticky_2e8": 6361.58, "dueling_determ_2e8": 60813.3, "human_gamer": 1971.0, "human_record": 1556345.0, "iqn_sticky_2e8": 15350.43, "prioritized_determ_2e8": 71846.4, "rainbow_determ_2e8": 111185.2, "rainbow_determ_4e5": 508.0, "rainbow_sticky_2e8": 17071.26, "random": 152.1, "reactor_determ_2e8": 122782.5, "simple_determ_4e5": 169.5}, "atari_double_dunk": {"c51_sticky_2e8": 1.28, "ddqn_determ_2e8": -5.5, "dqn_determ_2e8": -6.6, "dqn_sticky_2e8": -6.54, "dueling_determ_2e8": 0.1, "human_gamer": -16.4, "human_record": 22.0, "iqn_sticky_2e8": 21.12, "prioritized_determ_2e8": 18.5, "rainbow_determ_2e8": -0.3, "rainbow_sticky_2e8": 22.12, "random": -18.6, "reactor_determ_2e8": 23.0}, "atari_enduro": {"c51_sticky_2e8": 1283.8, "ddqn_determ_2e8": 1211.8, "dqn_determ_2e8": 729.0, "dqn_sticky_2e8": 628.91, "dueling_determ_2e8": 2258.2, "human_gamer": 860.5, "human_record": 9500.0, "iqn_sticky_2e8": 2202.98, "prioritized_determ_2e8": 2093.0, "rainbow_determ_2e8": 2125.9, "rainbow_sticky_2e8": 2200.16, "random": 0.0, "reactor_determ_2e8": 2211.3}, "atari_fishing_derby": {"c51_sticky_2e8": 10.4, "ddqn_determ_2e8": 15.5, "dqn_determ_2e8": -4.9, "dqn_sticky_2e8": 0.58, "dueling_determ_2e8": 46.4, "human_gamer": -38.7, "human_record": 71.0, "iqn_sticky_2e8": 45.33, "prioritized_determ_2e8": 39.5, "rainbow_determ_2e8": 31.3, "rainbow_sticky_2e8": 41.8, "random": -91.7, "reactor_determ_2e8": 33.1}, "atari_freeway": {"c51_sticky_2e8": 33.68, "ddqn_determ_2e8": 33.3, "dqn_determ_2e8": 30.8, "dqn_sticky_2e8": 26.27, "dueling_determ_2e8": 0.0, "human_gamer": 29.6, "human_record": 38.0, "iqn_sticky_2e8": 33.67, "prioritized_determ_2e8": 33.7, "rainbow_determ_2e8": 34.0, "rainbow_determ_4e5": 27.9, "rainbow_sticky_2e8": 33.69, "random": 0.0, "reactor_determ_2e8": 22.3, "simple_determ_4e5": 20.3}, "atari_frostbite": {"c51_sticky_2e8": 4284.99, "ddqn_determ_2e8": 1683.3, "dqn_determ_2e8": 797.4, "dqn_sticky_2e8": 367.07, "dueling_determ_2e8": 4672.8, "human_gamer": 4334.7, "human_record": 454830.0, "iqn_sticky_2e8": 7811.57, "prioritized_determ_2e8": 4380.1, "rainbow_determ_2e8": 9590.5, "rainbow_determ_4e5": 866.8, "rainbow_sticky_2e8": 8207.72, "random": 65.2, "reactor_determ_2e8": 7136.7, "simple_determ_4e5": 254.7}, "atari_gopher": {"c51_sticky_2e8": 7172.88, "ddqn_determ_2e8": 14840.8, "dqn_determ_2e8": 8777.4, "dqn_sticky_2e8": 5479.9, "dueling_determ_2e8": 15718.4, "human_gamer": 2412.5, "human_record": 355040.0, "iqn_sticky_2e8": 12107.61, "prioritized_determ_2e8": 32487.2, "rainbow_determ_2e8": 70354.6, "rainbow_determ_4e5": 349.5, "rainbow_sticky_2e8": 10641.07, "random": 257.6, "reactor_determ_2e8": 36279.1, "simple_determ_4e5": 771.0}, "atari_gravitar": {"c51_sticky_2e8": 686.0, "ddqn_determ_2e8": 412.0, "dqn_determ_2e8": 473.0, "dqn_sticky_2e8": 330.07, "dueling_determ_2e8": 588.0, "human_gamer": 3351.4, "human_record": 162850.0, "iqn_sticky_2e8": 1346.62, "prioritized_determ_2e8": 548.5, "rainbow_determ_2e8": 1419.3, "rainbow_sticky_2e8": 1271.8, "random": 173.0, "reactor_determ_2e8": 1804.8}, "atari_hero": {"c51_sticky_2e8": 33753.23, "ddqn_determ_2e8": 20130.2, "dqn_determ_2e8": 20437.8, "dqn_sticky_2e8": 17325.44, "dueling_determ_2e8": 20818.2, "human_gamer": 30826.4, "human_record": 1000000.0, "iqn_sticky_2e8": 36058.41, "prioritized_determ_2e8": 23037.7, "rainbow_determ_2e8": 55887.4, "rainbow_determ_4e5": 6857.0, "rainbow_sticky_2e8": 46675.16, "random": 1027.0, "reactor_determ_2e8": 27833.0, "simple_determ_4e5": 1295.1}, "atari_ice_hockey": {"c51_sticky_2e8": -4.36, "ddqn_determ_2e8": -2.7, "dqn_determ_2e8": -1.9, "dqn_sticky_2e8": -5.84, "dueling_determ_2e8": 0.5, "human_gamer": 0.9, "human_record": 36.0, "iqn_sticky_2e8": -5.11, "prioritized_determ_2e8": 1.3, "rainbow_determ_2e8": 1.1, "rainbow_sticky_2e8": -0.18, "random": -11.2, "reactor_determ_2e8": 15.7}, "atari_james_bond": {"c51_sticky_2e8": 725.56, "ddqn_determ_2e8": 768.5, "dqn_determ_2e8": 302.8, "dqn_sticky_2e8": 573.31, "dueling_determ_2e8": 1358.0, "human_gamer": 29.0, "human_record": 45550.0, "iqn_sticky_2e8": 3165.7, "prioritized_determ_2e8": 1312.5, "rainbow_determ_2e8": 812.0, "rainbow_determ_4e5": 301.6, "rainbow_sticky_2e8": 1097.12, "random": 7.0, "reactor_determ_2e8": 19809.0, "simple_determ_4e5": 125.3}, "atari_kangaroo": {"c51_sticky_2e8": 7687.54, "ddqn_determ_2e8": 12992.0, "dqn_determ_2e8": 7259.0, "dqn_sticky_2e8": 11485.98, "dueling_determ_2e8": 14854.0, "human_gamer": 3035.0, "human_record": 1424600.0, "iqn_sticky_2e8": 12602.27, "prioritized_determ_2e8": 16200.0, "rainbow_determ_2e8": 14637.5, "rainbow_determ_4e5": 779.3, "rainbow_sticky_2e8": 12748.3, "random": 52.0, "reactor_determ_2e8": 13349.0, "simple_determ_4e5": 323.1}, "atari_krull": {"c51_sticky_2e8": 6770.45, "ddqn_determ_2e8": 7920.5, "dqn_determ_2e8": 8422.3, "dqn_sticky_2e8": 6097.63, "dueling_determ_2e8": 11451.9, "human_gamer": 2665.5, "human_record": 104100.0, "iqn_sticky_2e8": 8844.17, "prioritized_determ_2e8": 9728.0, "rainbow_determ_2e8": 8741.5, "rainbow_determ_4e5": 2851.5, "rainbow_sticky_2e8": 4065.97, "random": 1598.0, "reactor_determ_2e8": 10237.8, "simple_determ_4e5": 4539.9}, "atari_kung_fu_master": {"c51_sticky_2e8": 24811.6, "ddqn_determ_2e8": 29710.0, "dqn_determ_2e8": 26059.0, "dqn_sticky_2e8": 23435.38, "dueling_determ_2e8": 34294.0, "human_gamer": 22736.3, "human_record": 1000000.0, "iqn_sticky_2e8": 31653.37, "prioritized_determ_2e8": 39581.0, "rainbow_determ_2e8": 52181.0, "rainbow_determ_4e5": 14346.1, "rainbow_sticky_2e8": 26475.07, "random": 258.5, "reactor_determ_2e8": 61621.5, "simple_determ_4e5": 17257.2}, "atari_montezuma_revenge": {"c51_sticky_2e8": 1044.67, "ddqn_determ_2e8": 0.0, "dqn_determ_2e8": 0.0, "dqn_sticky_2e8": 0.0, "dueling_determ_2e8": 0.0, "human_gamer": 4753.3, "human_record": 1219200.0, "iqn_sticky_2e8": 500.0, "prioritized_determ_2e8": 0.0, "rainbow_determ_2e8": 384.0, "rainbow_sticky_2e8": 500.0, "random": 0.0, "reactor_determ_2e8": 0.0}, "atari_ms_pacman": {"c51_sticky_2e8": 3693.53, "ddqn_determ_2e8": 2711.4, "dqn_determ_2e8": 3085.6, "dqn_sticky_2e8": 3402.4, "dueling_determ_2e8": 6283.5, "human_gamer": 6951.6, "human_record": 290090.0, "iqn_sticky_2e8": 5217.99, "prioritized_determ_2e8": 6518.7, "rainbow_determ_2e8": 5380.4, "rainbow_determ_4e5": 1204.1, "rainbow_sticky_2e8": 3861.01, "random": 307.3, "reactor_determ_2e8": 4416.9, "simple_determ_4e5": 762.8}, "atari_name_this_game": {"c51_sticky_2e8": 12594.19, "ddqn_determ_2e8": 10616.0, "dqn_determ_2e8": 8207.8, "dqn_sticky_2e8": 7278.65, "dueling_determ_2e8": 11971.1, "human_gamer": 8049.0, "human_record": 25220.0, "iqn_sticky_2e8": 6638.61, "prioritized_determ_2e8": 12270.5, "rainbow_determ_2e8": 13136.0, "rainbow_sticky_2e8": 9025.75, "random": 2292.3, "reactor_determ_2e8": 12636.5}, "atari_phoenix": {"c51_sticky_2e8": 5849.25, "ddqn_determ_2e8": 12252.5, "dqn_determ_2e8": 8485.2, "dqn_sticky_2e8": 4996.58, "dueling_determ_2e8": 23092.2, "human_gamer": 7242.6, "human_record": 4014440.0, "iqn_sticky_2e8": 5102.03, "prioritized_determ_2e8": 18992.7, "rainbow_determ_2e8": 108528.6, "rainbow_sticky_2e8": 8545.35, "random": 761.4, "reactor_determ_2e8": 10261.4}, "atari_pitfall": {"c51_sticky_2e8": -18.39, "ddqn_determ_2e8": -29.9, "dqn_determ_2e8": -286.1, "dqn_sticky_2e8": -73.81, "dueling_determ_2e8": 0.0, "human_gamer": 6463.7, "human_record": 114000.0, "iqn_sticky_2e8": -13.16, "prioritized_determ_2e8": -356.5, "rainbow_determ_2e8": 0.0, "rainbow_sticky_2e8": -19.76, "random": -229.4, "reactor_determ_2e8": -3.7}, "atari_pong": {"c51_sticky_2e8": 19.34, "ddqn_determ_2e8": 20.9, "dqn_determ_2e8": 19.5, "dqn_sticky_2e8": 16.61, "dueling_determ_2e8": 21.0, "human_gamer": 14.6, "human_record": 21.0, "iqn_sticky_2e8": 20.13, "prioritized_determ_2e8": 20.6, "rainbow_determ_2e8": 20.9, "rainbow_determ_4e5": -19.3, "rainbow_sticky_2e8": 20.18, "random": -20.7, "reactor_determ_2e8": 20.7, "simple_determ_4e5": 5.2}, "atari_private_eye": {"c51_sticky_2e8": 4207.11, "ddqn_determ_2e8": 129.7, "dqn_determ_2e8": 146.7, "dqn_sticky_2e8": -16.04, "dueling_determ_2e8": 103.0, "human_gamer": 69571.3, "human_record": 101800.0, "iqn_sticky_2e8": 4180.76, "prioritized_determ_2e8": 200.0, "rainbow_determ_2e8": 4234.0, "rainbow_determ_4e5": 97.8, "rainbow_sticky_2e8": 21333.56, "random": 24.9, "reactor_determ_2e8": 15198.0, "simple_determ_4e5": 58.3}, "atari_qbert": {"c51_sticky_2e8": 9761.91, "ddqn_determ_2e8": 15088.5, "dqn_determ_2e8": 13117.3, "dqn_sticky_2e8": 10117.5, "dueling_determ_2e8": 19220.3, "human_gamer": 13455.0, "human_record": 2400000.0, "iqn_sticky_2e8": 16729.89, "prioritized_determ_2e8": 16256.5, "rainbow_determ_2e8": 33817.5, "rainbow_determ_4e5": 1152.9, "rainbow_sticky_2e8": 17382.94, "random": 163.9, "reactor_determ_2e8": 21222.5, "simple_determ_4e5": 559.8}, "atari_riverraid": {"c51_sticky_2e8": 13514.28, "ddqn_determ_2e8": 14884.5, "dqn_determ_2e8": 7377.6, "dqn_sticky_2e8": 11638.93, "dueling_determ_2e8": 21162.6, "human_gamer": 17118.0, "human_record": 1000000.0, "iqn_sticky_2e8": 15183.38, "prioritized_determ_2e8": 14522.3, "rainbow_determ_2e8": 22920.8, "rainbow_sticky_2e8": 20755.91, "random": 1338.5, "reactor_determ_2e8": 16957.3}, "atari_road_runner": {"c51_sticky_2e8": 48213.46, "ddqn_determ_2e8": 44127.0, "dqn_determ_2e8": 39544.0, "dqn_sticky_2e8": 36925.47, "dueling_determ_2e8": 69524.0, "human_gamer": 7845.0, "human_record": 2038100.0, "iqn_sticky_2e8": 58965.95, "prioritized_determ_2e8": 57608.0, "rainbow_determ_2e8": 62041.0, "rainbow_determ_4e5": 9600.0, "rainbow_sticky_2e8": 54662.07, "random": 11.5, "reactor_determ_2e8": 66790.5, "simple_determ_4e5": 5169.4}, "atari_robotank": {"c51_sticky_2e8": 60.53, "ddqn_determ_2e8": 65.1, "dqn_determ_2e8": 63.9, "dqn_sticky_2e8": 59.77, "dueling_determ_2e8": 65.3, "human_gamer": 11.9, "human_record": 76.0, "iqn_sticky_2e8": 65.72, "prioritized_determ_2e8": 62.6, "rainbow_determ_2e8": 61.4, "rainbow_sticky_2e8": 65.52, "random": 2.2, "reactor_determ_2e8": 71.8}, "atari_seaquest": {"c51_sticky_2e8": 31025.85, "ddqn_determ_2e8": 16452.7, "dqn_determ_2e8": 5860.6, "dqn_sticky_2e8": 1600.66, "dueling_determ_2e8": 50254.2, "human_gamer": 42054.7, "human_record": 999999.0, "iqn_sticky_2e8": 17039.16, "prioritized_determ_2e8": 26357.8, "rainbow_determ_2e8": 15898.9, "rainbow_determ_4e5": 354.1, "rainbow_sticky_2e8": 9903.38, "random": 68.4, "reactor_determ_2e8": 5071.6, "simple_determ_4e5": 370.9}, "atari_skiing": {"c51_sticky_2e8": -22231.09, "ddqn_determ_2e8": -9021.8, "dqn_determ_2e8": -13062.3, "dqn_sticky_2e8": -15824.61, "dueling_determ_2e8": -8857.4, "human_gamer": -4336.9, "human_record": -3272.0, "iqn_sticky_2e8": -11161.85, "prioritized_determ_2e8": -9996.9, "rainbow_determ_2e8": -12957.8, "rainbow_sticky_2e8": -28707.56, "random": -17098.1, "reactor_determ_2e8": -10632.9}, "atari_solaris": {"c51_sticky_2e8": 1961.66, "ddqn_determ_2e8": 3067.8, "dqn_determ_2e8": 3482.8, "dqn_sticky_2e8": 1436.4, "dueling_determ_2e8": 2250.8, "human_gamer": 12326.7, "human_record": 111420.0, "iqn_sticky_2e8": 1684.4, "prioritized_determ_2e8": 4309.0, "rainbow_determ_2e8": 3560.3, "rainbow_sticky_2e8": 1582.69, "random": 1236.3, "reactor_determ_2e8": 2236.0}, "atari_space_invaders": {"c51_sticky_2e8": 3983.6, "ddqn_determ_2e8": 2525.5, "dqn_determ_2e8": 1692.3, "dqn_sticky_2e8": 1794.24, "dueling_determ_2e8": 6427.3, "human_gamer": 1668.7, "human_record": 621535.0, "iqn_sticky_2e8": 4530.03, "prioritized_determ_2e8": 2865.8, "rainbow_determ_2e8": 18789.0, "rainbow_sticky_2e8": 4130.88, "random": 148.0, "reactor_determ_2e8": 2387.1}, "atari_star_gunner": {"c51_sticky_2e8": 34804.97, "ddqn_determ_2e8": 60142.0, "dqn_determ_2e8": 54282.0, "dqn_sticky_2e8": 42165.22, "dueling_determ_2e8": 89238.0, "human_gamer": 10250.0, "human_record": 77400.0, "iqn_sticky_2e8": 80002.58, "prioritized_determ_2e8": 63302.0, "rainbow_determ_2e8": 127029.0, "rainbow_sticky_2e8": 57908.72, "random": 664.0, "reactor_determ_2e8": 48942.0}, "atari_surround": {"ddqn_determ_2e8": -2.9, "dqn_determ_2e8": -5.6, "dueling_determ_2e8": 4.4, "human_gamer": 6.5, "prioritized_determ_2e8": 8.9, "rainbow_determ_2e8": 9.7, "random": -10.0, "reactor_determ_2e8": 0.9}, "atari_tennis": {"c51_sticky_2e8": 21.7, "ddqn_determ_2e8": -22.8, "dqn_determ_2e8": 12.2, "dqn_sticky_2e8": -1.5, "dueling_determ_2e8": 5.1, "human_gamer": -8.3, "human_record": 21.0, "iqn_sticky_2e8": 22.54, "prioritized_determ_2e8": 0.0, "rainbow_determ_2e8": 0.0, "rainbow_sticky_2e8": -0.22, "random": -23.8, "reactor_determ_2e8": 23.4}, "atari_time_pilot": {"c51_sticky_2e8": 8117.05, "ddqn_determ_2e8": 8339.0, "dqn_determ_2e8": 4870.0, "dqn_sticky_2e8": 3654.37, "dueling_determ_2e8": 11666.0, "human_gamer": 5229.2, "human_record": 65300.0, "iqn_sticky_2e8": 11666.23, "prioritized_determ_2e8": 9197.0, "rainbow_determ_2e8": 12926.0, "rainbow_sticky_2e8": 12050.53, "random": 3568.0, "reactor_determ_2e8": 18871.5}, "atari_tutankham": {"c51_sticky_2e8": 240.72, "ddqn_determ_2e8": 218.4, "dqn_determ_2e8": 68.1, "dqn_sticky_2e8": 103.84, "dueling_determ_2e8": 211.4, "human_gamer": 167.6, "human_record": 5384.0, "iqn_sticky_2e8": 251.41, "prioritized_determ_2e8": 204.6, "rainbow_determ_2e8": 241.0, "rainbow_sticky_2e8": 239.07, "random": 11.4, "reactor_determ_2e8": 263.2}, "atari_up_n_down": {"c51_sticky_2e8": 8279.26, "ddqn_determ_2e8": 22972.2, "dqn_determ_2e8": 9989.9, "dqn_sticky_2e8": 8488.31, "dueling_determ_2e8": 44939.6, "human_gamer": 11693.2, "human_record": 82840.0, "iqn_sticky_2e8": 59943.87, "prioritized_determ_2e8": 16154.1, "rainbow_determ_2e8": 125754.6, "rainbow_determ_4e5": 2877.4, "rainbow_sticky_2e8": 34887.76, "random": 533.4, "reactor_determ_2e8": 194989.5, "simple_determ_4e5": 2152.6}, "atari_venture": {"c51_sticky_2e8": 1339.41, "ddqn_determ_2e8": 98.0, "dqn_determ_2e8": 163.0, "dqn_sticky_2e8": 39.13, "dueling_determ_2e8": 497.0, "human_gamer": 1187.5, "human_record": 38900.0, "iqn_sticky_2e8": 1312.84, "prioritized_determ_2e8": 54.0, "rainbow_determ_2e8": 5.5, "rainbow_sticky_2e8": 1528.87, "random": 0.0, "reactor_determ_2e8": 0.0}, "atari_video_pinball": {"c51_sticky_2e8": 405173.28, "ddqn_determ_2e8": 309941.9, "dqn_determ_2e8": 196760.4, "dqn_sticky_2e8": 63406.11, "dueling_determ_2e8": 98209.5, "human_gamer": 17667.9, "human_record": 89218328.0, "iqn_sticky_2e8": 415833.17, "prioritized_determ_2e8": 282007.3, "rainbow_determ_2e8": 533936.5, "rainbow_sticky_2e8": 466894.97, "random": 16256.9, "reactor_determ_2e8": 261720.2}, "atari_wizard_of_wor": {"c51_sticky_2e8": 3229.23, "ddqn_determ_2e8": 7492.0, "dqn_determ_2e8": 2704.0, "dqn_sticky_2e8": 2065.8, "dueling_determ_2e8": 7855.0, "human_gamer": 4756.5, "human_record": 395300.0, "iqn_sticky_2e8": 5670.54, "prioritized_determ_2e8": 4802.0, "rainbow_determ_2e8": 17862.5, "rainbow_sticky_2e8": 7878.62, "random": 563.5, "reactor_determ_2e8": 18484.0}, "atari_yars_revenge": {"c51_sticky_2e8": 11746.61, "ddqn_determ_2e8": 11712.6, "dqn_determ_2e8": 18098.9, "dqn_sticky_2e8": 23909.38, "dueling_determ_2e8": 49622.1, "human_gamer": 54576.9, "human_record": 15000105.0, "iqn_sticky_2e8": 84144.01, "prioritized_determ_2e8": 11357.0, "rainbow_determ_2e8": 102557.0, "rainbow_sticky_2e8": 45542.03, "random": 3092.9, "reactor_determ_2e8": 109607.5}, "atari_zaxxon": {"c51_sticky_2e8": 6502.33, "ddqn_determ_2e8": 10163.0, "dqn_determ_2e8": 5363.0, "dqn_sticky_2e8": 4538.57, "dueling_determ_2e8": 12944.0, "human_gamer": 9173.3, "human_record": 83700.0, "iqn_sticky_2e8": 11022.93, "prioritized_determ_2e8": 10469.0, "rainbow_determ_2e8": 22209.5, "rainbow_sticky_2e8": 14603.02, "random": 32.5, "reactor_determ_2e8": 16525.0}, "atari_air_raid": {"c51_sticky_2e8": 8288.24, "dqn_sticky_2e8": 7479.5, "human_record": 23050.0, "iqn_sticky_2e8": 10658.59, "rainbow_sticky_2e8": 11668.75, "random": 579.25}, "atari_carnival": {"c51_sticky_2e8": 4809.81, "dqn_sticky_2e8": 4784.84, "human_record": 2541440.0, "iqn_sticky_2e8": 5616.85, "rainbow_sticky_2e8": 4533.1, "random": 700.8}, "atari_elevator_action": {"c51_sticky_2e8": 68332.0, "dqn_sticky_2e8": 439.77, "human_record": 156550.0, "iqn_sticky_2e8": 65484.0, "rainbow_sticky_2e8": 76652.0, "random": 4387.0}, "atari_journey_escape": {"c51_sticky_2e8": -2707.38, "dqn_sticky_2e8": -3671.09, "human_record": -4317804.0, "iqn_sticky_2e8": -1490.44, "rainbow_sticky_2e8": -1106.49, "random": -19977.0}, "atari_pooyan": {"c51_sticky_2e8": 2931.13, "dqn_sticky_2e8": 3211.96, "human_record": 13025.0, "iqn_sticky_2e8": 4595.22, "rainbow_sticky_2e8": 4494.43, "random": 371.2}} 2 | -------------------------------------------------------------------------------- /scores/montezuma-dreamerv2.json: -------------------------------------------------------------------------------- 1 | [{"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "0", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 20.0, 60.0, 160.0, 336.3636363636364, 400.0, 400.0, 400.0, 400.0, 400.0, 460.0, 500.0, 520.0, 400.0, 400.0, 420.0, 400.0, 430.0, 450.0, 490.0, 430.0, 560.0, 490.0, 460.0, 430.0, 410.0, 408.3333333333333, 520.0, 440.0, 440.0, 430.0, 430.0, 430.0, 460.0, 460.0, 460.0, 490.0, 450.0, 460.0, 430.0, 440.0, 390.0, 450.0, 400.0, 400.0, 400.0, 440.0, 490.0, 430.0, 500.0, 500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2250.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2300.0, 2500.0, 2520.0, 2300.0, 2520.0, 2520.0, 2580.0, 2560.0, 2580.0, 2540.0, 2540.0, 2550.0, 2560.0, 2520.0, 2560.0, 2520.0, 1660.0, 2270.0, 2580.0, 2320.0, 2540.0, 2540.0, 2520.0, 2250.0, 2280.0, 2250.0, 2570.0, 2540.0, 1930.0, 2520.0, 2520.0, 2520.0, 2530.0, 2550.0, 2550.0, 2540.0, 2500.0, 2500.0, 2540.0, 2530.0, 2570.0, 2540.0, 2350.0, 2550.0, 2570.0, 2540.0, 2530.0, 2580.0, 2530.0, 2590.0, 2580.0, 2590.0, 2590.0, 2600.0, 2540.0, 2560.0, 2590.0, 2570.0, 2550.0, 2140.0, 1560.0, 2590.0, 2290.0, 1930.0, 2550.0, 2550.0, 2280.0, 2520.0, 2530.0, 2530.0, 2520.0, 2530.0, 2570.0, 2560.0, 2590.0, 2580.0, 2520.0, 2550.0, 2590.0, 2570.0, 2560.0, 2590.0, 2560.0, 2120.0, 2390.0, 2600.0, 2310.0, 2500.0, 2500.0, 1720.0, 2500.0, 2510.0, 2500.0, 2300.0, 2580.0, 2570.0, 2570.0, 2580.0, 2570.0, 2560.0, 2570.0, 2520.0, 2500.0, 2530.0, 2560.0, 2530.0, 2580.0, 2590.0, 2600.0, 2540.0, 2590.0, 2600.0, 2560.0, 2550.0, 2590.0, 2150.0, 2540.0, 2580.0, 2530.0, 2590.0, 2600.0, 2600.0, 2590.0, 2580.0, 2570.0, 2500.0, 2530.0, 2580.0, 2130.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "1", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.0, 90.0, 160.0, 180.0, 130.0, 260.0, 400.0, 400.0, 440.0, 420.0, 430.0, 400.0, 400.0, 450.0, 470.0, 450.0, 470.0, 430.0, 400.0, 400.0, 400.0, 420.0, 410.0, 440.0, 420.0, 410.0, 430.0, 410.0, 460.0, 500.0, 500.0, 460.0, 450.0, 440.0, 470.0, 390.0, 500.0, 500.0, 500.0, 2100.0, 1870.0, 460.0, 1080.0, 2300.0, 1941.6666666666667, 2080.0, 2500.0, 2500.0, 2290.0, 2560.0, 2380.0, 2590.0, 2560.0, 2530.0, 2600.0, 2580.0, 2560.0, 2540.0, 2500.0, 2530.0, 2560.0, 2590.0, 1970.0, 2580.0, 2560.0, 2560.0, 2600.0, 2560.0, 2600.0, 2570.0, 2570.0, 2600.0, 2580.0, 2530.0, 2560.0, 2540.0, 2560.0, 2560.0, 2560.0, 2550.0, 2560.0, 2560.0, 2580.0, 2570.0, 2580.0, 2590.0, 2580.0, 2600.0, 2590.0, 2600.0, 2600.0, 2580.0, 2600.0, 2570.0, 2520.0, 2550.0, 2600.0, 2580.0, 2600.0, 2600.0, 2590.0, 2600.0, 2600.0, 2600.0, 2600.0, 2570.0, 2140.0, 2500.0, 2580.0, 2590.0, 2600.0, 2590.0, 2540.0, 2500.0, 2560.0, 2600.0, 2590.0, 2580.0, 2520.0, 2510.0, 2520.0, 2590.0, 2290.0, 2510.0, 2500.0, 2520.0, 2510.0, 2500.0, 2540.0, 2570.0, 2600.0, 2600.0, 2600.0, 2500.0, 2600.0, 2600.0, 2590.0, 2600.0, 2600.0, 2100.0, 2530.0, 2570.0, 2530.0, 2310.0, 2590.0, 2600.0, 2590.0, 2540.0, 2600.0, 2580.0, 2520.0, 2580.0, 2510.0, 2550.0, 2500.0, 2510.0, 2510.0, 2360.0, 2590.0, 2600.0, 1810.0, 2380.0, 2600.0, 2530.0, 2540.0, 2600.0, 2590.0, 2540.0, 2590.0, 2600.0, 2560.0, 2530.0, 2510.0, 2340.0, 2500.0, 2520.0, 2520.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "2", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 0.0, 20.0, 120.0, 310.0, 400.0, 400.0, 400.0, 400.0, 400.0, 400.0, 400.0, 430.0, 400.0, 490.0, 460.0, 470.0, 460.0, 430.0, 430.0, 490.0, 420.0, 470.0, 470.0, 440.0, 440.0, 430.0, 360.0, 400.0, 400.0, 450.0, 420.0, 430.0, 480.0, 480.0, 500.0, 480.0, 490.0, 450.0, 480.0, 500.0, 490.0, 460.0, 470.0, 440.0, 450.0, 410.0, 410.0, 410.0, 400.0, 420.0, 410.0, 400.0, 430.0, 410.0, 460.0, 430.0, 460.0, 450.0, 410.0, 410.0, 400.0, 410.0, 420.0, 430.0, 400.0, 400.0, 400.0, 366.6666666666667, 410.0, 400.0, 410.0, 430.0, 410.0, 440.0, 420.0, 400.0, 440.0, 430.0, 430.0, 440.0, 450.0, 500.0, 500.0, 500.0, 500.0, 490.0, 490.0, 500.0, 650.0, 1280.0, 2300.0, 1700.0, 2300.0, 2320.0, 1920.0, 2360.0, 2310.0, 2310.0, 2510.0, 2500.0, 2500.0, 2520.0, 2500.0, 2500.0, 2530.0, 2500.0, 2500.0, 2530.0, 2550.0, 2540.0, 2540.0, 2530.0, 2500.0, 2550.0, 2550.0, 2550.0, 2530.0, 2540.0, 2530.0, 2530.0, 2560.0, 2560.0, 2500.0, 2550.0, 2550.0, 2570.0, 2570.0, 2500.0, 2550.0, 2320.0, 2570.0, 2580.0, 2530.0, 2550.0, 2530.0, 2540.0, 2500.0, 2130.0, 2510.0, 2510.0, 2500.0, 2500.0, 2520.0, 2500.0, 2510.0, 2500.0, 2510.0, 2570.0, 2520.0, 2530.0, 2520.0, 2580.0, 2580.0, 2590.0, 2540.0, 2500.0, 2520.0, 2560.0, 2530.0, 2570.0, 2530.0, 1940.0, 2590.0, 2560.0, 2520.0, 2510.0, 2510.0, 2540.0, 2510.0, 2500.0, 2520.0, 2120.0, 2120.0, 2560.0, 2540.0, 2550.0, 2500.0, 2500.0, 2540.0, 2520.0, 2520.0, 2510.0, 2100.0, 2500.0, 2500.0, 2510.0, 2510.0, 2520.0, 2520.0, 2500.0, 2500.0, 2500.0, 2530.0, 2520.0, 2500.0, 2510.0, 1900.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "3", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.0, 90.0, 100.0, 100.0, 92.3076923076923, 91.66666666666667, 150.0, 50.0, 0.0, 0.0, 0.0, 330.0, 410.0, 400.0, 400.0, 230.0, 470.0, 440.0, 490.0, 500.0, 460.0, 460.0, 410.0, 480.0, 490.0, 500.0, 480.0, 480.0, 480.0, 480.0, 480.0, 450.0, 480.0, 490.0, 500.0, 500.0, 500.0, 450.0, 500.0, 500.0, 500.0, 500.0, 500.0, 490.0, 470.0, 490.0, 470.0, 510.0, 420.0, 420.0, 490.0, 1680.0, 1030.0, 690.0, 680.0, 1230.0, 1660.0, 2110.0, 1290.0, 1680.0, 2560.0, 2600.0, 2590.0, 2530.0, 2500.0, 2500.0, 2320.0, 2540.0, 2580.0, 2540.0, 2550.0, 2550.0, 2500.0, 2510.0, 2520.0, 2560.0, 2530.0, 2550.0, 2570.0, 2600.0, 2590.0, 2580.0, 2590.0, 2590.0, 2590.0, 2510.0, 2600.0, 2590.0, 2530.0, 2520.0, 2590.0, 2580.0, 2590.0, 2590.0, 2300.0, 2500.0, 2550.0, 2560.0, 2540.0, 2600.0, 2580.0, 2590.0, 2600.0, 2600.0, 2600.0, 2510.0, 2570.0, 2170.0, 2600.0, 2570.0, 2590.0, 2550.0, 2520.0, 2570.0, 2600.0, 2560.0, 2530.0, 2530.0, 2550.0, 2600.0, 2570.0, 2560.0, 2550.0, 2590.0, 2600.0, 2580.0, 2380.0, 2600.0, 2530.0, 2530.0, 2520.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "4", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 0.0, 100.0, 360.0, 440.0, 430.0, 400.0, 470.0, 450.0, 500.0, 500.0, 350.0, 500.0, 500.0, 490.0, 500.0, 470.0, 470.0, 500.0, 450.0, 450.0, 480.0, 480.0, 500.0, 500.0, 500.0, 500.0, 500.0, 500.0, 500.0, 900.0, 1100.0, 1050.0, 700.0, 1690.0, 2320.0, 2510.0, 2500.0, 2520.0, 2540.0, 2510.0, 2530.0, 2540.0, 2550.0, 2550.0, 2550.0, 2550.0, 2510.0, 2500.0, 2510.0, 2500.0, 2510.0, 2550.0, 2510.0, 2550.0, 2510.0, 1320.0, 2510.0, 2580.0, 2540.0, 2520.0, 2570.0, 2570.0, 2530.0, 2540.0, 2540.0, 2272.7272727272725, 2580.0, 2310.0, 2341.6666666666665, 2500.0, 2510.0, 2100.0, 2540.0, 2500.0, 2510.0, 2500.0, 2520.0, 1900.0, 2530.0, 2540.0, 2500.0, 2270.0, 2310.0, 2500.0, 2500.0, 2500.0, 2500.0, 2510.0, 2510.0, 2500.0, 2510.0, 2500.0, 2540.0, 2550.0, 1710.0, 2550.0, 2540.0, 2520.0, 2510.0, 2520.0, 2550.0, 2570.0, 1720.0, 2300.0, 2560.0, 2560.0, 2550.0, 2550.0, 2570.0, 2560.0, 2530.0, 2510.0, 2570.0, 2590.0, 2570.0, 2310.0, 2530.0, 2550.0, 2320.0, 2510.0, 2550.0, 2530.0, 2250.0, 2540.0, 2530.0, 2580.0, 2550.0, 2540.0, 2570.0, 2540.0, 2580.0, 2300.0, 2260.0, 2540.0, 2530.0, 2530.0, 2580.0, 2590.0, 2540.0, 2520.0, 2520.0, 1510.0, 2340.0, 2530.0, 2370.0, 2560.0, 2360.0, 2560.0, 2020.0, 1560.0, 2590.0, 2600.0, 2580.0, 2560.0, 2590.0, 2590.0, 2600.0, 2600.0, 2510.0, 2500.0, 2580.0, 2560.0, 2580.0, 2080.0, 2580.0, 2600.0, 2180.0, 2600.0, 2600.0, 2600.0, 2540.0, 2600.0, 2600.0, 2380.0, 2160.0, 2600.0, 2390.0, 2570.0, 2590.0, 2590.0, 2580.0, 2570.0, 2590.0, 2550.0, 2570.0, 2590.0, 2590.0, 2600.0, 2550.0, 2550.0, 2600.0, 2600.0, 2600.0, 2570.0, 2030.0, 2160.0, 2310.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "5", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 80.0, 370.0, 430.0, 410.0, 410.0, 420.0, 450.0, 460.0, 500.0, 900.0, 900.0, 1900.0, 900.0, 700.0, 1300.0, 2260.0, 2020.0, 2500.0, 2300.0, 2500.0, 2300.0, 2500.0, 2090.0, 2500.0, 2500.0, 2080.0, 2500.0, 2500.0, 2500.0, 2500.0, 2250.0, 2500.0, 2300.0, 2500.0, 2300.0, 1700.0, 2500.0, 2272.7272727272725, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2300.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2510.0, 2500.0, 2290.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2100.0, 2500.0, 2080.0, 2080.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2250.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2520.0, 2500.0, 2500.0, 2500.0, 2250.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 1900.0, 2500.0, 2500.0, 2500.0, 2500.0, 2510.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2510.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2100.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2090.0, 2500.0, 2500.0, 2500.0, 2500.0, 1900.0, 1900.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0]}] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import pathlib 3 | 4 | 5 | setuptools.setup( 6 | name='dreamerv2', 7 | version='2.2.0', 8 | description='Mastering Atari with Discrete World Models', 9 | url='http://github.com/danijar/dreamerv2', 10 | long_description=pathlib.Path('README.md').read_text(), 11 | long_description_content_type='text/markdown', 12 | packages=['dreamerv2', 'dreamerv2.common'], 13 | package_data={'dreamerv2': ['configs.yaml']}, 14 | entry_points={'console_scripts': ['dreamerv2=dreamerv2.train:main']}, 15 | install_requires=[ 16 | 'gym[atari]', 'atari_py', 'crafter', 'dm_control', 'ruamel.yaml', 17 | 'tensorflow', 'tensorflow_probability'], 18 | classifiers=[ 19 | 'Intended Audience :: Science/Research', 20 | 'License :: OSI Approved :: MIT License', 21 | 'Programming Language :: Python :: 3', 22 | 'Topic :: Games/Entertainment', 23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 24 | ], 25 | ) 26 | --------------------------------------------------------------------------------