├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── dreamerv2
├── agent.py
├── api.py
├── common
│ ├── __init__.py
│ ├── config.py
│ ├── counter.py
│ ├── dists.py
│ ├── driver.py
│ ├── envs.py
│ ├── flags.py
│ ├── logger.py
│ ├── nets.py
│ ├── other.py
│ ├── plot.py
│ ├── replay.py
│ ├── tfutils.py
│ └── when.py
├── configs.yaml
├── expl.py
└── train.py
├── examples
└── minigrid.py
├── scores
├── atari-dopamine.json
├── atari-dreamerv2-schedules.json
├── atari-dreamerv2.json
├── baselines.json
├── dmc-vision-dreamerv2.json
├── humanoid-dreamerv2.json
└── montezuma-dreamerv2.json
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *.egg-info
4 | dist
5 | MUJOCO_LOG.TXT
6 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:2.4.2-gpu
2 |
3 | # System packages.
4 | RUN apt-get update && apt-get install -y \
5 | ffmpeg \
6 | libgl1-mesa-dev \
7 | python3-pip \
8 | unrar \
9 | wget \
10 | && apt-get clean
11 |
12 | # MuJoCo.
13 | ENV MUJOCO_GL egl
14 | RUN mkdir -p /root/.mujoco && \
15 | wget -nv https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip && \
16 | unzip mujoco.zip -d /root/.mujoco && \
17 | rm mujoco.zip
18 |
19 | # Python packages.
20 | RUN pip3 install --no-cache-dir \
21 | 'gym[atari]' \
22 | atari_py \
23 | crafter \
24 | dm_control \
25 | ruamel.yaml \
26 | tensorflow_probability==0.12.2
27 |
28 | # Atari ROMS.
29 | RUN wget -L -nv http://www.atarimania.com/roms/Roms.rar && \
30 | unrar x Roms.rar && \
31 | unzip ROMS.zip && \
32 | python3 -m atari_py.import_roms ROMS && \
33 | rm -rf Roms.rar ROMS.zip ROMS
34 |
35 | # MuJoCo key.
36 | ARG MUJOCO_KEY=""
37 | RUN echo "$MUJOCO_KEY" > /root/.mujoco/mjkey.txt
38 | RUN cat /root/.mujoco/mjkey.txt
39 |
40 | # DreamerV2.
41 | ENV TF_XLA_FLAGS --tf_xla_auto_jit=2
42 | COPY . /app
43 | WORKDIR /app
44 | CMD [ \
45 | "python3", "dreamerv2/train.py", \
46 | "--logdir", "/logdir/$(date +%Y%m%d-%H%M%S)", \
47 | "--configs", "defaults", "atari", \
48 | "--task", "atari_pong" \
49 | ]
50 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Danijar Hafner
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **Status:** Stable release
2 |
3 | [](https://pypi.python.org/pypi/dreamerv2/#history)
4 |
5 | # Mastering Atari with Discrete World Models
6 |
7 | Implementation of the [DreamerV2][website] agent in TensorFlow 2. Training
8 | curves for all 55 games are included.
9 |
10 |
11 |
12 |
13 |
14 | If you find this code useful, please reference in your paper:
15 |
16 | ```
17 | @article{hafner2020dreamerv2,
18 | title={Mastering Atari with Discrete World Models},
19 | author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
20 | journal={arXiv preprint arXiv:2010.02193},
21 | year={2020}
22 | }
23 | ```
24 |
25 | [website]: https://danijar.com/dreamerv2
26 |
27 | ## Method
28 |
29 | DreamerV2 is the first world model agent that achieves human-level performance
30 | on the Atari benchmark. DreamerV2 also outperforms the final performance of the
31 | top model-free agents Rainbow and IQN using the same amount of experience and
32 | computation. The implementation in this repository alternates between training
33 | the world model, training the policy, and collecting experience and runs on a
34 | single GPU.
35 |
36 | 
37 |
38 | DreamerV2 learns a model of the environment directly from high-dimensional
39 | input images. For this, it predicts ahead using compact learned states. The
40 | states consist of a deterministic part and several categorical variables that
41 | are sampled. The prior for these categoricals is learned through a KL loss. The
42 | world model is learned end-to-end via straight-through gradients, meaning that
43 | the gradient of the density is set to the gradient of the sample.
44 |
45 | 
46 |
47 | DreamerV2 learns actor and critic networks from imagined trajectories of latent
48 | states. The trajectories start at encoded states of previously encountered
49 | sequences. The world model then predicts ahead using the selected actions and
50 | its learned state prior. The critic is trained using temporal difference
51 | learning and the actor is trained to maximize the value function via reinforce
52 | and straight-through gradients.
53 |
54 | For more information:
55 |
56 | - [Google AI Blog post](https://ai.googleblog.com/2021/02/mastering-atari-with-discrete-world.html)
57 | - [Project website](https://danijar.com/dreamerv2/)
58 | - [Research paper](https://arxiv.org/pdf/2010.02193.pdf)
59 |
60 | ## Using the Package
61 |
62 | The easiest way to run DreamerV2 on new environments is to install the package
63 | via `pip3 install dreamerv2`. The code automatically detects whether the
64 | environment uses discrete or continuous actions. Here is a usage example that
65 | trains DreamerV2 on the MiniGrid environment:
66 |
67 | ```python
68 | import gym
69 | import gym_minigrid
70 | import dreamerv2.api as dv2
71 |
72 | config = dv2.defaults.update({
73 | 'logdir': '~/logdir/minigrid',
74 | 'log_every': 1e3,
75 | 'train_every': 10,
76 | 'prefill': 1e5,
77 | 'actor_ent': 3e-3,
78 | 'loss_scales.kl': 1.0,
79 | 'discount': 0.99,
80 | }).parse_flags()
81 |
82 | env = gym.make('MiniGrid-DoorKey-6x6-v0')
83 | env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(env)
84 | dv2.train(env, config)
85 | ```
86 |
87 | ## Manual Instructions
88 |
89 | To modify the DreamerV2 agent, clone the repository and follow the instructions
90 | below. There is also a Dockerfile available, in case you do not want to install
91 | the dependencies on your system.
92 |
93 | Get dependencies:
94 |
95 | ```sh
96 | pip3 install tensorflow==2.6.0 tensorflow_probability ruamel.yaml 'gym[atari]' dm_control
97 | ```
98 |
99 | Train on Atari:
100 |
101 | ```sh
102 | python3 dreamerv2/train.py --logdir ~/logdir/atari_pong/dreamerv2/1 \
103 | --configs atari --task atari_pong
104 | ```
105 |
106 | Train on DM Control:
107 |
108 | ```sh
109 | python3 dreamerv2/train.py --logdir ~/logdir/dmc_walker_walk/dreamerv2/1 \
110 | --configs dmc_vision --task dmc_walker_walk
111 | ```
112 |
113 | Monitor results:
114 |
115 | ```sh
116 | tensorboard --logdir ~/logdir
117 | ```
118 |
119 | Generate plots:
120 |
121 | ```sh
122 | python3 common/plot.py --indir ~/logdir --outdir ~/plots \
123 | --xaxis step --yaxis eval_return --bins 1e6
124 | ```
125 |
126 | ## Docker Instructions
127 |
128 | The [Dockerfile](https://github.com/danijar/dreamerv2/blob/main/Dockerfile)
129 | lets you run DreamerV2 without installing its dependencies in your system. This
130 | requires you to have Docker with GPU access set up.
131 |
132 | Check your setup:
133 |
134 | ```sh
135 | docker run -it --rm --gpus all tensorflow/tensorflow:2.4.2-gpu nvidia-smi
136 | ```
137 |
138 | Train on Atari:
139 |
140 | ```sh
141 | docker build -t dreamerv2 .
142 | docker run -it --rm --gpus all -v ~/logdir:/logdir dreamerv2 \
143 | python3 dreamerv2/train.py --logdir /logdir/atari_pong/dreamerv2/1 \
144 | --configs atari --task atari_pong
145 | ```
146 |
147 | Train on DM Control:
148 |
149 | ```sh
150 | docker build -t dreamerv2 . --build-arg MUJOCO_KEY="$(cat ~/.mujoco/mjkey.txt)"
151 | docker run -it --rm --gpus all -v ~/logdir:/logdir dreamerv2 \
152 | python3 dreamerv2/train.py --logdir /logdir/dmc_walker_walk/dreamerv2/1 \
153 | --configs dmc_vision --task dmc_walker_walk
154 | ```
155 |
156 | ## Tips
157 |
158 | - **Efficient debugging.** You can use the `debug` config as in `--configs
159 | atari debug`. This reduces the batch size, increases the evaluation
160 | frequency, and disables `tf.function` graph compilation for easy line-by-line
161 | debugging.
162 |
163 | - **Infinite gradient norms.** This is normal and described under loss scaling in
164 | the [mixed precision][mixed] guide. You can disable mixed precision by passing
165 | `--precision 32` to the training script. Mixed precision is faster but can in
166 | principle cause numerical instabilities.
167 |
168 | - **Accessing logged metrics.** The metrics are stored in both TensorBoard and
169 | JSON lines format. You can directly load them using `pandas.read_json()`. The
170 | plotting script also stores the binned and aggregated metrics of multiple runs
171 | into a single JSON file for easy manual plotting.
172 |
173 | [mixed]: https://www.tensorflow.org/guide/mixed_precision
174 |
--------------------------------------------------------------------------------
/dreamerv2/agent.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import mixed_precision as prec
3 |
4 | import common
5 | import expl
6 |
7 |
8 | class Agent(common.Module):
9 |
10 | def __init__(self, config, obs_space, act_space, step):
11 | self.config = config
12 | self.obs_space = obs_space
13 | self.act_space = act_space['action']
14 | self.step = step
15 | self.tfstep = tf.Variable(int(self.step), tf.int64)
16 | self.wm = WorldModel(config, obs_space, self.tfstep)
17 | self._task_behavior = ActorCritic(config, self.act_space, self.tfstep)
18 | if config.expl_behavior == 'greedy':
19 | self._expl_behavior = self._task_behavior
20 | else:
21 | self._expl_behavior = getattr(expl, config.expl_behavior)(
22 | self.config, self.act_space, self.wm, self.tfstep,
23 | lambda seq: self.wm.heads['reward'](seq['feat']).mode())
24 |
25 | @tf.function
26 | def policy(self, obs, state=None, mode='train'):
27 | obs = tf.nest.map_structure(tf.tensor, obs)
28 | tf.py_function(lambda: self.tfstep.assign(
29 | int(self.step), read_value=False), [], [])
30 | if state is None:
31 | latent = self.wm.rssm.initial(len(obs['reward']))
32 | action = tf.zeros((len(obs['reward']),) + self.act_space.shape)
33 | state = latent, action
34 | latent, action = state
35 | embed = self.wm.encoder(self.wm.preprocess(obs))
36 | sample = (mode == 'train') or not self.config.eval_state_mean
37 | latent, _ = self.wm.rssm.obs_step(
38 | latent, action, embed, obs['is_first'], sample)
39 | feat = self.wm.rssm.get_feat(latent)
40 | if mode == 'eval':
41 | actor = self._task_behavior.actor(feat)
42 | action = actor.mode()
43 | noise = self.config.eval_noise
44 | elif mode == 'explore':
45 | actor = self._expl_behavior.actor(feat)
46 | action = actor.sample()
47 | noise = self.config.expl_noise
48 | elif mode == 'train':
49 | actor = self._task_behavior.actor(feat)
50 | action = actor.sample()
51 | noise = self.config.expl_noise
52 | action = common.action_noise(action, noise, self.act_space)
53 | outputs = {'action': action}
54 | state = (latent, action)
55 | return outputs, state
56 |
57 | @tf.function
58 | def train(self, data, state=None):
59 | metrics = {}
60 | state, outputs, mets = self.wm.train(data, state)
61 | metrics.update(mets)
62 | start = outputs['post']
63 | reward = lambda seq: self.wm.heads['reward'](seq['feat']).mode()
64 | metrics.update(self._task_behavior.train(
65 | self.wm, start, data['is_terminal'], reward))
66 | if self.config.expl_behavior != 'greedy':
67 | mets = self._expl_behavior.train(start, outputs, data)[-1]
68 | metrics.update({'expl_' + key: value for key, value in mets.items()})
69 | return state, metrics
70 |
71 | @tf.function
72 | def report(self, data):
73 | report = {}
74 | data = self.wm.preprocess(data)
75 | for key in self.wm.heads['decoder'].cnn_keys:
76 | name = key.replace('/', '_')
77 | report[f'openl_{name}'] = self.wm.video_pred(data, key)
78 | return report
79 |
80 |
81 | class WorldModel(common.Module):
82 |
83 | def __init__(self, config, obs_space, tfstep):
84 | shapes = {k: tuple(v.shape) for k, v in obs_space.items()}
85 | self.config = config
86 | self.tfstep = tfstep
87 | self.rssm = common.EnsembleRSSM(**config.rssm)
88 | self.encoder = common.Encoder(shapes, **config.encoder)
89 | self.heads = {}
90 | self.heads['decoder'] = common.Decoder(shapes, **config.decoder)
91 | self.heads['reward'] = common.MLP([], **config.reward_head)
92 | if config.pred_discount:
93 | self.heads['discount'] = common.MLP([], **config.discount_head)
94 | for name in config.grad_heads:
95 | assert name in self.heads, name
96 | self.model_opt = common.Optimizer('model', **config.model_opt)
97 |
98 | def train(self, data, state=None):
99 | with tf.GradientTape() as model_tape:
100 | model_loss, state, outputs, metrics = self.loss(data, state)
101 | modules = [self.encoder, self.rssm, *self.heads.values()]
102 | metrics.update(self.model_opt(model_tape, model_loss, modules))
103 | return state, outputs, metrics
104 |
105 | def loss(self, data, state=None):
106 | data = self.preprocess(data)
107 | embed = self.encoder(data)
108 | post, prior = self.rssm.observe(
109 | embed, data['action'], data['is_first'], state)
110 | kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.config.kl)
111 | assert len(kl_loss.shape) == 0
112 | likes = {}
113 | losses = {'kl': kl_loss}
114 | feat = self.rssm.get_feat(post)
115 | for name, head in self.heads.items():
116 | grad_head = (name in self.config.grad_heads)
117 | inp = feat if grad_head else tf.stop_gradient(feat)
118 | out = head(inp)
119 | dists = out if isinstance(out, dict) else {name: out}
120 | for key, dist in dists.items():
121 | like = tf.cast(dist.log_prob(data[key]), tf.float32)
122 | likes[key] = like
123 | losses[key] = -like.mean()
124 | model_loss = sum(
125 | self.config.loss_scales.get(k, 1.0) * v for k, v in losses.items())
126 | outs = dict(
127 | embed=embed, feat=feat, post=post,
128 | prior=prior, likes=likes, kl=kl_value)
129 | metrics = {f'{name}_loss': value for name, value in losses.items()}
130 | metrics['model_kl'] = kl_value.mean()
131 | metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean()
132 | metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean()
133 | last_state = {k: v[:, -1] for k, v in post.items()}
134 | return model_loss, last_state, outs, metrics
135 |
136 | def imagine(self, policy, start, is_terminal, horizon):
137 | flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
138 | start = {k: flatten(v) for k, v in start.items()}
139 | start['feat'] = self.rssm.get_feat(start)
140 | start['action'] = tf.zeros_like(policy(start['feat']).mode())
141 | seq = {k: [v] for k, v in start.items()}
142 | for _ in range(horizon):
143 | action = policy(tf.stop_gradient(seq['feat'][-1])).sample()
144 | state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action)
145 | feat = self.rssm.get_feat(state)
146 | for key, value in {**state, 'action': action, 'feat': feat}.items():
147 | seq[key].append(value)
148 | seq = {k: tf.stack(v, 0) for k, v in seq.items()}
149 | if 'discount' in self.heads:
150 | disc = self.heads['discount'](seq['feat']).mean()
151 | if is_terminal is not None:
152 | # Override discount prediction for the first step with the true
153 | # discount factor from the replay buffer.
154 | true_first = 1.0 - flatten(is_terminal).astype(disc.dtype)
155 | true_first *= self.config.discount
156 | disc = tf.concat([true_first[None], disc[1:]], 0)
157 | else:
158 | disc = self.config.discount * tf.ones(seq['feat'].shape[:-1])
159 | seq['discount'] = disc
160 | # Shift discount factors because they imply whether the following state
161 | # will be valid, not whether the current state is valid.
162 | seq['weight'] = tf.math.cumprod(
163 | tf.concat([tf.ones_like(disc[:1]), disc[:-1]], 0), 0)
164 | return seq
165 |
166 | @tf.function
167 | def preprocess(self, obs):
168 | dtype = prec.global_policy().compute_dtype
169 | obs = obs.copy()
170 | for key, value in obs.items():
171 | if key.startswith('log_'):
172 | continue
173 | if value.dtype == tf.int32:
174 | value = value.astype(dtype)
175 | if value.dtype == tf.uint8:
176 | value = value.astype(dtype) / 255.0 - 0.5
177 | obs[key] = value
178 | obs['reward'] = {
179 | 'identity': tf.identity,
180 | 'sign': tf.sign,
181 | 'tanh': tf.tanh,
182 | }[self.config.clip_rewards](obs['reward'])
183 | obs['discount'] = 1.0 - obs['is_terminal'].astype(dtype)
184 | obs['discount'] *= self.config.discount
185 | return obs
186 |
187 | @tf.function
188 | def video_pred(self, data, key):
189 | decoder = self.heads['decoder']
190 | truth = data[key][:6] + 0.5
191 | embed = self.encoder(data)
192 | states, _ = self.rssm.observe(
193 | embed[:6, :5], data['action'][:6, :5], data['is_first'][:6, :5])
194 | recon = decoder(self.rssm.get_feat(states))[key].mode()[:6]
195 | init = {k: v[:, -1] for k, v in states.items()}
196 | prior = self.rssm.imagine(data['action'][:6, 5:], init)
197 | openl = decoder(self.rssm.get_feat(prior))[key].mode()
198 | model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
199 | error = (model - truth + 1) / 2
200 | video = tf.concat([truth, model, error], 2)
201 | B, T, H, W, C = video.shape
202 | return video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
203 |
204 |
205 | class ActorCritic(common.Module):
206 |
207 | def __init__(self, config, act_space, tfstep):
208 | self.config = config
209 | self.act_space = act_space
210 | self.tfstep = tfstep
211 | discrete = hasattr(act_space, 'n')
212 | if self.config.actor.dist == 'auto':
213 | self.config = self.config.update({
214 | 'actor.dist': 'onehot' if discrete else 'trunc_normal'})
215 | if self.config.actor_grad == 'auto':
216 | self.config = self.config.update({
217 | 'actor_grad': 'reinforce' if discrete else 'dynamics'})
218 | self.actor = common.MLP(act_space.shape[0], **self.config.actor)
219 | self.critic = common.MLP([], **self.config.critic)
220 | if self.config.slow_target:
221 | self._target_critic = common.MLP([], **self.config.critic)
222 | self._updates = tf.Variable(0, tf.int64)
223 | else:
224 | self._target_critic = self.critic
225 | self.actor_opt = common.Optimizer('actor', **self.config.actor_opt)
226 | self.critic_opt = common.Optimizer('critic', **self.config.critic_opt)
227 | self.rewnorm = common.StreamNorm(**self.config.reward_norm)
228 |
229 | def train(self, world_model, start, is_terminal, reward_fn):
230 | metrics = {}
231 | hor = self.config.imag_horizon
232 | # The weights are is_terminal flags for the imagination start states.
233 | # Technically, they should multiply the losses from the second trajectory
234 | # step onwards, which is the first imagined step. However, we are not
235 | # training the action that led into the first step anyway, so we can use
236 | # them to scale the whole sequence.
237 | with tf.GradientTape() as actor_tape:
238 | seq = world_model.imagine(self.actor, start, is_terminal, hor)
239 | reward = reward_fn(seq)
240 | seq['reward'], mets1 = self.rewnorm(reward)
241 | mets1 = {f'reward_{k}': v for k, v in mets1.items()}
242 | target, mets2 = self.target(seq)
243 | actor_loss, mets3 = self.actor_loss(seq, target)
244 | with tf.GradientTape() as critic_tape:
245 | critic_loss, mets4 = self.critic_loss(seq, target)
246 | metrics.update(self.actor_opt(actor_tape, actor_loss, self.actor))
247 | metrics.update(self.critic_opt(critic_tape, critic_loss, self.critic))
248 | metrics.update(**mets1, **mets2, **mets3, **mets4)
249 | self.update_slow_target() # Variables exist after first forward pass.
250 | return metrics
251 |
252 | def actor_loss(self, seq, target):
253 | # Actions: 0 [a1] [a2] a3
254 | # ^ | ^ | ^ |
255 | # / v / v / v
256 | # States: [z0]->[z1]-> z2 -> z3
257 | # Targets: t0 [t1] [t2]
258 | # Baselines: [v0] [v1] v2 v3
259 | # Entropies: [e1] [e2]
260 | # Weights: [ 1] [w1] w2 w3
261 | # Loss: l1 l2
262 | metrics = {}
263 | # Two states are lost at the end of the trajectory, one for the boostrap
264 | # value prediction and one because the corresponding action does not lead
265 | # anywhere anymore. One target is lost at the start of the trajectory
266 | # because the initial state comes from the replay buffer.
267 | policy = self.actor(tf.stop_gradient(seq['feat'][:-2]))
268 | if self.config.actor_grad == 'dynamics':
269 | objective = target[1:]
270 | elif self.config.actor_grad == 'reinforce':
271 | baseline = self._target_critic(seq['feat'][:-2]).mode()
272 | advantage = tf.stop_gradient(target[1:] - baseline)
273 | action = tf.stop_gradient(seq['action'][1:-1])
274 | objective = policy.log_prob(action) * advantage
275 | elif self.config.actor_grad == 'both':
276 | baseline = self._target_critic(seq['feat'][:-2]).mode()
277 | advantage = tf.stop_gradient(target[1:] - baseline)
278 | objective = policy.log_prob(seq['action'][1:-1]) * advantage
279 | mix = common.schedule(self.config.actor_grad_mix, self.tfstep)
280 | objective = mix * target[1:] + (1 - mix) * objective
281 | metrics['actor_grad_mix'] = mix
282 | else:
283 | raise NotImplementedError(self.config.actor_grad)
284 | ent = policy.entropy()
285 | ent_scale = common.schedule(self.config.actor_ent, self.tfstep)
286 | objective += ent_scale * ent
287 | weight = tf.stop_gradient(seq['weight'])
288 | actor_loss = -(weight[:-2] * objective).mean()
289 | metrics['actor_ent'] = ent.mean()
290 | metrics['actor_ent_scale'] = ent_scale
291 | return actor_loss, metrics
292 |
293 | def critic_loss(self, seq, target):
294 | # States: [z0] [z1] [z2] z3
295 | # Rewards: [r0] [r1] [r2] r3
296 | # Values: [v0] [v1] [v2] v3
297 | # Weights: [ 1] [w1] [w2] w3
298 | # Targets: [t0] [t1] [t2]
299 | # Loss: l0 l1 l2
300 | dist = self.critic(seq['feat'][:-1])
301 | target = tf.stop_gradient(target)
302 | weight = tf.stop_gradient(seq['weight'])
303 | critic_loss = -(dist.log_prob(target) * weight[:-1]).mean()
304 | metrics = {'critic': dist.mode().mean()}
305 | return critic_loss, metrics
306 |
307 | def target(self, seq):
308 | # States: [z0] [z1] [z2] [z3]
309 | # Rewards: [r0] [r1] [r2] r3
310 | # Values: [v0] [v1] [v2] [v3]
311 | # Discount: [d0] [d1] [d2] d3
312 | # Targets: t0 t1 t2
313 | reward = tf.cast(seq['reward'], tf.float32)
314 | disc = tf.cast(seq['discount'], tf.float32)
315 | value = self._target_critic(seq['feat']).mode()
316 | # Skipping last time step because it is used for bootstrapping.
317 | target = common.lambda_return(
318 | reward[:-1], value[:-1], disc[:-1],
319 | bootstrap=value[-1],
320 | lambda_=self.config.discount_lambda,
321 | axis=0)
322 | metrics = {}
323 | metrics['critic_slow'] = value.mean()
324 | metrics['critic_target'] = target.mean()
325 | return target, metrics
326 |
327 | def update_slow_target(self):
328 | if self.config.slow_target:
329 | if self._updates % self.config.slow_target_update == 0:
330 | mix = 1.0 if self._updates == 0 else float(
331 | self.config.slow_target_fraction)
332 | for s, d in zip(self.critic.variables, self._target_critic.variables):
333 | d.assign(mix * s + (1 - mix) * d)
334 | self._updates.assign_add(1)
335 |
--------------------------------------------------------------------------------
/dreamerv2/api.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | import os
4 | import pathlib
5 | import re
6 | import sys
7 | import warnings
8 |
9 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
10 | logging.getLogger().setLevel('ERROR')
11 | warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
12 |
13 | sys.path.append(str(pathlib.Path(__file__).parent))
14 | sys.path.append(str(pathlib.Path(__file__).parent.parent))
15 |
16 | import numpy as np
17 | import ruamel.yaml as yaml
18 |
19 | import agent
20 | import common
21 |
22 | from common import Config
23 | from common import GymWrapper
24 | from common import RenderImage
25 | from common import TerminalOutput
26 | from common import JSONLOutput
27 | from common import TensorBoardOutput
28 |
29 | configs = yaml.safe_load(
30 | (pathlib.Path(__file__).parent / 'configs.yaml').read_text())
31 | defaults = common.Config(configs.pop('defaults'))
32 |
33 |
34 | def train(env, config, outputs=None):
35 |
36 | logdir = pathlib.Path(config.logdir).expanduser()
37 | logdir.mkdir(parents=True, exist_ok=True)
38 | config.save(logdir / 'config.yaml')
39 | print(config, '\n')
40 | print('Logdir', logdir)
41 |
42 | outputs = outputs or [
43 | common.TerminalOutput(),
44 | common.JSONLOutput(config.logdir),
45 | common.TensorBoardOutput(config.logdir),
46 | ]
47 | replay = common.Replay(logdir / 'train_episodes', **config.replay)
48 | step = common.Counter(replay.stats['total_steps'])
49 | logger = common.Logger(step, outputs, multiplier=config.action_repeat)
50 | metrics = collections.defaultdict(list)
51 |
52 | should_train = common.Every(config.train_every)
53 | should_log = common.Every(config.log_every)
54 | should_video = common.Every(config.log_every)
55 | should_expl = common.Until(config.expl_until)
56 |
57 | def per_episode(ep):
58 | length = len(ep['reward']) - 1
59 | score = float(ep['reward'].astype(np.float64).sum())
60 | print(f'Episode has {length} steps and return {score:.1f}.')
61 | logger.scalar('return', score)
62 | logger.scalar('length', length)
63 | for key, value in ep.items():
64 | if re.match(config.log_keys_sum, key):
65 | logger.scalar(f'sum_{key}', ep[key].sum())
66 | if re.match(config.log_keys_mean, key):
67 | logger.scalar(f'mean_{key}', ep[key].mean())
68 | if re.match(config.log_keys_max, key):
69 | logger.scalar(f'max_{key}', ep[key].max(0).mean())
70 | if should_video(step):
71 | for key in config.log_keys_video:
72 | logger.video(f'policy_{key}', ep[key])
73 | logger.add(replay.stats)
74 | logger.write()
75 |
76 | env = common.GymWrapper(env)
77 | env = common.ResizeImage(env)
78 | if hasattr(env.act_space['action'], 'n'):
79 | env = common.OneHotAction(env)
80 | else:
81 | env = common.NormalizeAction(env)
82 | env = common.TimeLimit(env, config.time_limit)
83 |
84 | driver = common.Driver([env])
85 | driver.on_episode(per_episode)
86 | driver.on_step(lambda tran, worker: step.increment())
87 | driver.on_step(replay.add_step)
88 | driver.on_reset(replay.add_step)
89 |
90 | prefill = max(0, config.prefill - replay.stats['total_steps'])
91 | if prefill:
92 | print(f'Prefill dataset ({prefill} steps).')
93 | random_agent = common.RandomAgent(env.act_space)
94 | driver(random_agent, steps=prefill, episodes=1)
95 | driver.reset()
96 |
97 | print('Create agent.')
98 | agnt = agent.Agent(config, env.obs_space, env.act_space, step)
99 | dataset = iter(replay.dataset(**config.dataset))
100 | train_agent = common.CarryOverState(agnt.train)
101 | train_agent(next(dataset))
102 | if (logdir / 'variables.pkl').exists():
103 | agnt.load(logdir / 'variables.pkl')
104 | else:
105 | print('Pretrain agent.')
106 | for _ in range(config.pretrain):
107 | train_agent(next(dataset))
108 | policy = lambda *args: agnt.policy(
109 | *args, mode='explore' if should_expl(step) else 'train')
110 |
111 | def train_step(tran, worker):
112 | if should_train(step):
113 | for _ in range(config.train_steps):
114 | mets = train_agent(next(dataset))
115 | [metrics[key].append(value) for key, value in mets.items()]
116 | if should_log(step):
117 | for name, values in metrics.items():
118 | logger.scalar(name, np.array(values, np.float64).mean())
119 | metrics[name].clear()
120 | logger.add(agnt.report(next(dataset)))
121 | logger.write(fps=True)
122 | driver.on_step(train_step)
123 |
124 | while step < config.steps:
125 | logger.write()
126 | driver(policy, steps=config.eval_every)
127 | agnt.save(logdir / 'variables.pkl')
128 |
--------------------------------------------------------------------------------
/dreamerv2/common/__init__.py:
--------------------------------------------------------------------------------
1 | # General tools.
2 | from .config import *
3 | from .counter import *
4 | from .flags import *
5 | from .logger import *
6 | from .when import *
7 |
8 | # RL tools.
9 | from .other import *
10 | from .driver import *
11 | from .envs import *
12 | from .replay import *
13 |
14 | # TensorFlow tools.
15 | from .tfutils import *
16 | from .dists import *
17 | from .nets import *
18 |
--------------------------------------------------------------------------------
/dreamerv2/common/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pathlib
3 | import re
4 |
5 |
6 | class Config(dict):
7 |
8 | SEP = '.'
9 | IS_PATTERN = re.compile(r'.*[^A-Za-z0-9_.-].*')
10 |
11 | def __init__(self, *args, **kwargs):
12 | mapping = dict(*args, **kwargs)
13 | mapping = self._flatten(mapping)
14 | mapping = self._ensure_keys(mapping)
15 | mapping = self._ensure_values(mapping)
16 | self._flat = mapping
17 | self._nested = self._nest(mapping)
18 | # Need to assign the values to the base class dictionary so that
19 | # conversion to dict does not lose the content.
20 | super().__init__(self._nested)
21 |
22 | @property
23 | def flat(self):
24 | return self._flat.copy()
25 |
26 | def save(self, filename):
27 | filename = pathlib.Path(filename)
28 | if filename.suffix == '.json':
29 | filename.write_text(json.dumps(dict(self)))
30 | elif filename.suffix in ('.yml', '.yaml'):
31 | import ruamel.yaml as yaml
32 | with filename.open('w') as f:
33 | yaml.safe_dump(dict(self), f)
34 | else:
35 | raise NotImplementedError(filename.suffix)
36 |
37 | @classmethod
38 | def load(cls, filename):
39 | filename = pathlib.Path(filename)
40 | if filename.suffix == '.json':
41 | return cls(json.loads(filename.read_text()))
42 | elif filename.suffix in ('.yml', '.yaml'):
43 | import ruamel.yaml as yaml
44 | return cls(yaml.safe_load(filename.read_text()))
45 | else:
46 | raise NotImplementedError(filename.suffix)
47 |
48 | def parse_flags(self, argv=None, known_only=False, help_exists=None):
49 | from . import flags
50 | return flags.Flags(self).parse(argv, known_only, help_exists)
51 |
52 | def __contains__(self, name):
53 | try:
54 | self[name]
55 | return True
56 | except KeyError:
57 | return False
58 |
59 | def __getattr__(self, name):
60 | if name.startswith('_'):
61 | return super().__getattr__(name)
62 | try:
63 | return self[name]
64 | except KeyError:
65 | raise AttributeError(name)
66 |
67 | def __getitem__(self, name):
68 | result = self._nested
69 | for part in name.split(self.SEP):
70 | result = result[part]
71 | if isinstance(result, dict):
72 | result = type(self)(result)
73 | return result
74 |
75 | def __setattr__(self, key, value):
76 | if key.startswith('_'):
77 | return super().__setattr__(key, value)
78 | message = f"Tried to set key '{key}' on immutable config. Use update()."
79 | raise AttributeError(message)
80 |
81 | def __setitem__(self, key, value):
82 | if key.startswith('_'):
83 | return super().__setitem__(key, value)
84 | message = f"Tried to set key '{key}' on immutable config. Use update()."
85 | raise AttributeError(message)
86 |
87 | def __reduce__(self):
88 | return (type(self), (dict(self),))
89 |
90 | def __str__(self):
91 | lines = ['\nConfig:']
92 | keys, vals, typs = [], [], []
93 | for key, val in self.flat.items():
94 | keys.append(key + ':')
95 | vals.append(self._format_value(val))
96 | typs.append(self._format_type(val))
97 | max_key = max(len(k) for k in keys) if keys else 0
98 | max_val = max(len(v) for v in vals) if vals else 0
99 | for key, val, typ in zip(keys, vals, typs):
100 | key = key.ljust(max_key)
101 | val = val.ljust(max_val)
102 | lines.append(f'{key} {val} ({typ})')
103 | return '\n'.join(lines)
104 |
105 | def update(self, *args, **kwargs):
106 | result = self._flat.copy()
107 | inputs = self._flatten(dict(*args, **kwargs))
108 | for key, new in inputs.items():
109 | if self.IS_PATTERN.match(key):
110 | pattern = re.compile(key)
111 | keys = {k for k in result if pattern.match(k)}
112 | else:
113 | keys = [key]
114 | if not keys:
115 | raise KeyError(f'Unknown key or pattern {key}.')
116 | for key in keys:
117 | old = result[key]
118 | try:
119 | if isinstance(old, int) and isinstance(new, float):
120 | if float(int(new)) != new:
121 | message = f"Cannot convert fractional float {new} to int."
122 | raise ValueError(message)
123 | result[key] = type(old)(new)
124 | except (ValueError, TypeError):
125 | raise TypeError(
126 | f"Cannot convert '{new}' to type '{type(old).__name__}' " +
127 | f"of value '{old}' for key '{key}'.")
128 | return type(self)(result)
129 |
130 | def _flatten(self, mapping):
131 | result = {}
132 | for key, value in mapping.items():
133 | if isinstance(value, dict):
134 | for k, v in self._flatten(value).items():
135 | if self.IS_PATTERN.match(key) or self.IS_PATTERN.match(k):
136 | combined = f'{key}\\{self.SEP}{k}'
137 | else:
138 | combined = f'{key}{self.SEP}{k}'
139 | result[combined] = v
140 | else:
141 | result[key] = value
142 | return result
143 |
144 | def _nest(self, mapping):
145 | result = {}
146 | for key, value in mapping.items():
147 | parts = key.split(self.SEP)
148 | node = result
149 | for part in parts[:-1]:
150 | if part not in node:
151 | node[part] = {}
152 | node = node[part]
153 | node[parts[-1]] = value
154 | return result
155 |
156 | def _ensure_keys(self, mapping):
157 | for key in mapping:
158 | assert not self.IS_PATTERN.match(key), key
159 | return mapping
160 |
161 | def _ensure_values(self, mapping):
162 | result = json.loads(json.dumps(mapping))
163 | for key, value in result.items():
164 | if isinstance(value, list):
165 | value = tuple(value)
166 | if isinstance(value, tuple):
167 | if len(value) == 0:
168 | message = 'Empty lists are disallowed because their type is unclear.'
169 | raise TypeError(message)
170 | if not isinstance(value[0], (str, float, int, bool)):
171 | message = 'Lists can only contain strings, floats, ints, bools'
172 | message += f' but not {type(value[0])}'
173 | raise TypeError(message)
174 | if not all(isinstance(x, type(value[0])) for x in value[1:]):
175 | message = 'Elements of a list must all be of the same type.'
176 | raise TypeError(message)
177 | result[key] = value
178 | return result
179 |
180 | def _format_value(self, value):
181 | if isinstance(value, (list, tuple)):
182 | return '[' + ', '.join(self._format_value(x) for x in value) + ']'
183 | return str(value)
184 |
185 | def _format_type(self, value):
186 | if isinstance(value, (list, tuple)):
187 | assert len(value) > 0, value
188 | return self._format_type(value[0]) + 's'
189 | return str(type(value).__name__)
190 |
--------------------------------------------------------------------------------
/dreamerv2/common/counter.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 |
4 | @functools.total_ordering
5 | class Counter:
6 |
7 | def __init__(self, initial=0):
8 | self.value = initial
9 |
10 | def __int__(self):
11 | return int(self.value)
12 |
13 | def __eq__(self, other):
14 | return int(self) == other
15 |
16 | def __ne__(self, other):
17 | return int(self) != other
18 |
19 | def __lt__(self, other):
20 | return int(self) < other
21 |
22 | def __add__(self, other):
23 | return int(self) + other
24 |
25 | def increment(self, amount=1):
26 | self.value += amount
27 |
--------------------------------------------------------------------------------
/dreamerv2/common/dists.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow_probability as tfp
3 | from tensorflow_probability import distributions as tfd
4 |
5 |
6 | # Patch to ignore seed to avoid synchronization across GPUs.
7 | _orig_random_categorical = tf.random.categorical
8 | def random_categorical(*args, **kwargs):
9 | kwargs['seed'] = None
10 | return _orig_random_categorical(*args, **kwargs)
11 | tf.random.categorical = random_categorical
12 |
13 | # Patch to ignore seed to avoid synchronization across GPUs.
14 | _orig_random_normal = tf.random.normal
15 | def random_normal(*args, **kwargs):
16 | kwargs['seed'] = None
17 | return _orig_random_normal(*args, **kwargs)
18 | tf.random.normal = random_normal
19 |
20 |
21 | class SampleDist:
22 |
23 | def __init__(self, dist, samples=100):
24 | self._dist = dist
25 | self._samples = samples
26 |
27 | @property
28 | def name(self):
29 | return 'SampleDist'
30 |
31 | def __getattr__(self, name):
32 | return getattr(self._dist, name)
33 |
34 | def mean(self):
35 | samples = self._dist.sample(self._samples)
36 | return samples.mean(0)
37 |
38 | def mode(self):
39 | sample = self._dist.sample(self._samples)
40 | logprob = self._dist.log_prob(sample)
41 | return tf.gather(sample, tf.argmax(logprob))[0]
42 |
43 | def entropy(self):
44 | sample = self._dist.sample(self._samples)
45 | logprob = self.log_prob(sample)
46 | return -logprob.mean(0)
47 |
48 |
49 | class OneHotDist(tfd.OneHotCategorical):
50 |
51 | def __init__(self, logits=None, probs=None, dtype=None):
52 | self._sample_dtype = dtype or tf.float32
53 | super().__init__(logits=logits, probs=probs)
54 |
55 | def mode(self):
56 | return tf.cast(super().mode(), self._sample_dtype)
57 |
58 | def sample(self, sample_shape=(), seed=None):
59 | # Straight through biased gradient estimator.
60 | sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
61 | probs = self._pad(super().probs_parameter(), sample.shape)
62 | sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype)
63 | return sample
64 |
65 | def _pad(self, tensor, shape):
66 | tensor = super().probs_parameter()
67 | while len(tensor.shape) < len(shape):
68 | tensor = tensor[None]
69 | return tensor
70 |
71 |
72 | class TruncNormalDist(tfd.TruncatedNormal):
73 |
74 | def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
75 | super().__init__(loc, scale, low, high)
76 | self._clip = clip
77 | self._mult = mult
78 |
79 | def sample(self, *args, **kwargs):
80 | event = super().sample(*args, **kwargs)
81 | if self._clip:
82 | clipped = tf.clip_by_value(
83 | event, self.low + self._clip, self.high - self._clip)
84 | event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped)
85 | if self._mult:
86 | event *= self._mult
87 | return event
88 |
89 |
90 | class TanhBijector(tfp.bijectors.Bijector):
91 |
92 | def __init__(self, validate_args=False, name='tanh'):
93 | super().__init__(
94 | forward_min_event_ndims=0,
95 | validate_args=validate_args,
96 | name=name)
97 |
98 | def _forward(self, x):
99 | return tf.nn.tanh(x)
100 |
101 | def _inverse(self, y):
102 | dtype = y.dtype
103 | y = tf.cast(y, tf.float32)
104 | y = tf.where(
105 | tf.less_equal(tf.abs(y), 1.),
106 | tf.clip_by_value(y, -0.99999997, 0.99999997), y)
107 | y = tf.atanh(y)
108 | y = tf.cast(y, dtype)
109 | return y
110 |
111 | def _forward_log_det_jacobian(self, x):
112 | log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype))
113 | return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x))
114 |
--------------------------------------------------------------------------------
/dreamerv2/common/driver.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Driver:
5 |
6 | def __init__(self, envs, **kwargs):
7 | self._envs = envs
8 | self._kwargs = kwargs
9 | self._on_steps = []
10 | self._on_resets = []
11 | self._on_episodes = []
12 | self._act_spaces = [env.act_space for env in envs]
13 | self.reset()
14 |
15 | def on_step(self, callback):
16 | self._on_steps.append(callback)
17 |
18 | def on_reset(self, callback):
19 | self._on_resets.append(callback)
20 |
21 | def on_episode(self, callback):
22 | self._on_episodes.append(callback)
23 |
24 | def reset(self):
25 | self._obs = [None] * len(self._envs)
26 | self._eps = [None] * len(self._envs)
27 | self._state = None
28 |
29 | def __call__(self, policy, steps=0, episodes=0):
30 | step, episode = 0, 0
31 | while step < steps or episode < episodes:
32 | obs = {
33 | i: self._envs[i].reset()
34 | for i, ob in enumerate(self._obs) if ob is None or ob['is_last']}
35 | for i, ob in obs.items():
36 | self._obs[i] = ob() if callable(ob) else ob
37 | act = {k: np.zeros(v.shape) for k, v in self._act_spaces[i].items()}
38 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()}
39 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_resets]
40 | self._eps[i] = [tran]
41 | obs = {k: np.stack([o[k] for o in self._obs]) for k in self._obs[0]}
42 | actions, self._state = policy(obs, self._state, **self._kwargs)
43 | actions = [
44 | {k: np.array(actions[k][i]) for k in actions}
45 | for i in range(len(self._envs))]
46 | assert len(actions) == len(self._envs)
47 | obs = [e.step(a) for e, a in zip(self._envs, actions)]
48 | obs = [ob() if callable(ob) else ob for ob in obs]
49 | for i, (act, ob) in enumerate(zip(actions, obs)):
50 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()}
51 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_steps]
52 | self._eps[i].append(tran)
53 | step += 1
54 | if ob['is_last']:
55 | ep = self._eps[i]
56 | ep = {k: self._convert([t[k] for t in ep]) for k in ep[0]}
57 | [fn(ep, **self._kwargs) for fn in self._on_episodes]
58 | episode += 1
59 | self._obs = obs
60 |
61 | def _convert(self, value):
62 | value = np.array(value)
63 | if np.issubdtype(value.dtype, np.floating):
64 | return value.astype(np.float32)
65 | elif np.issubdtype(value.dtype, np.signedinteger):
66 | return value.astype(np.int32)
67 | elif np.issubdtype(value.dtype, np.uint8):
68 | return value.astype(np.uint8)
69 | return value
70 |
--------------------------------------------------------------------------------
/dreamerv2/common/envs.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import os
3 | import sys
4 | import threading
5 | import traceback
6 |
7 | import cloudpickle
8 | import gym
9 | import numpy as np
10 |
11 |
12 | class GymWrapper:
13 |
14 | def __init__(self, env, obs_key='image', act_key='action'):
15 | self._env = env
16 | self._obs_is_dict = hasattr(self._env.observation_space, 'spaces')
17 | self._act_is_dict = hasattr(self._env.action_space, 'spaces')
18 | self._obs_key = obs_key
19 | self._act_key = act_key
20 |
21 | def __getattr__(self, name):
22 | if name.startswith('__'):
23 | raise AttributeError(name)
24 | try:
25 | return getattr(self._env, name)
26 | except AttributeError:
27 | raise ValueError(name)
28 |
29 | @property
30 | def obs_space(self):
31 | if self._obs_is_dict:
32 | spaces = self._env.observation_space.spaces.copy()
33 | else:
34 | spaces = {self._obs_key: self._env.observation_space}
35 | return {
36 | **spaces,
37 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
38 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool),
39 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool),
40 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool),
41 | }
42 |
43 | @property
44 | def act_space(self):
45 | if self._act_is_dict:
46 | return self._env.action_space.spaces.copy()
47 | else:
48 | return {self._act_key: self._env.action_space}
49 |
50 | def step(self, action):
51 | if not self._act_is_dict:
52 | action = action[self._act_key]
53 | obs, reward, done, info = self._env.step(action)
54 | if not self._obs_is_dict:
55 | obs = {self._obs_key: obs}
56 | obs['reward'] = float(reward)
57 | obs['is_first'] = False
58 | obs['is_last'] = done
59 | obs['is_terminal'] = info.get('is_terminal', done)
60 | return obs
61 |
62 | def reset(self):
63 | obs = self._env.reset()
64 | if not self._obs_is_dict:
65 | obs = {self._obs_key: obs}
66 | obs['reward'] = 0.0
67 | obs['is_first'] = True
68 | obs['is_last'] = False
69 | obs['is_terminal'] = False
70 | return obs
71 |
72 |
73 | class DMC:
74 |
75 | def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
76 | os.environ['MUJOCO_GL'] = 'egl'
77 | domain, task = name.split('_', 1)
78 | if domain == 'cup': # Only domain with multiple words.
79 | domain = 'ball_in_cup'
80 | if domain == 'manip':
81 | from dm_control import manipulation
82 | self._env = manipulation.load(task + '_vision')
83 | elif domain == 'locom':
84 | from dm_control.locomotion.examples import basic_rodent_2020
85 | self._env = getattr(basic_rodent_2020, task)()
86 | else:
87 | from dm_control import suite
88 | self._env = suite.load(domain, task)
89 | self._action_repeat = action_repeat
90 | self._size = size
91 | if camera in (-1, None):
92 | camera = dict(
93 | quadruped_walk=2, quadruped_run=2, quadruped_escape=2,
94 | quadruped_fetch=2, locom_rodent_maze_forage=1,
95 | locom_rodent_two_touch=1,
96 | ).get(name, 0)
97 | self._camera = camera
98 | self._ignored_keys = []
99 | for key, value in self._env.observation_spec().items():
100 | if value.shape == (0,):
101 | print(f"Ignoring empty observation key '{key}'.")
102 | self._ignored_keys.append(key)
103 |
104 | @property
105 | def obs_space(self):
106 | spaces = {
107 | 'image': gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8),
108 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
109 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool),
110 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool),
111 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool),
112 | }
113 | for key, value in self._env.observation_spec().items():
114 | if key in self._ignored_keys:
115 | continue
116 | if value.dtype == np.float64:
117 | spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, np.float32)
118 | elif value.dtype == np.uint8:
119 | spaces[key] = gym.spaces.Box(0, 255, value.shape, np.uint8)
120 | else:
121 | raise NotImplementedError(value.dtype)
122 | return spaces
123 |
124 | @property
125 | def act_space(self):
126 | spec = self._env.action_spec()
127 | action = gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
128 | return {'action': action}
129 |
130 | def step(self, action):
131 | assert np.isfinite(action['action']).all(), action['action']
132 | reward = 0.0
133 | for _ in range(self._action_repeat):
134 | time_step = self._env.step(action['action'])
135 | reward += time_step.reward or 0.0
136 | if time_step.last():
137 | break
138 | assert time_step.discount in (0, 1)
139 | obs = {
140 | 'reward': reward,
141 | 'is_first': False,
142 | 'is_last': time_step.last(),
143 | 'is_terminal': time_step.discount == 0,
144 | 'image': self._env.physics.render(*self._size, camera_id=self._camera),
145 | }
146 | obs.update({
147 | k: v for k, v in dict(time_step.observation).items()
148 | if k not in self._ignored_keys})
149 | return obs
150 |
151 | def reset(self):
152 | time_step = self._env.reset()
153 | obs = {
154 | 'reward': 0.0,
155 | 'is_first': True,
156 | 'is_last': False,
157 | 'is_terminal': False,
158 | 'image': self._env.physics.render(*self._size, camera_id=self._camera),
159 | }
160 | obs.update({
161 | k: v for k, v in dict(time_step.observation).items()
162 | if k not in self._ignored_keys})
163 | return obs
164 |
165 |
166 | class Atari:
167 |
168 | LOCK = threading.Lock()
169 |
170 | def __init__(
171 | self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30,
172 | life_done=False, sticky=True, all_actions=False):
173 | assert size[0] == size[1]
174 | import gym.wrappers
175 | import gym.envs.atari
176 | if name == 'james_bond':
177 | name = 'jamesbond'
178 | with self.LOCK:
179 | env = gym.envs.atari.AtariEnv(
180 | game=name, obs_type='image', frameskip=1,
181 | repeat_action_probability=0.25 if sticky else 0.0,
182 | full_action_space=all_actions)
183 | # Avoid unnecessary rendering in inner env.
184 | env._get_obs = lambda: None
185 | # Tell wrapper that the inner env has no action repeat.
186 | env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0')
187 | self._env = gym.wrappers.AtariPreprocessing(
188 | env, noops, action_repeat, size[0], life_done, grayscale)
189 | self._size = size
190 | self._grayscale = grayscale
191 |
192 | @property
193 | def obs_space(self):
194 | shape = self._size + (1 if self._grayscale else 3,)
195 | return {
196 | 'image': gym.spaces.Box(0, 255, shape, np.uint8),
197 | 'ram': gym.spaces.Box(0, 255, (128,), np.uint8),
198 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
199 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool),
200 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool),
201 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool),
202 | }
203 |
204 | @property
205 | def act_space(self):
206 | return {'action': self._env.action_space}
207 |
208 | def step(self, action):
209 | image, reward, done, info = self._env.step(action['action'])
210 | if self._grayscale:
211 | image = image[..., None]
212 | return {
213 | 'image': image,
214 | 'ram': self._env.env._get_ram(),
215 | 'reward': reward,
216 | 'is_first': False,
217 | 'is_last': done,
218 | 'is_terminal': done,
219 | }
220 |
221 | def reset(self):
222 | with self.LOCK:
223 | image = self._env.reset()
224 | if self._grayscale:
225 | image = image[..., None]
226 | return {
227 | 'image': image,
228 | 'ram': self._env.env._get_ram(),
229 | 'reward': 0.0,
230 | 'is_first': True,
231 | 'is_last': False,
232 | 'is_terminal': False,
233 | }
234 |
235 | def close(self):
236 | return self._env.close()
237 |
238 |
239 | class Crafter:
240 |
241 | def __init__(self, outdir=None, reward=True, seed=None):
242 | import crafter
243 | self._env = crafter.Env(reward=reward, seed=seed)
244 | self._env = crafter.Recorder(
245 | self._env, outdir,
246 | save_stats=True,
247 | save_video=False,
248 | save_episode=False,
249 | )
250 | self._achievements = crafter.constants.achievements.copy()
251 |
252 | @property
253 | def obs_space(self):
254 | spaces = {
255 | 'image': self._env.observation_space,
256 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
257 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool),
258 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool),
259 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool),
260 | 'log_reward': gym.spaces.Box(-np.inf, np.inf, (), np.float32),
261 | }
262 | spaces.update({
263 | f'log_achievement_{k}': gym.spaces.Box(0, 2 ** 31 - 1, (), np.int32)
264 | for k in self._achievements})
265 | return spaces
266 |
267 | @property
268 | def act_space(self):
269 | return {'action': self._env.action_space}
270 |
271 | def step(self, action):
272 | image, reward, done, info = self._env.step(action['action'])
273 | obs = {
274 | 'image': image,
275 | 'reward': reward,
276 | 'is_first': False,
277 | 'is_last': done,
278 | 'is_terminal': info['discount'] == 0,
279 | 'log_reward': info['reward'],
280 | }
281 | obs.update({
282 | f'log_achievement_{k}': v
283 | for k, v in info['achievements'].items()})
284 | return obs
285 |
286 | def reset(self):
287 | obs = {
288 | 'image': self._env.reset(),
289 | 'reward': 0.0,
290 | 'is_first': True,
291 | 'is_last': False,
292 | 'is_terminal': False,
293 | 'log_reward': 0.0,
294 | }
295 | obs.update({
296 | f'log_achievement_{k}': 0
297 | for k in self._achievements})
298 | return obs
299 |
300 |
301 | class Dummy:
302 |
303 | def __init__(self):
304 | pass
305 |
306 | @property
307 | def obs_space(self):
308 | return {
309 | 'image': gym.spaces.Box(0, 255, (64, 64, 3), dtype=np.uint8),
310 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
311 | 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool),
312 | 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool),
313 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool),
314 | }
315 |
316 | @property
317 | def act_space(self):
318 | return {'action': gym.spaces.Box(-1, 1, (6,), dtype=np.float32)}
319 |
320 | def step(self, action):
321 | return {
322 | 'image': np.zeros((64, 64, 3)),
323 | 'reward': 0.0,
324 | 'is_first': False,
325 | 'is_last': False,
326 | 'is_terminal': False,
327 | }
328 |
329 | def reset(self):
330 | return {
331 | 'image': np.zeros((64, 64, 3)),
332 | 'reward': 0.0,
333 | 'is_first': True,
334 | 'is_last': False,
335 | 'is_terminal': False,
336 | }
337 |
338 |
339 | class TimeLimit:
340 |
341 | def __init__(self, env, duration):
342 | self._env = env
343 | self._duration = duration
344 | self._step = None
345 |
346 | def __getattr__(self, name):
347 | if name.startswith('__'):
348 | raise AttributeError(name)
349 | try:
350 | return getattr(self._env, name)
351 | except AttributeError:
352 | raise ValueError(name)
353 |
354 | def step(self, action):
355 | assert self._step is not None, 'Must reset environment.'
356 | obs = self._env.step(action)
357 | self._step += 1
358 | if self._duration and self._step >= self._duration:
359 | obs['is_last'] = True
360 | self._step = None
361 | return obs
362 |
363 | def reset(self):
364 | self._step = 0
365 | return self._env.reset()
366 |
367 |
368 | class NormalizeAction:
369 |
370 | def __init__(self, env, key='action'):
371 | self._env = env
372 | self._key = key
373 | space = env.act_space[key]
374 | self._mask = np.isfinite(space.low) & np.isfinite(space.high)
375 | self._low = np.where(self._mask, space.low, -1)
376 | self._high = np.where(self._mask, space.high, 1)
377 |
378 | def __getattr__(self, name):
379 | if name.startswith('__'):
380 | raise AttributeError(name)
381 | try:
382 | return getattr(self._env, name)
383 | except AttributeError:
384 | raise ValueError(name)
385 |
386 | @property
387 | def act_space(self):
388 | low = np.where(self._mask, -np.ones_like(self._low), self._low)
389 | high = np.where(self._mask, np.ones_like(self._low), self._high)
390 | space = gym.spaces.Box(low, high, dtype=np.float32)
391 | return {**self._env.act_space, self._key: space}
392 |
393 | def step(self, action):
394 | orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low
395 | orig = np.where(self._mask, orig, action[self._key])
396 | return self._env.step({**action, self._key: orig})
397 |
398 |
399 | class OneHotAction:
400 |
401 | def __init__(self, env, key='action'):
402 | assert hasattr(env.act_space[key], 'n')
403 | self._env = env
404 | self._key = key
405 | self._random = np.random.RandomState()
406 |
407 | def __getattr__(self, name):
408 | if name.startswith('__'):
409 | raise AttributeError(name)
410 | try:
411 | return getattr(self._env, name)
412 | except AttributeError:
413 | raise ValueError(name)
414 |
415 | @property
416 | def act_space(self):
417 | shape = (self._env.act_space[self._key].n,)
418 | space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
419 | space.sample = self._sample_action
420 | space.n = shape[0]
421 | return {**self._env.act_space, self._key: space}
422 |
423 | def step(self, action):
424 | index = np.argmax(action[self._key]).astype(int)
425 | reference = np.zeros_like(action[self._key])
426 | reference[index] = 1
427 | if not np.allclose(reference, action[self._key]):
428 | raise ValueError(f'Invalid one-hot action:\n{action}')
429 | return self._env.step({**action, self._key: index})
430 |
431 | def reset(self):
432 | return self._env.reset()
433 |
434 | def _sample_action(self):
435 | actions = self._env.act_space.n
436 | index = self._random.randint(0, actions)
437 | reference = np.zeros(actions, dtype=np.float32)
438 | reference[index] = 1.0
439 | return reference
440 |
441 |
442 | class ResizeImage:
443 |
444 | def __init__(self, env, size=(64, 64)):
445 | self._env = env
446 | self._size = size
447 | self._keys = [
448 | k for k, v in env.obs_space.items()
449 | if len(v.shape) > 1 and v.shape[:2] != size]
450 | print(f'Resizing keys {",".join(self._keys)} to {self._size}.')
451 | if self._keys:
452 | from PIL import Image
453 | self._Image = Image
454 |
455 | def __getattr__(self, name):
456 | if name.startswith('__'):
457 | raise AttributeError(name)
458 | try:
459 | return getattr(self._env, name)
460 | except AttributeError:
461 | raise ValueError(name)
462 |
463 | @property
464 | def obs_space(self):
465 | spaces = self._env.obs_space
466 | for key in self._keys:
467 | shape = self._size + spaces[key].shape[2:]
468 | spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8)
469 | return spaces
470 |
471 | def step(self, action):
472 | obs = self._env.step(action)
473 | for key in self._keys:
474 | obs[key] = self._resize(obs[key])
475 | return obs
476 |
477 | def reset(self):
478 | obs = self._env.reset()
479 | for key in self._keys:
480 | obs[key] = self._resize(obs[key])
481 | return obs
482 |
483 | def _resize(self, image):
484 | image = self._Image.fromarray(image)
485 | image = image.resize(self._size, self._Image.NEAREST)
486 | image = np.array(image)
487 | return image
488 |
489 |
490 | class RenderImage:
491 |
492 | def __init__(self, env, key='image'):
493 | self._env = env
494 | self._key = key
495 | self._shape = self._env.render().shape
496 |
497 | def __getattr__(self, name):
498 | if name.startswith('__'):
499 | raise AttributeError(name)
500 | try:
501 | return getattr(self._env, name)
502 | except AttributeError:
503 | raise ValueError(name)
504 |
505 | @property
506 | def obs_space(self):
507 | spaces = self._env.obs_space
508 | spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8)
509 | return spaces
510 |
511 | def step(self, action):
512 | obs = self._env.step(action)
513 | obs[self._key] = self._env.render('rgb_array')
514 | return obs
515 |
516 | def reset(self):
517 | obs = self._env.reset()
518 | obs[self._key] = self._env.render('rgb_array')
519 | return obs
520 |
521 |
522 | class Async:
523 |
524 | # Message types for communication via the pipe.
525 | _ACCESS = 1
526 | _CALL = 2
527 | _RESULT = 3
528 | _CLOSE = 4
529 | _EXCEPTION = 5
530 |
531 | def __init__(self, constructor, strategy='thread'):
532 | self._pickled_ctor = cloudpickle.dumps(constructor)
533 | if strategy == 'process':
534 | import multiprocessing as mp
535 | context = mp.get_context('spawn')
536 | elif strategy == 'thread':
537 | import multiprocessing.dummy as context
538 | else:
539 | raise NotImplementedError(strategy)
540 | self._strategy = strategy
541 | self._conn, conn = context.Pipe()
542 | self._process = context.Process(target=self._worker, args=(conn,))
543 | atexit.register(self.close)
544 | self._process.start()
545 | self._receive() # Ready.
546 | self._obs_space = None
547 | self._act_space = None
548 |
549 | def access(self, name):
550 | self._conn.send((self._ACCESS, name))
551 | return self._receive
552 |
553 | def call(self, name, *args, **kwargs):
554 | payload = name, args, kwargs
555 | self._conn.send((self._CALL, payload))
556 | return self._receive
557 |
558 | def close(self):
559 | try:
560 | self._conn.send((self._CLOSE, None))
561 | self._conn.close()
562 | except IOError:
563 | pass # The connection was already closed.
564 | self._process.join(5)
565 |
566 | @property
567 | def obs_space(self):
568 | if not self._obs_space:
569 | self._obs_space = self.access('obs_space')()
570 | return self._obs_space
571 |
572 | @property
573 | def act_space(self):
574 | if not self._act_space:
575 | self._act_space = self.access('act_space')()
576 | return self._act_space
577 |
578 | def step(self, action, blocking=False):
579 | promise = self.call('step', action)
580 | if blocking:
581 | return promise()
582 | else:
583 | return promise
584 |
585 | def reset(self, blocking=False):
586 | promise = self.call('reset')
587 | if blocking:
588 | return promise()
589 | else:
590 | return promise
591 |
592 | def _receive(self):
593 | try:
594 | message, payload = self._conn.recv()
595 | except (OSError, EOFError):
596 | raise RuntimeError('Lost connection to environment worker.')
597 | # Re-raise exceptions in the main process.
598 | if message == self._EXCEPTION:
599 | stacktrace = payload
600 | raise Exception(stacktrace)
601 | if message == self._RESULT:
602 | return payload
603 | raise KeyError('Received message of unexpected type {}'.format(message))
604 |
605 | def _worker(self, conn):
606 | try:
607 | ctor = cloudpickle.loads(self._pickled_ctor)
608 | env = ctor()
609 | conn.send((self._RESULT, None)) # Ready.
610 | while True:
611 | try:
612 | # Only block for short times to have keyboard exceptions be raised.
613 | if not conn.poll(0.1):
614 | continue
615 | message, payload = conn.recv()
616 | except (EOFError, KeyboardInterrupt):
617 | break
618 | if message == self._ACCESS:
619 | name = payload
620 | result = getattr(env, name)
621 | conn.send((self._RESULT, result))
622 | continue
623 | if message == self._CALL:
624 | name, args, kwargs = payload
625 | result = getattr(env, name)(*args, **kwargs)
626 | conn.send((self._RESULT, result))
627 | continue
628 | if message == self._CLOSE:
629 | break
630 | raise KeyError('Received message of unknown type {}'.format(message))
631 | except Exception:
632 | stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
633 | print('Error in environment process: {}'.format(stacktrace))
634 | conn.send((self._EXCEPTION, stacktrace))
635 | finally:
636 | try:
637 | conn.close()
638 | except IOError:
639 | pass # The connection was already closed.
640 |
--------------------------------------------------------------------------------
/dreamerv2/common/flags.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 |
4 |
5 | class Flags:
6 |
7 | def __init__(self, *args, **kwargs):
8 | from .config import Config
9 | self._config = Config(*args, **kwargs)
10 |
11 | def parse(self, argv=None, known_only=False, help_exists=None):
12 | if help_exists is None:
13 | help_exists = not known_only
14 | if argv is None:
15 | argv = sys.argv[1:]
16 | if '--help' in argv:
17 | print('\nHelp:')
18 | lines = str(self._config).split('\n')[2:]
19 | print('\n'.join('--' + re.sub(r'[:,\[\]]', '', x) for x in lines))
20 | help_exists and sys.exit()
21 | parsed = {}
22 | remaining = []
23 | key = None
24 | vals = None
25 | for arg in argv:
26 | if arg.startswith('--'):
27 | if key:
28 | self._submit_entry(key, vals, parsed, remaining)
29 | if '=' in arg:
30 | key, val = arg.split('=', 1)
31 | vals = [val]
32 | else:
33 | key, vals = arg, []
34 | else:
35 | if key:
36 | vals.append(arg)
37 | else:
38 | remaining.append(arg)
39 | self._submit_entry(key, vals, parsed, remaining)
40 | parsed = self._config.update(parsed)
41 | if known_only:
42 | return parsed, remaining
43 | else:
44 | for flag in remaining:
45 | if flag.startswith('--'):
46 | raise ValueError(f"Flag '{flag}' did not match any config keys.")
47 | assert not remaining, remaining
48 | return parsed
49 |
50 | def _submit_entry(self, key, vals, parsed, remaining):
51 | if not key and not vals:
52 | return
53 | if not key:
54 | vals = ', '.join(f"'{x}'" for x in vals)
55 | raise ValueError(f"Values {vals} were not preceeded by any flag.")
56 | name = key[len('--'):]
57 | if '=' in name:
58 | remaining.extend([key] + vals)
59 | return
60 | if self._config.IS_PATTERN.match(name):
61 | pattern = re.compile(name)
62 | keys = {k for k in self._config.flat if pattern.match(k)}
63 | elif name in self._config:
64 | keys = [name]
65 | else:
66 | keys = []
67 | if not keys:
68 | remaining.extend([key] + vals)
69 | return
70 | if not vals:
71 | raise ValueError(f"Flag '{key}' was not followed by any values.")
72 | for key in keys:
73 | parsed[key] = self._parse_flag_value(self._config[key], vals, key)
74 |
75 | def _parse_flag_value(self, default, value, key):
76 | value = value if isinstance(value, (tuple, list)) else (value,)
77 | if isinstance(default, (tuple, list)):
78 | if len(value) == 1 and ',' in value[0]:
79 | value = value[0].split(',')
80 | return tuple(self._parse_flag_value(default[0], [x], key) for x in value)
81 | assert len(value) == 1, value
82 | value = str(value[0])
83 | if default is None:
84 | return value
85 | if isinstance(default, bool):
86 | try:
87 | return bool(['False', 'True'].index(value))
88 | except ValueError:
89 | message = f"Expected bool but got '{value}' for key '{key}'."
90 | raise TypeError(message)
91 | if isinstance(default, int):
92 | value = float(value) # Allow scientific notation for integers.
93 | if float(int(value)) != value:
94 | message = f"Expected int but got float '{value}' for key '{key}'."
95 | raise TypeError(message)
96 | return int(value)
97 | return type(default)(value)
98 |
--------------------------------------------------------------------------------
/dreamerv2/common/logger.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pathlib
4 | import time
5 |
6 | import numpy as np
7 |
8 |
9 | class Logger:
10 |
11 | def __init__(self, step, outputs, multiplier=1):
12 | self._step = step
13 | self._outputs = outputs
14 | self._multiplier = multiplier
15 | self._last_step = None
16 | self._last_time = None
17 | self._metrics = []
18 |
19 | def add(self, mapping, prefix=None):
20 | step = int(self._step) * self._multiplier
21 | for name, value in dict(mapping).items():
22 | name = f'{prefix}_{name}' if prefix else name
23 | value = np.array(value)
24 | if len(value.shape) not in (0, 2, 3, 4):
25 | raise ValueError(
26 | f"Shape {value.shape} for name '{name}' cannot be "
27 | "interpreted as scalar, image, or video.")
28 | self._metrics.append((step, name, value))
29 |
30 | def scalar(self, name, value):
31 | self.add({name: value})
32 |
33 | def image(self, name, value):
34 | self.add({name: value})
35 |
36 | def video(self, name, value):
37 | self.add({name: value})
38 |
39 | def write(self, fps=False):
40 | fps and self.scalar('fps', self._compute_fps())
41 | if not self._metrics:
42 | return
43 | for output in self._outputs:
44 | output(self._metrics)
45 | self._metrics.clear()
46 |
47 | def _compute_fps(self):
48 | step = int(self._step) * self._multiplier
49 | if self._last_step is None:
50 | self._last_time = time.time()
51 | self._last_step = step
52 | return 0
53 | steps = step - self._last_step
54 | duration = time.time() - self._last_time
55 | self._last_time += duration
56 | self._last_step = step
57 | return steps / duration
58 |
59 |
60 | class TerminalOutput:
61 |
62 | def __call__(self, summaries):
63 | step = max(s for s, _, _, in summaries)
64 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0}
65 | formatted = {k: self._format_value(v) for k, v in scalars.items()}
66 | print(f'[{step}]', ' / '.join(f'{k} {v}' for k, v in formatted.items()))
67 |
68 | def _format_value(self, value):
69 | if value == 0:
70 | return '0'
71 | elif 0.01 < abs(value) < 10000:
72 | value = f'{value:.2f}'
73 | value = value.rstrip('0')
74 | value = value.rstrip('0')
75 | value = value.rstrip('.')
76 | return value
77 | else:
78 | value = f'{value:.1e}'
79 | value = value.replace('.0e', 'e')
80 | value = value.replace('+0', '')
81 | value = value.replace('+', '')
82 | value = value.replace('-0', '-')
83 | return value
84 |
85 |
86 | class JSONLOutput:
87 |
88 | def __init__(self, logdir):
89 | self._logdir = pathlib.Path(logdir).expanduser()
90 |
91 | def __call__(self, summaries):
92 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0}
93 | step = max(s for s, _, _, in summaries)
94 | with (self._logdir / 'metrics.jsonl').open('a') as f:
95 | f.write(json.dumps({'step': step, **scalars}) + '\n')
96 |
97 |
98 | class TensorBoardOutput:
99 |
100 | def __init__(self, logdir, fps=20):
101 | # The TensorFlow summary writer supports file protocols like gs://. We use
102 | # os.path over pathlib here to preserve those prefixes.
103 | self._logdir = os.path.expanduser(logdir)
104 | self._writer = None
105 | self._fps = fps
106 |
107 | def __call__(self, summaries):
108 | import tensorflow as tf
109 | self._ensure_writer()
110 | self._writer.set_as_default()
111 | for step, name, value in summaries:
112 | if len(value.shape) == 0:
113 | tf.summary.scalar('scalars/' + name, value, step)
114 | elif len(value.shape) == 2:
115 | tf.summary.image(name, value, step)
116 | elif len(value.shape) == 3:
117 | tf.summary.image(name, value, step)
118 | elif len(value.shape) == 4:
119 | self._video_summary(name, value, step)
120 | self._writer.flush()
121 |
122 | def _ensure_writer(self):
123 | if not self._writer:
124 | import tensorflow as tf
125 | self._writer = tf.summary.create_file_writer(
126 | self._logdir, max_queue=1000)
127 |
128 | def _video_summary(self, name, video, step):
129 | import tensorflow as tf
130 | import tensorflow.compat.v1 as tf1
131 | name = name if isinstance(name, str) else name.decode('utf-8')
132 | if np.issubdtype(video.dtype, np.floating):
133 | video = np.clip(255 * video, 0, 255).astype(np.uint8)
134 | try:
135 | T, H, W, C = video.shape
136 | summary = tf1.Summary()
137 | image = tf1.Summary.Image(height=H, width=W, colorspace=C)
138 | image.encoded_image_string = encode_gif(video, self._fps)
139 | summary.value.add(tag=name, image=image)
140 | tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
141 | except (IOError, OSError) as e:
142 | print('GIF summaries require ffmpeg in $PATH.', e)
143 | tf.summary.image(name, video, step)
144 |
145 |
146 | def encode_gif(frames, fps):
147 | from subprocess import Popen, PIPE
148 | h, w, c = frames[0].shape
149 | pxfmt = {1: 'gray', 3: 'rgb24'}[c]
150 | cmd = ' '.join([
151 | 'ffmpeg -y -f rawvideo -vcodec rawvideo',
152 | f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
153 | '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
154 | f'-r {fps:.02f} -f gif -'])
155 | proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
156 | for image in frames:
157 | proc.stdin.write(image.tobytes())
158 | out, err = proc.communicate()
159 | if proc.returncode:
160 | raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
161 | del proc
162 | return out
163 |
--------------------------------------------------------------------------------
/dreamerv2/common/nets.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow.keras import layers as tfkl
6 | from tensorflow_probability import distributions as tfd
7 | from tensorflow.keras.mixed_precision import experimental as prec
8 |
9 | import common
10 |
11 |
12 | class EnsembleRSSM(common.Module):
13 |
14 | def __init__(
15 | self, ensemble=5, stoch=30, deter=200, hidden=200, discrete=False,
16 | act='elu', norm='none', std_act='softplus', min_std=0.1):
17 | super().__init__()
18 | self._ensemble = ensemble
19 | self._stoch = stoch
20 | self._deter = deter
21 | self._hidden = hidden
22 | self._discrete = discrete
23 | self._act = get_act(act)
24 | self._norm = norm
25 | self._std_act = std_act
26 | self._min_std = min_std
27 | self._cell = GRUCell(self._deter, norm=True)
28 | self._cast = lambda x: tf.cast(x, prec.global_policy().compute_dtype)
29 |
30 | def initial(self, batch_size):
31 | dtype = prec.global_policy().compute_dtype
32 | if self._discrete:
33 | state = dict(
34 | logit=tf.zeros([batch_size, self._stoch, self._discrete], dtype),
35 | stoch=tf.zeros([batch_size, self._stoch, self._discrete], dtype),
36 | deter=self._cell.get_initial_state(None, batch_size, dtype))
37 | else:
38 | state = dict(
39 | mean=tf.zeros([batch_size, self._stoch], dtype),
40 | std=tf.zeros([batch_size, self._stoch], dtype),
41 | stoch=tf.zeros([batch_size, self._stoch], dtype),
42 | deter=self._cell.get_initial_state(None, batch_size, dtype))
43 | return state
44 |
45 | @tf.function
46 | def observe(self, embed, action, is_first, state=None):
47 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
48 | if state is None:
49 | state = self.initial(tf.shape(action)[0])
50 | post, prior = common.static_scan(
51 | lambda prev, inputs: self.obs_step(prev[0], *inputs),
52 | (swap(action), swap(embed), swap(is_first)), (state, state))
53 | post = {k: swap(v) for k, v in post.items()}
54 | prior = {k: swap(v) for k, v in prior.items()}
55 | return post, prior
56 |
57 | @tf.function
58 | def imagine(self, action, state=None):
59 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
60 | if state is None:
61 | state = self.initial(tf.shape(action)[0])
62 | assert isinstance(state, dict), state
63 | action = swap(action)
64 | prior = common.static_scan(self.img_step, action, state)
65 | prior = {k: swap(v) for k, v in prior.items()}
66 | return prior
67 |
68 | def get_feat(self, state):
69 | stoch = self._cast(state['stoch'])
70 | if self._discrete:
71 | shape = stoch.shape[:-2] + [self._stoch * self._discrete]
72 | stoch = tf.reshape(stoch, shape)
73 | return tf.concat([stoch, state['deter']], -1)
74 |
75 | def get_dist(self, state, ensemble=False):
76 | if ensemble:
77 | state = self._suff_stats_ensemble(state['deter'])
78 | if self._discrete:
79 | logit = state['logit']
80 | logit = tf.cast(logit, tf.float32)
81 | dist = tfd.Independent(common.OneHotDist(logit), 1)
82 | else:
83 | mean, std = state['mean'], state['std']
84 | mean = tf.cast(mean, tf.float32)
85 | std = tf.cast(std, tf.float32)
86 | dist = tfd.MultivariateNormalDiag(mean, std)
87 | return dist
88 |
89 | @tf.function
90 | def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
91 | # if is_first.any():
92 | prev_state, prev_action = tf.nest.map_structure(
93 | lambda x: tf.einsum(
94 | 'b,b...->b...', 1.0 - is_first.astype(x.dtype), x),
95 | (prev_state, prev_action))
96 | prior = self.img_step(prev_state, prev_action, sample)
97 | x = tf.concat([prior['deter'], embed], -1)
98 | x = self.get('obs_out', tfkl.Dense, self._hidden)(x)
99 | x = self.get('obs_out_norm', NormLayer, self._norm)(x)
100 | x = self._act(x)
101 | stats = self._suff_stats_layer('obs_dist', x)
102 | dist = self.get_dist(stats)
103 | stoch = dist.sample() if sample else dist.mode()
104 | post = {'stoch': stoch, 'deter': prior['deter'], **stats}
105 | return post, prior
106 |
107 | @tf.function
108 | def img_step(self, prev_state, prev_action, sample=True):
109 | prev_stoch = self._cast(prev_state['stoch'])
110 | prev_action = self._cast(prev_action)
111 | if self._discrete:
112 | shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
113 | prev_stoch = tf.reshape(prev_stoch, shape)
114 | x = tf.concat([prev_stoch, prev_action], -1)
115 | x = self.get('img_in', tfkl.Dense, self._hidden)(x)
116 | x = self.get('img_in_norm', NormLayer, self._norm)(x)
117 | x = self._act(x)
118 | deter = prev_state['deter']
119 | x, deter = self._cell(x, [deter])
120 | deter = deter[0] # Keras wraps the state in a list.
121 | stats = self._suff_stats_ensemble(x)
122 | index = tf.random.uniform((), 0, self._ensemble, tf.int32)
123 | stats = {k: v[index] for k, v in stats.items()}
124 | dist = self.get_dist(stats)
125 | stoch = dist.sample() if sample else dist.mode()
126 | prior = {'stoch': stoch, 'deter': deter, **stats}
127 | return prior
128 |
129 | def _suff_stats_ensemble(self, inp):
130 | bs = list(inp.shape[:-1])
131 | inp = inp.reshape([-1, inp.shape[-1]])
132 | stats = []
133 | for k in range(self._ensemble):
134 | x = self.get(f'img_out_{k}', tfkl.Dense, self._hidden)(inp)
135 | x = self.get(f'img_out_norm_{k}', NormLayer, self._norm)(x)
136 | x = self._act(x)
137 | stats.append(self._suff_stats_layer(f'img_dist_{k}', x))
138 | stats = {
139 | k: tf.stack([x[k] for x in stats], 0)
140 | for k, v in stats[0].items()}
141 | stats = {
142 | k: v.reshape([v.shape[0]] + bs + list(v.shape[2:]))
143 | for k, v in stats.items()}
144 | return stats
145 |
146 | def _suff_stats_layer(self, name, x):
147 | if self._discrete:
148 | x = self.get(name, tfkl.Dense, self._stoch * self._discrete, None)(x)
149 | logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete])
150 | return {'logit': logit}
151 | else:
152 | x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x)
153 | mean, std = tf.split(x, 2, -1)
154 | std = {
155 | 'softplus': lambda: tf.nn.softplus(std),
156 | 'sigmoid': lambda: tf.nn.sigmoid(std),
157 | 'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2),
158 | }[self._std_act]()
159 | std = std + self._min_std
160 | return {'mean': mean, 'std': std}
161 |
162 | def kl_loss(self, post, prior, forward, balance, free, free_avg):
163 | kld = tfd.kl_divergence
164 | sg = lambda x: tf.nest.map_structure(tf.stop_gradient, x)
165 | lhs, rhs = (prior, post) if forward else (post, prior)
166 | mix = balance if forward else (1 - balance)
167 | if balance == 0.5:
168 | value = kld(self.get_dist(lhs), self.get_dist(rhs))
169 | loss = tf.maximum(value, free).mean()
170 | else:
171 | value_lhs = value = kld(self.get_dist(lhs), self.get_dist(sg(rhs)))
172 | value_rhs = kld(self.get_dist(sg(lhs)), self.get_dist(rhs))
173 | if free_avg:
174 | loss_lhs = tf.maximum(value_lhs.mean(), free)
175 | loss_rhs = tf.maximum(value_rhs.mean(), free)
176 | else:
177 | loss_lhs = tf.maximum(value_lhs, free).mean()
178 | loss_rhs = tf.maximum(value_rhs, free).mean()
179 | loss = mix * loss_lhs + (1 - mix) * loss_rhs
180 | return loss, value
181 |
182 |
183 | class Encoder(common.Module):
184 |
185 | def __init__(
186 | self, shapes, cnn_keys=r'.*', mlp_keys=r'.*', act='elu', norm='none',
187 | cnn_depth=48, cnn_kernels=(4, 4, 4, 4), mlp_layers=[400, 400, 400, 400]):
188 | self.shapes = shapes
189 | self.cnn_keys = [
190 | k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3]
191 | self.mlp_keys = [
192 | k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1]
193 | print('Encoder CNN inputs:', list(self.cnn_keys))
194 | print('Encoder MLP inputs:', list(self.mlp_keys))
195 | self._act = get_act(act)
196 | self._norm = norm
197 | self._cnn_depth = cnn_depth
198 | self._cnn_kernels = cnn_kernels
199 | self._mlp_layers = mlp_layers
200 |
201 | @tf.function
202 | def __call__(self, data):
203 | key, shape = list(self.shapes.items())[0]
204 | batch_dims = data[key].shape[:-len(shape)]
205 | data = {
206 | k: tf.reshape(v, (-1,) + tuple(v.shape)[len(batch_dims):])
207 | for k, v in data.items()}
208 | outputs = []
209 | if self.cnn_keys:
210 | outputs.append(self._cnn({k: data[k] for k in self.cnn_keys}))
211 | if self.mlp_keys:
212 | outputs.append(self._mlp({k: data[k] for k in self.mlp_keys}))
213 | output = tf.concat(outputs, -1)
214 | return output.reshape(batch_dims + output.shape[1:])
215 |
216 | def _cnn(self, data):
217 | x = tf.concat(list(data.values()), -1)
218 | x = x.astype(prec.global_policy().compute_dtype)
219 | for i, kernel in enumerate(self._cnn_kernels):
220 | depth = 2 ** i * self._cnn_depth
221 | x = self.get(f'conv{i}', tfkl.Conv2D, depth, kernel, 2)(x)
222 | x = self.get(f'convnorm{i}', NormLayer, self._norm)(x)
223 | x = self._act(x)
224 | return x.reshape(tuple(x.shape[:-3]) + (-1,))
225 |
226 | def _mlp(self, data):
227 | x = tf.concat(list(data.values()), -1)
228 | x = x.astype(prec.global_policy().compute_dtype)
229 | for i, width in enumerate(self._mlp_layers):
230 | x = self.get(f'dense{i}', tfkl.Dense, width)(x)
231 | x = self.get(f'densenorm{i}', NormLayer, self._norm)(x)
232 | x = self._act(x)
233 | return x
234 |
235 |
236 | class Decoder(common.Module):
237 |
238 | def __init__(
239 | self, shapes, cnn_keys=r'.*', mlp_keys=r'.*', act='elu', norm='none',
240 | cnn_depth=48, cnn_kernels=(4, 4, 4, 4), mlp_layers=[400, 400, 400, 400]):
241 | self._shapes = shapes
242 | self.cnn_keys = [
243 | k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3]
244 | self.mlp_keys = [
245 | k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1]
246 | print('Decoder CNN outputs:', list(self.cnn_keys))
247 | print('Decoder MLP outputs:', list(self.mlp_keys))
248 | self._act = get_act(act)
249 | self._norm = norm
250 | self._cnn_depth = cnn_depth
251 | self._cnn_kernels = cnn_kernels
252 | self._mlp_layers = mlp_layers
253 |
254 | def __call__(self, features):
255 | features = tf.cast(features, prec.global_policy().compute_dtype)
256 | outputs = {}
257 | if self.cnn_keys:
258 | outputs.update(self._cnn(features))
259 | if self.mlp_keys:
260 | outputs.update(self._mlp(features))
261 | return outputs
262 |
263 | def _cnn(self, features):
264 | channels = {k: self._shapes[k][-1] for k in self.cnn_keys}
265 | ConvT = tfkl.Conv2DTranspose
266 | x = self.get('convin', tfkl.Dense, 32 * self._cnn_depth)(features)
267 | x = tf.reshape(x, [-1, 1, 1, 32 * self._cnn_depth])
268 | for i, kernel in enumerate(self._cnn_kernels):
269 | depth = 2 ** (len(self._cnn_kernels) - i - 2) * self._cnn_depth
270 | act, norm = self._act, self._norm
271 | if i == len(self._cnn_kernels) - 1:
272 | depth, act, norm = sum(channels.values()), tf.identity, 'none'
273 | x = self.get(f'conv{i}', ConvT, depth, kernel, 2)(x)
274 | x = self.get(f'convnorm{i}', NormLayer, norm)(x)
275 | x = act(x)
276 | x = x.reshape(features.shape[:-1] + x.shape[1:])
277 | means = tf.split(x, list(channels.values()), -1)
278 | dists = {
279 | key: tfd.Independent(tfd.Normal(mean, 1), 3)
280 | for (key, shape), mean in zip(channels.items(), means)}
281 | return dists
282 |
283 | def _mlp(self, features):
284 | shapes = {k: self._shapes[k] for k in self.mlp_keys}
285 | x = features
286 | for i, width in enumerate(self._mlp_layers):
287 | x = self.get(f'dense{i}', tfkl.Dense, width)(x)
288 | x = self.get(f'densenorm{i}', NormLayer, self._norm)(x)
289 | x = self._act(x)
290 | dists = {}
291 | for key, shape in shapes.items():
292 | dists[key] = self.get(f'dense_{key}', DistLayer, shape)(x)
293 | return dists
294 |
295 |
296 | class MLP(common.Module):
297 |
298 | def __init__(self, shape, layers, units, act='elu', norm='none', **out):
299 | self._shape = (shape,) if isinstance(shape, int) else shape
300 | self._layers = layers
301 | self._units = units
302 | self._norm = norm
303 | self._act = get_act(act)
304 | self._out = out
305 |
306 | def __call__(self, features):
307 | x = tf.cast(features, prec.global_policy().compute_dtype)
308 | x = x.reshape([-1, x.shape[-1]])
309 | for index in range(self._layers):
310 | x = self.get(f'dense{index}', tfkl.Dense, self._units)(x)
311 | x = self.get(f'norm{index}', NormLayer, self._norm)(x)
312 | x = self._act(x)
313 | x = x.reshape(features.shape[:-1] + [x.shape[-1]])
314 | return self.get('out', DistLayer, self._shape, **self._out)(x)
315 |
316 |
317 | class GRUCell(tf.keras.layers.AbstractRNNCell):
318 |
319 | def __init__(self, size, norm=False, act='tanh', update_bias=-1, **kwargs):
320 | super().__init__()
321 | self._size = size
322 | self._act = get_act(act)
323 | self._norm = norm
324 | self._update_bias = update_bias
325 | self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs)
326 | if norm:
327 | self._norm = tfkl.LayerNormalization(dtype=tf.float32)
328 |
329 | @property
330 | def state_size(self):
331 | return self._size
332 |
333 | @tf.function
334 | def call(self, inputs, state):
335 | state = state[0] # Keras wraps the state in a list.
336 | parts = self._layer(tf.concat([inputs, state], -1))
337 | if self._norm:
338 | dtype = parts.dtype
339 | parts = tf.cast(parts, tf.float32)
340 | parts = self._norm(parts)
341 | parts = tf.cast(parts, dtype)
342 | reset, cand, update = tf.split(parts, 3, -1)
343 | reset = tf.nn.sigmoid(reset)
344 | cand = self._act(reset * cand)
345 | update = tf.nn.sigmoid(update + self._update_bias)
346 | output = update * cand + (1 - update) * state
347 | return output, [output]
348 |
349 |
350 | class DistLayer(common.Module):
351 |
352 | def __init__(
353 | self, shape, dist='mse', min_std=0.1, init_std=0.0):
354 | self._shape = shape
355 | self._dist = dist
356 | self._min_std = min_std
357 | self._init_std = init_std
358 |
359 | def __call__(self, inputs):
360 | out = self.get('out', tfkl.Dense, np.prod(self._shape))(inputs)
361 | out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], self._shape], 0))
362 | out = tf.cast(out, tf.float32)
363 | if self._dist in ('normal', 'tanh_normal', 'trunc_normal'):
364 | std = self.get('std', tfkl.Dense, np.prod(self._shape))(inputs)
365 | std = tf.reshape(std, tf.concat([tf.shape(inputs)[:-1], self._shape], 0))
366 | std = tf.cast(std, tf.float32)
367 | if self._dist == 'mse':
368 | dist = tfd.Normal(out, 1.0)
369 | return tfd.Independent(dist, len(self._shape))
370 | if self._dist == 'normal':
371 | dist = tfd.Normal(out, std)
372 | return tfd.Independent(dist, len(self._shape))
373 | if self._dist == 'binary':
374 | dist = tfd.Bernoulli(out)
375 | return tfd.Independent(dist, len(self._shape))
376 | if self._dist == 'tanh_normal':
377 | mean = 5 * tf.tanh(out / 5)
378 | std = tf.nn.softplus(std + self._init_std) + self._min_std
379 | dist = tfd.Normal(mean, std)
380 | dist = tfd.TransformedDistribution(dist, common.TanhBijector())
381 | dist = tfd.Independent(dist, len(self._shape))
382 | return common.SampleDist(dist)
383 | if self._dist == 'trunc_normal':
384 | std = 2 * tf.nn.sigmoid((std + self._init_std) / 2) + self._min_std
385 | dist = common.TruncNormalDist(tf.tanh(out), std, -1, 1)
386 | return tfd.Independent(dist, 1)
387 | if self._dist == 'onehot':
388 | return common.OneHotDist(out)
389 | raise NotImplementedError(self._dist)
390 |
391 |
392 | class NormLayer(common.Module):
393 |
394 | def __init__(self, name):
395 | if name == 'none':
396 | self._layer = None
397 | elif name == 'layer':
398 | self._layer = tfkl.LayerNormalization()
399 | else:
400 | raise NotImplementedError(name)
401 |
402 | def __call__(self, features):
403 | if not self._layer:
404 | return features
405 | return self._layer(features)
406 |
407 |
408 | def get_act(name):
409 | if name == 'none':
410 | return tf.identity
411 | if name == 'mish':
412 | return lambda x: x * tf.math.tanh(tf.nn.softplus(x))
413 | elif hasattr(tf.nn, name):
414 | return getattr(tf.nn, name)
415 | elif hasattr(tf, name):
416 | return getattr(tf, name)
417 | else:
418 | raise NotImplementedError(name)
419 |
--------------------------------------------------------------------------------
/dreamerv2/common/other.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import contextlib
3 | import re
4 | import time
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 | from tensorflow_probability import distributions as tfd
9 |
10 | from . import dists
11 | from . import tfutils
12 |
13 |
14 | class RandomAgent:
15 |
16 | def __init__(self, act_space, logprob=False):
17 | self.act_space = act_space['action']
18 | self.logprob = logprob
19 | if hasattr(self.act_space, 'n'):
20 | self._dist = dists.OneHotDist(tf.zeros(self.act_space.n))
21 | else:
22 | dist = tfd.Uniform(self.act_space.low, self.act_space.high)
23 | self._dist = tfd.Independent(dist, 1)
24 |
25 | def __call__(self, obs, state=None, mode=None):
26 | action = self._dist.sample(len(obs['is_first']))
27 | output = {'action': action}
28 | if self.logprob:
29 | output['logprob'] = self._dist.log_prob(action)
30 | return output, None
31 |
32 |
33 | def static_scan(fn, inputs, start, reverse=False):
34 | last = start
35 | outputs = [[] for _ in tf.nest.flatten(start)]
36 | indices = range(tf.nest.flatten(inputs)[0].shape[0])
37 | if reverse:
38 | indices = reversed(indices)
39 | for index in indices:
40 | inp = tf.nest.map_structure(lambda x: x[index], inputs)
41 | last = fn(last, inp)
42 | [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))]
43 | if reverse:
44 | outputs = [list(reversed(x)) for x in outputs]
45 | outputs = [tf.stack(x, 0) for x in outputs]
46 | return tf.nest.pack_sequence_as(start, outputs)
47 |
48 |
49 | def schedule(string, step):
50 | try:
51 | return float(string)
52 | except ValueError:
53 | step = tf.cast(step, tf.float32)
54 | match = re.match(r'linear\((.+),(.+),(.+)\)', string)
55 | if match:
56 | initial, final, duration = [float(group) for group in match.groups()]
57 | mix = tf.clip_by_value(step / duration, 0, 1)
58 | return (1 - mix) * initial + mix * final
59 | match = re.match(r'warmup\((.+),(.+)\)', string)
60 | if match:
61 | warmup, value = [float(group) for group in match.groups()]
62 | scale = tf.clip_by_value(step / warmup, 0, 1)
63 | return scale * value
64 | match = re.match(r'exp\((.+),(.+),(.+)\)', string)
65 | if match:
66 | initial, final, halflife = [float(group) for group in match.groups()]
67 | return (initial - final) * 0.5 ** (step / halflife) + final
68 | match = re.match(r'horizon\((.+),(.+),(.+)\)', string)
69 | if match:
70 | initial, final, duration = [float(group) for group in match.groups()]
71 | mix = tf.clip_by_value(step / duration, 0, 1)
72 | horizon = (1 - mix) * initial + mix * final
73 | return 1 - 1 / horizon
74 | raise NotImplementedError(string)
75 |
76 |
77 | def lambda_return(
78 | reward, value, pcont, bootstrap, lambda_, axis):
79 | # Setting lambda=1 gives a discounted Monte Carlo return.
80 | # Setting lambda=0 gives a fixed 1-step return.
81 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
82 | if isinstance(pcont, (int, float)):
83 | pcont = pcont * tf.ones_like(reward)
84 | dims = list(range(reward.shape.ndims))
85 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
86 | if axis != 0:
87 | reward = tf.transpose(reward, dims)
88 | value = tf.transpose(value, dims)
89 | pcont = tf.transpose(pcont, dims)
90 | if bootstrap is None:
91 | bootstrap = tf.zeros_like(value[-1])
92 | next_values = tf.concat([value[1:], bootstrap[None]], 0)
93 | inputs = reward + pcont * next_values * (1 - lambda_)
94 | returns = static_scan(
95 | lambda agg, cur: cur[0] + cur[1] * lambda_ * agg,
96 | (inputs, pcont), bootstrap, reverse=True)
97 | if axis != 0:
98 | returns = tf.transpose(returns, dims)
99 | return returns
100 |
101 |
102 | def action_noise(action, amount, act_space):
103 | if amount == 0:
104 | return action
105 | amount = tf.cast(amount, action.dtype)
106 | if hasattr(act_space, 'n'):
107 | probs = amount / action.shape[-1] + (1 - amount) * action
108 | return dists.OneHotDist(probs=probs).sample()
109 | else:
110 | return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
111 |
112 |
113 | class StreamNorm(tfutils.Module):
114 |
115 | def __init__(self, shape=(), momentum=0.99, scale=1.0, eps=1e-8):
116 | # Momentum of 0 normalizes only based on the current batch.
117 | # Momentum of 1 disables normalization.
118 | self._shape = tuple(shape)
119 | self._momentum = momentum
120 | self._scale = scale
121 | self._eps = eps
122 | self.mag = tf.Variable(tf.ones(shape, tf.float64), False)
123 |
124 | def __call__(self, inputs):
125 | metrics = {}
126 | self.update(inputs)
127 | metrics['mean'] = inputs.mean()
128 | metrics['std'] = inputs.std()
129 | outputs = self.transform(inputs)
130 | metrics['normed_mean'] = outputs.mean()
131 | metrics['normed_std'] = outputs.std()
132 | return outputs, metrics
133 |
134 | def reset(self):
135 | self.mag.assign(tf.ones_like(self.mag))
136 |
137 | def update(self, inputs):
138 | batch = inputs.reshape((-1,) + self._shape)
139 | mag = tf.abs(batch).mean(0).astype(tf.float64)
140 | self.mag.assign(self._momentum * self.mag + (1 - self._momentum) * mag)
141 |
142 | def transform(self, inputs):
143 | values = inputs.reshape((-1,) + self._shape)
144 | values /= self.mag.astype(inputs.dtype)[None] + self._eps
145 | values *= self._scale
146 | return values.reshape(inputs.shape)
147 |
148 |
149 | class Timer:
150 |
151 | def __init__(self):
152 | self._indurs = collections.defaultdict(list)
153 | self._outdurs = collections.defaultdict(list)
154 | self._start_times = {}
155 | self._end_times = {}
156 |
157 | @contextlib.contextmanager
158 | def section(self, name):
159 | self.start(name)
160 | yield
161 | self.end(name)
162 |
163 | def wrap(self, function, name):
164 | def wrapped(*args, **kwargs):
165 | with self.section(name):
166 | return function(*args, **kwargs)
167 | return wrapped
168 |
169 | def start(self, name):
170 | now = time.time()
171 | self._start_times[name] = now
172 | if name in self._end_times:
173 | last = self._end_times[name]
174 | self._outdurs[name].append(now - last)
175 |
176 | def end(self, name):
177 | now = time.time()
178 | self._end_times[name] = now
179 | self._indurs[name].append(now - self._start_times[name])
180 |
181 | def result(self):
182 | metrics = {}
183 | for key in self._indurs:
184 | indurs = self._indurs[key]
185 | outdurs = self._outdurs[key]
186 | metrics[f'timer_count_{key}'] = len(indurs)
187 | metrics[f'timer_inside_{key}'] = np.sum(indurs)
188 | metrics[f'timer_outside_{key}'] = np.sum(outdurs)
189 | indurs.clear()
190 | outdurs.clear()
191 | return metrics
192 |
193 |
194 | class CarryOverState:
195 |
196 | def __init__(self, fn):
197 | self._fn = fn
198 | self._state = None
199 |
200 | def __call__(self, *args):
201 | self._state, out = self._fn(*args, self._state)
202 | return out
203 |
--------------------------------------------------------------------------------
/dreamerv2/common/plot.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import functools
4 | import itertools
5 | import json
6 | import multiprocessing as mp
7 | import os
8 | import pathlib
9 | import re
10 | import subprocess
11 | import warnings
12 |
13 | os.environ['NO_AT_BRIDGE'] = '1' # Hide X org false warning.
14 |
15 | import matplotlib
16 | matplotlib.use('Agg')
17 | import matplotlib.pyplot as plt
18 | import matplotlib.ticker as ticker
19 | import numpy as np
20 | import pandas as pd
21 |
22 | np.set_string_function(lambda x: f'')
23 |
24 | Run = collections.namedtuple('Run', 'task method seed xs ys')
25 |
26 | PALETTES = dict(
27 | discrete=(
28 | '#377eb8', '#4daf4a', '#984ea3', '#e41a1c', '#ff7f00', '#a65628',
29 | '#f781bf', '#888888', '#a6cee3', '#b2df8a', '#cab2d6', '#fb9a99',
30 | ),
31 | contrast=(
32 | '#0022ff', '#33aa00', '#ff0011', '#ddaa00', '#cc44dd', '#0088aa',
33 | '#001177', '#117700', '#990022', '#885500', '#553366', '#006666',
34 | ),
35 | gradient=(
36 | '#fde725', '#a0da39', '#4ac16d', '#1fa187', '#277f8e', '#365c8d',
37 | '#46327e', '#440154',
38 | ),
39 | baselines=(
40 | '#222222', '#666666', '#aaaaaa', '#cccccc',
41 | ),
42 | )
43 |
44 | LEGEND = dict(
45 | fontsize='medium', numpoints=1, labelspacing=0, columnspacing=1.2,
46 | handlelength=1.5, handletextpad=0.5, loc='lower center')
47 |
48 | DEFAULT_BASELINES = ['d4pg', 'rainbow_sticky', 'human_gamer', 'impala']
49 |
50 |
51 | def find_keys(args):
52 | filenames = []
53 | for indir in args.indir:
54 | task = next(indir.iterdir()) # First only.
55 | for method in task.iterdir():
56 | seed = next(indir.iterdir()) # First only.
57 | filenames += list(seed.glob('**/*.jsonl'))
58 | keys = set()
59 | for filename in filenames:
60 | keys |= set(load_jsonl(filename).columns)
61 | print(f'Keys ({len(keys)}):', ', '.join(keys), flush=True)
62 |
63 |
64 | def load_runs(args):
65 | total, toload = [], []
66 | for indir in args.indir:
67 | filenames = list(indir.glob('**/*.jsonl'))
68 | total += filenames
69 | for filename in filenames:
70 | task, method, seed = filename.relative_to(indir).parts[:-1]
71 | if not any(p.search(task) for p in args.tasks):
72 | continue
73 | if not any(p.search(method) for p in args.methods):
74 | continue
75 | toload.append((filename, indir))
76 | print(f'Loading {len(toload)} of {len(total)} runs...')
77 | jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
78 | # Disable async data loading:
79 | # runs = [j() for j in jobs]
80 | with mp.Pool(10) as pool:
81 | promises = [pool.apply_async(j) for j in jobs]
82 | runs = [p.get() for p in promises]
83 | runs = [r for r in runs if r is not None]
84 | return runs
85 |
86 |
87 | def load_run(filename, indir, args):
88 | task, method, seed = filename.relative_to(indir).parts[:-1]
89 | prefix = f'indir{args.indir.index(indir)+1}_'
90 | if task == 'atari_jamesbond':
91 | task = 'atari_james_bond'
92 | seed = prefix + seed
93 | if args.prefix:
94 | method = prefix + method
95 | df = load_jsonl(filename)
96 | if df is None:
97 | print('Skipping empty run')
98 | return
99 | try:
100 | df = df[[args.xaxis, args.yaxis]].dropna()
101 | if args.maxval:
102 | df = df.replace([+np.inf], +args.maxval)
103 | df = df.replace([-np.inf], -args.maxval)
104 | df[args.yaxis] = df[args.yaxis].clip(-args.maxval, +args.maxval)
105 | except KeyError:
106 | return
107 | xs = df[args.xaxis].to_numpy()
108 | if args.xmult != 1:
109 | xs = xs.astype(np.float32) * args.xmult
110 | ys = df[args.yaxis].to_numpy()
111 | bins = {
112 | 'atari': 1e6,
113 | 'dmc': 1e4,
114 | 'crafter': 1e4,
115 | }.get(task.split('_')[0], 1e5) if args.bins == -1 else args.bins
116 | if bins:
117 | borders = np.arange(0, xs.max() + 1e-8, bins)
118 | xs, ys = bin_scores(xs, ys, borders)
119 | if not len(xs):
120 | print('Skipping empty run', task, method, seed)
121 | return
122 | return Run(task, method, seed, xs, ys)
123 |
124 |
125 | def load_baselines(patterns, prefix=False):
126 | runs = []
127 | directory = pathlib.Path(__file__).parent.parent / 'scores'
128 | for filename in directory.glob('**/*_baselines.json'):
129 | for task, methods in json.loads(filename.read_text()).items():
130 | for method, score in methods.items():
131 | if prefix:
132 | method = f'baseline_{method}'
133 | if not any(p.search(method) for p in patterns):
134 | continue
135 | runs.append(Run(task, method, None, None, score))
136 | return runs
137 |
138 |
139 | def stats(runs, baselines):
140 | tasks = sorted(set(r.task for r in runs))
141 | methods = sorted(set(r.method for r in runs))
142 | seeds = sorted(set(r.seed for r in runs))
143 | baseline = sorted(set(r.method for r in baselines))
144 | print('Loaded', len(runs), 'runs.')
145 | print(f'Tasks ({len(tasks)}):', ', '.join(tasks))
146 | print(f'Methods ({len(methods)}):', ', '.join(methods))
147 | print(f'Seeds ({len(seeds)}):', ', '.join(seeds))
148 | print(f'Baselines ({len(baseline)}):', ', '.join(baseline))
149 |
150 |
151 | def order_methods(runs, baselines, args):
152 | methods = []
153 | for pattern in args.methods:
154 | for method in sorted(set(r.method for r in runs)):
155 | if pattern.search(method):
156 | if method not in methods:
157 | methods.append(method)
158 | if method not in args.colors:
159 | index = len(args.colors) % len(args.palette)
160 | args.colors[method] = args.palette[index]
161 | non_baseline_colors = len(args.colors)
162 | for pattern in args.baselines:
163 | for method in sorted(set(r.method for r in baselines)):
164 | if pattern.search(method):
165 | if method not in methods:
166 | methods.append(method)
167 | if method not in args.colors:
168 | index = len(args.colors) - non_baseline_colors
169 | index = index % len(PALETTES['baselines'])
170 | args.colors[method] = PALETTES['baselines'][index]
171 | return methods
172 |
173 |
174 | def figure(runs, methods, args):
175 | tasks = sorted(set(r.task for r in runs if r.xs is not None))
176 | rows = int(np.ceil((len(tasks) + len(args.add)) / args.cols))
177 | figsize = args.size[0] * args.cols, args.size[1] * rows
178 | fig, axes = plt.subplots(rows, args.cols, figsize=figsize, squeeze=False)
179 | for task, ax in zip(tasks, axes.flatten()):
180 | relevant = [r for r in runs if r.task == task]
181 | plot(task, ax, relevant, methods, args)
182 | for name, ax in zip(args.add, axes.flatten()[len(tasks):]):
183 | ax.set_facecolor((0.9, 0.9, 0.9))
184 | if name == 'median':
185 | plot_combined(
186 | 'combined_median', ax, runs, methods, args,
187 | agg=lambda x: np.nanmedian(x, -1))
188 | elif name == 'mean':
189 | plot_combined(
190 | 'combined_mean', ax, runs, methods, args,
191 | agg=lambda x: np.nanmean(x, -1))
192 | elif name == 'gamer_median':
193 | plot_combined(
194 | 'combined_gamer_median', ax, runs, methods, args,
195 | lo='random', hi='human_gamer',
196 | agg=lambda x: np.nanmedian(x, -1))
197 | elif name == 'gamer_mean':
198 | plot_combined(
199 | 'combined_gamer_mean', ax, runs, methods, args,
200 | lo='random', hi='human_gamer',
201 | agg=lambda x: np.nanmean(x, -1))
202 | elif name == 'record_mean':
203 | plot_combined(
204 | 'combined_record_mean', ax, runs, methods, args,
205 | lo='random', hi='record',
206 | agg=lambda x: np.nanmean(x, -1))
207 | elif name == 'clip_record_mean':
208 | plot_combined(
209 | 'combined_clipped_record_mean', ax, runs, methods, args,
210 | lo='random', hi='record', clip=True,
211 | agg=lambda x: np.nanmean(x, -1))
212 | elif name == 'seeds':
213 | plot_combined(
214 | 'combined_seeds', ax, runs, methods, args,
215 | agg=lambda x: np.isfinite(x).sum(-1))
216 | elif name == 'human_above':
217 | plot_combined(
218 | 'combined_above_human_gamer', ax, runs, methods, args,
219 | agg=lambda y: (y >= 1.0).astype(float).sum(-1))
220 | elif name == 'human_below':
221 | plot_combined(
222 | 'combined_below_human_gamer', ax, runs, methods, args,
223 | agg=lambda y: (y <= 1.0).astype(float).sum(-1))
224 | else:
225 | raise NotImplementedError(name)
226 | if args.xlim:
227 | for ax in axes[:-1].flatten():
228 | ax.xaxis.get_offset_text().set_visible(False)
229 | if args.xlabel:
230 | for ax in axes[-1]:
231 | ax.set_xlabel(args.xlabel)
232 | if args.ylabel:
233 | for ax in axes[:, 0]:
234 | ax.set_ylabel(args.ylabel)
235 | for ax in axes.flatten()[len(tasks) + len(args.add):]:
236 | ax.axis('off')
237 | legend(fig, args.labels, ncol=args.legendcols, **LEGEND)
238 | return fig
239 |
240 |
241 | def plot(task, ax, runs, methods, args):
242 | assert runs
243 | try:
244 | title = task.split('_', 1)[1].replace('_', ' ').title()
245 | except IndexError:
246 | title = task.title()
247 | ax.set_title(title)
248 | xlim = [+np.inf, -np.inf]
249 | for index, method in enumerate(methods):
250 | relevant = [r for r in runs if r.method == method]
251 | if not relevant:
252 | continue
253 | if any(r.xs is None for r in relevant):
254 | baseline(index, method, ax, relevant, args)
255 | else:
256 | if args.agg == 'none':
257 | xs, ys = curve_lines(index, task, method, ax, relevant, args)
258 | else:
259 | xs, ys = curve_area(index, task, method, ax, relevant, args)
260 | if len(xs) == len(ys) == 0:
261 | print(f'Skipping empty: {task} {method}')
262 | continue
263 | xlim = [min(xlim[0], np.nanmin(xs)), max(xlim[1], np.nanmax(xs))]
264 | ax.ticklabel_format(axis='x', style='sci', scilimits=(0, 0))
265 | steps = [1, 2, 2.5, 5, 10]
266 | ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
267 | ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
268 | if np.isfinite(xlim).all():
269 | ax.set_xlim(args.xlim or xlim)
270 | if args.xlim:
271 | ticks = sorted({*ax.get_xticks(), *args.xlim})
272 | ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
273 | ax.set_xticks(ticks)
274 | if args.ylim:
275 | ax.set_ylim(args.ylim)
276 | if args.ylimticks:
277 | ticks = sorted({*ax.get_yticks(), *args.ylim})
278 | ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
279 | ax.set_yticks(ticks)
280 |
281 |
282 | def plot_combined(
283 | name, ax, runs, methods, args, agg, lo=None, hi=None, clip=False):
284 | tasks = sorted(set(run.task for run in runs if run.xs is not None))
285 | seeds = list(set(run.seed for run in runs))
286 | runs = [r for r in runs if r.task in tasks] # Discard unused baselines.
287 | # Bin all runs onto the same X steps.
288 | borders = sorted(
289 | [r.xs for r in runs if r.xs is not None],
290 | key=lambda x: np.nanmax(x))[-1]
291 | for index, run in enumerate(runs):
292 | if run.xs is None:
293 | continue
294 | xs, ys = bin_scores(run.xs, run.ys, borders, fill='last')
295 | runs[index] = run._replace(xs=xs, ys=ys)
296 | # Per-task normalization by low and high baseline.
297 | if lo or hi:
298 | mins = collections.defaultdict(list)
299 | maxs = collections.defaultdict(list)
300 | [mins[r.task].append(r.ys) for r in load_baselines([re.compile(lo)])]
301 | [maxs[r.task].append(r.ys) for r in load_baselines([re.compile(hi)])]
302 | mins = {task: min(ys) for task, ys in mins.items() if task in tasks}
303 | maxs = {task: max(ys) for task, ys in maxs.items() if task in tasks}
304 | missing_baselines = []
305 | for task in tasks:
306 | if task not in mins or task not in maxs:
307 | missing_baselines.append(task)
308 | if set(missing_baselines) == set(tasks):
309 | print(f'No baselines found to normalize any tasks in {name} plot.')
310 | else:
311 | for task in missing_baselines:
312 | print(f'No baselines found to normalize {task} in {name} plot.')
313 | for index, run in enumerate(runs):
314 | if run.task not in mins or run.task not in maxs:
315 | continue
316 | ys = (run.ys - mins[run.task]) / (maxs[run.task] - mins[run.task])
317 | if clip:
318 | ys = np.minimum(ys, 1.0)
319 | runs[index] = run._replace(ys=ys)
320 | # Aggregate across tasks but not methods or seeds.
321 | combined = []
322 | for method, seed in itertools.product(methods, seeds):
323 | relevant = [r for r in runs if r.method == method and r.seed == seed]
324 | if not relevant:
325 | continue
326 | if relevant[0].xs is None:
327 | xs, ys = None, np.array([r.ys for r in relevant])
328 | else:
329 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in relevant]))
330 | with warnings.catch_warnings(): # Ignore empty slice warnings.
331 | warnings.simplefilter('ignore', category=RuntimeWarning)
332 | combined.append(Run('combined', method, seed, xs, agg(ys)))
333 | plot(name, ax, combined, methods, args)
334 |
335 |
336 | def curve_lines(index, task, method, ax, runs, args):
337 | zorder = 10000 - 10 * index - 1
338 | for run in runs:
339 | color = args.colors[method]
340 | ax.plot(run.xs, run.ys, label=method, color=color, zorder=zorder)
341 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in runs]))
342 | return xs, ys
343 |
344 |
345 | def curve_area(index, task, method, ax, runs, args):
346 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in runs]))
347 | with warnings.catch_warnings(): # NaN buckets remain NaN.
348 | warnings.simplefilter('ignore', category=RuntimeWarning)
349 | if args.agg == 'std1':
350 | mean, std = np.nanmean(ys, -1), np.nanstd(ys, -1)
351 | lo, mi, hi = mean - std, mean, mean + std
352 | elif args.agg == 'per0':
353 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (0, 50, 100)]
354 | elif args.agg == 'per5':
355 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (5, 50, 95)]
356 | elif args.agg == 'per25':
357 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (25, 50, 75)]
358 | else:
359 | raise NotImplementedError(args.agg)
360 | color = args.colors[method]
361 | kw = dict(color=color, zorder=1000 - 10 * index, alpha=0.1, linewidths=0)
362 | mask = ~np.isnan(mi)
363 | xs, lo, mi, hi = xs[mask], lo[mask], mi[mask], hi[mask]
364 | ax.fill_between(xs, lo, hi, **kw)
365 | ax.plot(xs, mi, label=method, color=color, zorder=10000 - 10 * index - 1)
366 | return xs, mi
367 |
368 |
369 | def baseline(index, method, ax, runs, args):
370 | assert all(run.xs is None for run in runs)
371 | ys = np.array([run.ys for run in runs])
372 | mean, std = ys.mean(), ys.std()
373 | color = args.colors[method]
374 | kw = dict(color=color, zorder=500 - 20 * index - 1, alpha=0.1, linewidths=0)
375 | ax.fill_between([-np.inf, np.inf], [mean - std] * 2, [mean + std] * 2, **kw)
376 | kw = dict(ls='--', color=color, zorder=5000 - 10 * index - 1)
377 | ax.axhline(mean, label=method, **kw)
378 |
379 |
380 | def legend(fig, mapping=None, **kwargs):
381 | entries = {}
382 | for ax in fig.axes:
383 | for handle, label in zip(*ax.get_legend_handles_labels()):
384 | if mapping and label in mapping:
385 | label = mapping[label]
386 | entries[label] = handle
387 | leg = fig.legend(entries.values(), entries.keys(), **kwargs)
388 | leg.get_frame().set_edgecolor('white')
389 | extent = leg.get_window_extent(fig.canvas.get_renderer())
390 | extent = extent.transformed(fig.transFigure.inverted())
391 | yloc, xloc = kwargs['loc'].split()
392 | y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
393 | y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
394 | x0 = dict(left=extent.x1, center=0, right=0)[xloc]
395 | x1 = dict(left=1, center=1, right=extent.x0)[xloc]
396 | fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
397 |
398 |
399 | def save(fig, args):
400 | args.outdir.mkdir(parents=True, exist_ok=True)
401 | filename = args.outdir / 'curves.png'
402 | fig.savefig(filename, dpi=args.dpi)
403 | print('Saved to', filename)
404 | filename = args.outdir / 'curves.pdf'
405 | fig.savefig(filename)
406 | try:
407 | subprocess.call(['pdfcrop', str(filename), str(filename)])
408 | except FileNotFoundError:
409 | print('Install texlive-extra-utils to crop PDF outputs.')
410 |
411 |
412 | def bin_scores(xs, ys, borders, reducer=np.nanmean, fill='nan'):
413 | order = np.argsort(xs)
414 | xs, ys = xs[order], ys[order]
415 | binned = []
416 | with warnings.catch_warnings(): # Empty buckets become NaN.
417 | warnings.simplefilter('ignore', category=RuntimeWarning)
418 | for start, stop in zip(borders[:-1], borders[1:]):
419 | left = (xs <= start).sum()
420 | right = (xs <= stop).sum()
421 | if left < right:
422 | value = reducer(ys[left:right])
423 | elif binned:
424 | value = {'nan': np.nan, 'last': binned[-1]}[fill]
425 | else:
426 | value = np.nan
427 | binned.append(value)
428 | return borders[1:], np.array(binned)
429 |
430 |
431 | def stack_scores(multiple_xs, multiple_ys, fill='last'):
432 | longest_xs = sorted(multiple_xs, key=lambda x: len(x))[-1]
433 | multiple_padded_ys = []
434 | for xs, ys in zip(multiple_xs, multiple_ys):
435 | assert (longest_xs[:len(xs)] == xs).all(), (list(xs), list(longest_xs))
436 | value = {'nan': np.nan, 'last': ys[-1]}[fill]
437 | padding = [value] * (len(longest_xs) - len(xs))
438 | padded_ys = np.concatenate([ys, padding])
439 | multiple_padded_ys.append(padded_ys)
440 | stacked_ys = np.stack(multiple_padded_ys, -1)
441 | return longest_xs, stacked_ys
442 |
443 |
444 | def load_jsonl(filename):
445 | try:
446 | with filename.open() as f:
447 | lines = list(f.readlines())
448 | records = []
449 | for index, line in enumerate(lines):
450 | try:
451 | records.append(json.loads(line))
452 | except Exception:
453 | if index == len(lines) - 1:
454 | continue # Silently skip last line if it is incomplete.
455 | raise ValueError(
456 | f'Skipping invalid JSON line ({index+1}/{len(lines)+1}) in'
457 | f'{filename}: {line}')
458 | return pd.DataFrame(records)
459 | except ValueError as e:
460 | print('Invalid', filename, e)
461 | return None
462 |
463 |
464 | def save_runs(runs, filename):
465 | filename.parent.mkdir(parents=True, exist_ok=True)
466 | records = []
467 | for run in runs:
468 | if run.xs is None:
469 | continue
470 | records.append(dict(
471 | task=run.task, method=run.method, seed=run.seed,
472 | xs=run.xs.tolist(), ys=run.ys.tolist()))
473 | runs = json.dumps(records)
474 | filename.write_text(runs)
475 | print('Saved', filename)
476 |
477 |
478 | def main(args):
479 | find_keys(args)
480 | runs = load_runs(args)
481 | save_runs(runs, args.outdir / 'runs.json')
482 | baselines = load_baselines(args.baselines, args.prefix)
483 | stats(runs, baselines)
484 | methods = order_methods(runs, baselines, args)
485 | if not runs:
486 | print('Noting to plot.')
487 | return
488 | # Adjust options based on loaded runs.
489 | tasks = set(r.task for r in runs)
490 | if 'auto' in args.add:
491 | index = args.add.index('auto')
492 | del args.add[index]
493 | atari = any(run.task.startswith('atari_') for run in runs)
494 | if len(tasks) < 2:
495 | pass
496 | elif atari:
497 | args.add[index:index] = [
498 | 'gamer_median', 'gamer_mean', 'record_mean', 'clip_record_mean',
499 | ]
500 | else:
501 | args.add[index:index] = ['mean', 'median']
502 | args.cols = min(args.cols, len(tasks) + len(args.add))
503 | args.legendcols = min(args.legendcols, args.cols)
504 | print('Plotting...')
505 | fig = figure(runs + baselines, methods, args)
506 | save(fig, args)
507 |
508 |
509 | def parse_args():
510 | boolean = lambda x: bool(['False', 'True'].index(x))
511 | parser = argparse.ArgumentParser()
512 | parser.add_argument('--indir', nargs='+', type=pathlib.Path, required=True)
513 | parser.add_argument('--indir-prefix', type=pathlib.Path)
514 | parser.add_argument('--outdir', type=pathlib.Path, required=True)
515 | parser.add_argument('--subdir', type=boolean, default=True)
516 | parser.add_argument('--xaxis', type=str, default='step')
517 | parser.add_argument('--yaxis', type=str, default='eval_return')
518 | parser.add_argument('--tasks', nargs='+', default=[r'.*'])
519 | parser.add_argument('--methods', nargs='+', default=[r'.*'])
520 | parser.add_argument('--baselines', nargs='+', default=DEFAULT_BASELINES)
521 | parser.add_argument('--prefix', type=boolean, default=False)
522 | parser.add_argument('--bins', type=float, default=-1)
523 | parser.add_argument('--agg', type=str, default='std1')
524 | parser.add_argument('--size', nargs=2, type=float, default=[2.5, 2.3])
525 | parser.add_argument('--dpi', type=int, default=80)
526 | parser.add_argument('--cols', type=int, default=6)
527 | parser.add_argument('--xlim', nargs=2, type=float, default=None)
528 | parser.add_argument('--ylim', nargs=2, type=float, default=None)
529 | parser.add_argument('--ylimticks', type=boolean, default=True)
530 | parser.add_argument('--xlabel', type=str, default=None)
531 | parser.add_argument('--ylabel', type=str, default=None)
532 | parser.add_argument('--xticks', type=int, default=6)
533 | parser.add_argument('--yticks', type=int, default=5)
534 | parser.add_argument('--xmult', type=float, default=1)
535 | parser.add_argument('--labels', nargs='+', default=None)
536 | parser.add_argument('--palette', nargs='+', default=['contrast'])
537 | parser.add_argument('--legendcols', type=int, default=4)
538 | parser.add_argument('--colors', nargs='+', default={})
539 | parser.add_argument('--maxval', type=float, default=0)
540 | parser.add_argument('--add', nargs='+', type=str, default=['auto', 'seeds'])
541 | args = parser.parse_args()
542 | if args.subdir:
543 | args.outdir /= args.indir[0].stem
544 | if args.indir_prefix:
545 | args.indir = [args.indir_prefix / indir for indir in args.indir]
546 | args.indir = [d.expanduser() for d in args.indir]
547 | args.outdir = args.outdir.expanduser()
548 | if args.labels:
549 | assert len(args.labels) % 2 == 0
550 | args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
551 | if args.colors:
552 | assert len(args.colors) % 2 == 0
553 | args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
554 | args.tasks = [re.compile(p) for p in args.tasks]
555 | args.methods = [re.compile(p) for p in args.methods]
556 | args.baselines = [re.compile(p) for p in args.baselines]
557 | if 'return' not in args.yaxis:
558 | args.baselines = []
559 | if args.prefix is None:
560 | args.prefix = len(args.indir) > 1
561 | if len(args.palette) == 1 and args.palette[0] in PALETTES:
562 | args.palette = 10 * PALETTES[args.palette[0]]
563 | if len(args.add) == 1 and args.add[0] == 'none':
564 | args.add = []
565 | return args
566 |
567 |
568 | if __name__ == '__main__':
569 | main(parse_args())
570 |
--------------------------------------------------------------------------------
/dreamerv2/common/replay.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import datetime
3 | import io
4 | import pathlib
5 | import uuid
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 |
10 |
11 | class Replay:
12 |
13 | def __init__(
14 | self, directory, capacity=0, ongoing=False, minlen=1, maxlen=0,
15 | prioritize_ends=False):
16 | self._directory = pathlib.Path(directory).expanduser()
17 | self._directory.mkdir(parents=True, exist_ok=True)
18 | self._capacity = capacity
19 | self._ongoing = ongoing
20 | self._minlen = minlen
21 | self._maxlen = maxlen
22 | self._prioritize_ends = prioritize_ends
23 | self._random = np.random.RandomState()
24 | # filename -> key -> value_sequence
25 | self._complete_eps = load_episodes(self._directory, capacity, minlen)
26 | # worker -> key -> value_sequence
27 | self._ongoing_eps = collections.defaultdict(
28 | lambda: collections.defaultdict(list))
29 | self._total_episodes, self._total_steps = count_episodes(directory)
30 | self._loaded_episodes = len(self._complete_eps)
31 | self._loaded_steps = sum(eplen(x) for x in self._complete_eps.values())
32 |
33 | @property
34 | def stats(self):
35 | return {
36 | 'total_steps': self._total_steps,
37 | 'total_episodes': self._total_episodes,
38 | 'loaded_steps': self._loaded_steps,
39 | 'loaded_episodes': self._loaded_episodes,
40 | }
41 |
42 | def add_step(self, transition, worker=0):
43 | episode = self._ongoing_eps[worker]
44 | for key, value in transition.items():
45 | episode[key].append(value)
46 | if transition['is_last']:
47 | self.add_episode(episode)
48 | episode.clear()
49 |
50 | def add_episode(self, episode):
51 | length = eplen(episode)
52 | if length < self._minlen:
53 | print(f'Skipping short episode of length {length}.')
54 | return
55 | self._total_steps += length
56 | self._loaded_steps += length
57 | self._total_episodes += 1
58 | self._loaded_episodes += 1
59 | episode = {key: convert(value) for key, value in episode.items()}
60 | filename = save_episode(self._directory, episode)
61 | self._complete_eps[str(filename)] = episode
62 | self._enforce_limit()
63 |
64 | def dataset(self, batch, length):
65 | example = next(iter(self._generate_chunks(length)))
66 | dataset = tf.data.Dataset.from_generator(
67 | lambda: self._generate_chunks(length),
68 | {k: v.dtype for k, v in example.items()},
69 | {k: v.shape for k, v in example.items()})
70 | dataset = dataset.batch(batch, drop_remainder=True)
71 | dataset = dataset.prefetch(5)
72 | return dataset
73 |
74 | def _generate_chunks(self, length):
75 | sequence = self._sample_sequence()
76 | while True:
77 | chunk = collections.defaultdict(list)
78 | added = 0
79 | while added < length:
80 | needed = length - added
81 | adding = {k: v[:needed] for k, v in sequence.items()}
82 | sequence = {k: v[needed:] for k, v in sequence.items()}
83 | for key, value in adding.items():
84 | chunk[key].append(value)
85 | added += len(adding['action'])
86 | if len(sequence['action']) < 1:
87 | sequence = self._sample_sequence()
88 | chunk = {k: np.concatenate(v) for k, v in chunk.items()}
89 | yield chunk
90 |
91 | def _sample_sequence(self):
92 | episodes = list(self._complete_eps.values())
93 | if self._ongoing:
94 | episodes += [
95 | x for x in self._ongoing_eps.values()
96 | if eplen(x) >= self._minlen]
97 | episode = self._random.choice(episodes)
98 | total = len(episode['action'])
99 | length = total
100 | if self._maxlen:
101 | length = min(length, self._maxlen)
102 | # Randomize length to avoid all chunks ending at the same time in case the
103 | # episodes are all of the same length.
104 | length -= np.random.randint(self._minlen)
105 | length = max(self._minlen, length)
106 | upper = total - length + 1
107 | if self._prioritize_ends:
108 | upper += self._minlen
109 | index = min(self._random.randint(upper), total - length)
110 | sequence = {
111 | k: convert(v[index: index + length])
112 | for k, v in episode.items() if not k.startswith('log_')}
113 | sequence['is_first'] = np.zeros(len(sequence['action']), np.bool)
114 | sequence['is_first'][0] = True
115 | if self._maxlen:
116 | assert self._minlen <= len(sequence['action']) <= self._maxlen
117 | return sequence
118 |
119 | def _enforce_limit(self):
120 | if not self._capacity:
121 | return
122 | while self._loaded_episodes > 1 and self._loaded_steps > self._capacity:
123 | # Relying on Python preserving the insertion order of dicts.
124 | oldest, episode = next(iter(self._complete_eps.items()))
125 | self._loaded_steps -= eplen(episode)
126 | self._loaded_episodes -= 1
127 | del self._complete_eps[oldest]
128 |
129 |
130 | def count_episodes(directory):
131 | filenames = list(directory.glob('*.npz'))
132 | num_episodes = len(filenames)
133 | num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames)
134 | return num_episodes, num_steps
135 |
136 |
137 | def save_episode(directory, episode):
138 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
139 | identifier = str(uuid.uuid4().hex)
140 | length = eplen(episode)
141 | filename = directory / f'{timestamp}-{identifier}-{length}.npz'
142 | with io.BytesIO() as f1:
143 | np.savez_compressed(f1, **episode)
144 | f1.seek(0)
145 | with filename.open('wb') as f2:
146 | f2.write(f1.read())
147 | return filename
148 |
149 |
150 | def load_episodes(directory, capacity=None, minlen=1):
151 | # The returned directory from filenames to episodes is guaranteed to be in
152 | # temporally sorted order.
153 | filenames = sorted(directory.glob('*.npz'))
154 | if capacity:
155 | num_steps = 0
156 | num_episodes = 0
157 | for filename in reversed(filenames):
158 | length = int(str(filename).split('-')[-1][:-4])
159 | num_steps += length
160 | num_episodes += 1
161 | if num_steps >= capacity:
162 | break
163 | filenames = filenames[-num_episodes:]
164 | episodes = {}
165 | for filename in filenames:
166 | try:
167 | with filename.open('rb') as f:
168 | episode = np.load(f)
169 | episode = {k: episode[k] for k in episode.keys()}
170 | except Exception as e:
171 | print(f'Could not load episode {str(filename)}: {e}')
172 | continue
173 | episodes[str(filename)] = episode
174 | return episodes
175 |
176 |
177 | def convert(value):
178 | value = np.array(value)
179 | if np.issubdtype(value.dtype, np.floating):
180 | return value.astype(np.float32)
181 | elif np.issubdtype(value.dtype, np.signedinteger):
182 | return value.astype(np.int32)
183 | elif np.issubdtype(value.dtype, np.uint8):
184 | return value.astype(np.uint8)
185 | return value
186 |
187 |
188 | def eplen(episode):
189 | return len(episode['action']) - 1
190 |
--------------------------------------------------------------------------------
/dreamerv2/common/tfutils.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import pickle
3 | import re
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from tensorflow.keras import mixed_precision as prec
8 |
9 | try:
10 | from tensorflow.python.distribute import values
11 | except Exception:
12 | from google3.third_party.tensorflow.python.distribute import values
13 |
14 | tf.tensor = tf.convert_to_tensor
15 | for base in (tf.Tensor, tf.Variable, values.PerReplica):
16 | base.mean = tf.math.reduce_mean
17 | base.std = tf.math.reduce_std
18 | base.var = tf.math.reduce_variance
19 | base.sum = tf.math.reduce_sum
20 | base.any = tf.math.reduce_any
21 | base.all = tf.math.reduce_all
22 | base.min = tf.math.reduce_min
23 | base.max = tf.math.reduce_max
24 | base.abs = tf.math.abs
25 | base.logsumexp = tf.math.reduce_logsumexp
26 | base.transpose = tf.transpose
27 | base.reshape = tf.reshape
28 | base.astype = tf.cast
29 |
30 |
31 | # values.PerReplica.dtype = property(lambda self: self.values[0].dtype)
32 |
33 | # tf.TensorHandle.__repr__ = lambda x: ''
34 | # tf.TensorHandle.__str__ = lambda x: ''
35 | # np.set_printoptions(threshold=5, edgeitems=0)
36 |
37 |
38 | class Module(tf.Module):
39 |
40 | def save(self, filename):
41 | values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
42 | amount = len(tf.nest.flatten(values))
43 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
44 | print(f'Save checkpoint with {amount} tensors and {count} parameters.')
45 | with pathlib.Path(filename).open('wb') as f:
46 | pickle.dump(values, f)
47 |
48 | def load(self, filename):
49 | with pathlib.Path(filename).open('rb') as f:
50 | values = pickle.load(f)
51 | amount = len(tf.nest.flatten(values))
52 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
53 | print(f'Load checkpoint with {amount} tensors and {count} parameters.')
54 | tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)
55 |
56 | def get(self, name, ctor, *args, **kwargs):
57 | # Create or get layer by name to avoid mentioning it in the constructor.
58 | if not hasattr(self, '_modules'):
59 | self._modules = {}
60 | if name not in self._modules:
61 | self._modules[name] = ctor(*args, **kwargs)
62 | return self._modules[name]
63 |
64 |
65 | class Optimizer(tf.Module):
66 |
67 | def __init__(
68 | self, name, lr, eps=1e-4, clip=None, wd=None,
69 | opt='adam', wd_pattern=r'.*'):
70 | assert 0 <= wd < 1
71 | assert not clip or 1 <= clip
72 | self._name = name
73 | self._clip = clip
74 | self._wd = wd
75 | self._wd_pattern = wd_pattern
76 | self._opt = {
77 | 'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps),
78 | 'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps),
79 | 'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps),
80 | 'sgd': lambda: tf.optimizers.SGD(lr),
81 | 'momentum': lambda: tf.optimizers.SGD(lr, 0.9),
82 | }[opt]()
83 | self._mixed = (prec.global_policy().compute_dtype == tf.float16)
84 | if self._mixed:
85 | self._opt = prec.LossScaleOptimizer(self._opt, dynamic=True)
86 | self._once = True
87 |
88 | @property
89 | def variables(self):
90 | return self._opt.variables()
91 |
92 | def __call__(self, tape, loss, modules):
93 | assert loss.dtype is tf.float32, (self._name, loss.dtype)
94 | assert len(loss.shape) == 0, (self._name, loss.shape)
95 | metrics = {}
96 |
97 | # Find variables.
98 | modules = modules if hasattr(modules, '__len__') else (modules,)
99 | varibs = tf.nest.flatten([module.variables for module in modules])
100 | count = sum(np.prod(x.shape) for x in varibs)
101 | if self._once:
102 | print(f'Found {count} {self._name} parameters.')
103 | self._once = False
104 |
105 | # Check loss.
106 | tf.debugging.check_numerics(loss, self._name + '_loss')
107 | metrics[f'{self._name}_loss'] = loss
108 |
109 | # Compute scaled gradient.
110 | if self._mixed:
111 | with tape:
112 | loss = self._opt.get_scaled_loss(loss)
113 | grads = tape.gradient(loss, varibs)
114 | if self._mixed:
115 | grads = self._opt.get_unscaled_gradients(grads)
116 | if self._mixed:
117 | metrics[f'{self._name}_loss_scale'] = self._opt.loss_scale
118 |
119 | # Distributed sync.
120 | context = tf.distribute.get_replica_context()
121 | if context:
122 | grads = context.all_reduce('mean', grads)
123 |
124 | # Gradient clipping.
125 | norm = tf.linalg.global_norm(grads)
126 | if not self._mixed:
127 | tf.debugging.check_numerics(norm, self._name + '_norm')
128 | if self._clip:
129 | grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
130 | metrics[f'{self._name}_grad_norm'] = norm
131 |
132 | # Weight decay.
133 | if self._wd:
134 | self._apply_weight_decay(varibs)
135 |
136 | # Apply gradients.
137 | self._opt.apply_gradients(
138 | zip(grads, varibs),
139 | experimental_aggregate_gradients=False)
140 |
141 | return metrics
142 |
143 | def _apply_weight_decay(self, varibs):
144 | nontrivial = (self._wd_pattern != r'.*')
145 | if nontrivial:
146 | print('Applied weight decay to variables:')
147 | for var in varibs:
148 | if re.search(self._wd_pattern, self._name + '/' + var.name):
149 | if nontrivial:
150 | print('- ' + self._name + '/' + var.name)
151 | var.assign((1 - self._wd) * var)
152 |
--------------------------------------------------------------------------------
/dreamerv2/common/when.py:
--------------------------------------------------------------------------------
1 | class Every:
2 |
3 | def __init__(self, every):
4 | self._every = every
5 | self._last = None
6 |
7 | def __call__(self, step):
8 | step = int(step)
9 | if not self._every:
10 | return False
11 | if self._last is None:
12 | self._last = step
13 | return True
14 | if step >= self._last + self._every:
15 | self._last += self._every
16 | return True
17 | return False
18 |
19 |
20 | class Once:
21 |
22 | def __init__(self):
23 | self._once = True
24 |
25 | def __call__(self):
26 | if self._once:
27 | self._once = False
28 | return True
29 | return False
30 |
31 |
32 | class Until:
33 |
34 | def __init__(self, until):
35 | self._until = until
36 |
37 | def __call__(self, step):
38 | step = int(step)
39 | if not self._until:
40 | return True
41 | return step < self._until
42 |
--------------------------------------------------------------------------------
/dreamerv2/configs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 |
3 | # Train Script
4 | logdir: /dev/null
5 | seed: 0
6 | task: dmc_walker_walk
7 | envs: 1
8 | envs_parallel: none
9 | render_size: [64, 64]
10 | dmc_camera: -1
11 | atari_grayscale: True
12 | time_limit: 0
13 | action_repeat: 1
14 | steps: 1e8
15 | log_every: 1e4
16 | eval_every: 1e5
17 | eval_eps: 1
18 | prefill: 10000
19 | pretrain: 1
20 | train_every: 5
21 | train_steps: 1
22 | expl_until: 0
23 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True}
24 | dataset: {batch: 16, length: 50}
25 | log_keys_video: ['image']
26 | log_keys_sum: '^$'
27 | log_keys_mean: '^$'
28 | log_keys_max: '^$'
29 | precision: 16
30 | jit: True
31 |
32 | # Agent
33 | clip_rewards: tanh
34 | expl_behavior: greedy
35 | expl_noise: 0.0
36 | eval_noise: 0.0
37 | eval_state_mean: False
38 |
39 | # World Model
40 | grad_heads: [decoder, reward, discount]
41 | pred_discount: True
42 | rssm: {ensemble: 1, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1}
43 | encoder: {mlp_keys: '.*', cnn_keys: '.*', act: elu, norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]}
44 | decoder: {mlp_keys: '.*', cnn_keys: '.*', act: elu, norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]}
45 | reward_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse}
46 | discount_head: {layers: 4, units: 400, act: elu, norm: none, dist: binary}
47 | loss_scales: {kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0}
48 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True}
49 | model_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100, wd: 1e-6}
50 |
51 | # Actor Critic
52 | actor: {layers: 4, units: 400, act: elu, norm: none, dist: auto, min_std: 0.1}
53 | critic: {layers: 4, units: 400, act: elu, norm: none, dist: mse}
54 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6}
55 | critic_opt: {opt: adam, lr: 2e-4, eps: 1e-5, clip: 100, wd: 1e-6}
56 | discount: 0.99
57 | discount_lambda: 0.95
58 | imag_horizon: 15
59 | actor_grad: auto
60 | actor_grad_mix: 0.1
61 | actor_ent: 2e-3
62 | slow_target: True
63 | slow_target_update: 100
64 | slow_target_fraction: 1
65 | slow_baseline: True
66 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
67 |
68 | # Exploration
69 | expl_intr_scale: 1.0
70 | expl_extr_scale: 0.0
71 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6}
72 | expl_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse}
73 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
74 | disag_target: stoch
75 | disag_log: False
76 | disag_models: 10
77 | disag_offset: 1
78 | disag_action_cond: True
79 | expl_model_loss: kl
80 |
81 | atari:
82 |
83 | task: atari_pong
84 | encoder: {mlp_keys: '$^', cnn_keys: 'image'}
85 | decoder: {mlp_keys: '$^', cnn_keys: 'image'}
86 | time_limit: 27000
87 | action_repeat: 4
88 | steps: 5e7
89 | eval_every: 2.5e5
90 | log_every: 1e4
91 | prefill: 50000
92 | train_every: 16
93 | clip_rewards: tanh
94 | rssm: {hidden: 600, deter: 600}
95 | model_opt.lr: 2e-4
96 | actor_opt.lr: 4e-5
97 | critic_opt.lr: 1e-4
98 | actor_ent: 1e-3
99 | discount: 0.999
100 | loss_scales.kl: 0.1
101 | loss_scales.discount: 5.0
102 |
103 | crafter:
104 |
105 | task: crafter_reward
106 | encoder: {mlp_keys: '$^', cnn_keys: 'image'}
107 | decoder: {mlp_keys: '$^', cnn_keys: 'image'}
108 | log_keys_max: '^log_achievement_.*'
109 | log_keys_sum: '^log_reward$'
110 | rssm: {hidden: 1024, deter: 1024}
111 | discount: 0.999
112 | model_opt.lr: 1e-4
113 | actor_opt.lr: 1e-4
114 | critic_opt.lr: 1e-4
115 | actor_ent: 3e-3
116 | .*\.norm: layer
117 |
118 | dmc_vision:
119 |
120 | task: dmc_walker_walk
121 | encoder: {mlp_keys: '$^', cnn_keys: 'image'}
122 | decoder: {mlp_keys: '$^', cnn_keys: 'image'}
123 | action_repeat: 2
124 | eval_every: 1e4
125 | prefill: 1000
126 | pretrain: 100
127 | clip_rewards: identity
128 | pred_discount: False
129 | replay.prioritize_ends: False
130 | grad_heads: [decoder, reward]
131 | rssm: {hidden: 200, deter: 200}
132 | model_opt.lr: 3e-4
133 | actor_opt.lr: 8e-5
134 | critic_opt.lr: 8e-5
135 | actor_ent: 1e-4
136 | kl.free: 1.0
137 |
138 | dmc_proprio:
139 |
140 | task: dmc_walker_walk
141 | encoder: {mlp_keys: '.*', cnn_keys: '$^'}
142 | decoder: {mlp_keys: '.*', cnn_keys: '$^'}
143 | action_repeat: 2
144 | eval_every: 1e4
145 | prefill: 1000
146 | pretrain: 100
147 | clip_rewards: identity
148 | pred_discount: False
149 | replay.prioritize_ends: False
150 | grad_heads: [decoder, reward]
151 | rssm: {hidden: 200, deter: 200}
152 | model_opt.lr: 3e-4
153 | actor_opt.lr: 8e-5
154 | critic_opt.lr: 8e-5
155 | actor_ent: 1e-4
156 | kl.free: 1.0
157 |
158 | debug:
159 |
160 | jit: False
161 | time_limit: 100
162 | eval_every: 300
163 | log_every: 300
164 | prefill: 100
165 | pretrain: 1
166 | train_steps: 1
167 | replay: {minlen: 10, maxlen: 30}
168 | dataset: {batch: 10, length: 10}
169 |
--------------------------------------------------------------------------------
/dreamerv2/expl.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow_probability import distributions as tfd
3 |
4 | import agent
5 | import common
6 |
7 |
8 | class Random(common.Module):
9 |
10 | def __init__(self, config, act_space, wm, tfstep, reward):
11 | self.config = config
12 | self.act_space = self.act_space
13 |
14 | def actor(self, feat):
15 | shape = feat.shape[:-1] + self.act_space.shape
16 | if self.config.actor.dist == 'onehot':
17 | return common.OneHotDist(tf.zeros(shape))
18 | else:
19 | dist = tfd.Uniform(-tf.ones(shape), tf.ones(shape))
20 | return tfd.Independent(dist, 1)
21 |
22 | def train(self, start, context, data):
23 | return None, {}
24 |
25 |
26 | class Plan2Explore(common.Module):
27 |
28 | def __init__(self, config, act_space, wm, tfstep, reward):
29 | self.config = config
30 | self.reward = reward
31 | self.wm = wm
32 | self.ac = agent.ActorCritic(config, act_space, tfstep)
33 | self.actor = self.ac.actor
34 | stoch_size = config.rssm.stoch
35 | if config.rssm.discrete:
36 | stoch_size *= config.rssm.discrete
37 | size = {
38 | 'embed': 32 * config.encoder.cnn_depth,
39 | 'stoch': stoch_size,
40 | 'deter': config.rssm.deter,
41 | 'feat': config.rssm.stoch + config.rssm.deter,
42 | }[self.config.disag_target]
43 | self._networks = [
44 | common.MLP(size, **config.expl_head)
45 | for _ in range(config.disag_models)]
46 | self.opt = common.Optimizer('expl', **config.expl_opt)
47 | self.extr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm)
48 | self.intr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm)
49 |
50 | def train(self, start, context, data):
51 | metrics = {}
52 | stoch = start['stoch']
53 | if self.config.rssm.discrete:
54 | stoch = tf.reshape(
55 | stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]))
56 | target = {
57 | 'embed': context['embed'],
58 | 'stoch': stoch,
59 | 'deter': start['deter'],
60 | 'feat': context['feat'],
61 | }[self.config.disag_target]
62 | inputs = context['feat']
63 | if self.config.disag_action_cond:
64 | action = tf.cast(data['action'], inputs.dtype)
65 | inputs = tf.concat([inputs, action], -1)
66 | metrics.update(self._train_ensemble(inputs, target))
67 | metrics.update(self.ac.train(
68 | self.wm, start, data['is_terminal'], self._intr_reward))
69 | return None, metrics
70 |
71 | def _intr_reward(self, seq):
72 | inputs = seq['feat']
73 | if self.config.disag_action_cond:
74 | action = tf.cast(seq['action'], inputs.dtype)
75 | inputs = tf.concat([inputs, action], -1)
76 | preds = [head(inputs).mode() for head in self._networks]
77 | disag = tf.tensor(preds).std(0).mean(-1)
78 | if self.config.disag_log:
79 | disag = tf.math.log(disag)
80 | reward = self.config.expl_intr_scale * self.intr_rewnorm(disag)[0]
81 | if self.config.expl_extr_scale:
82 | reward += self.config.expl_extr_scale * self.extr_rewnorm(
83 | self.reward(seq))[0]
84 | return reward
85 |
86 | def _train_ensemble(self, inputs, targets):
87 | if self.config.disag_offset:
88 | targets = targets[:, self.config.disag_offset:]
89 | inputs = inputs[:, :-self.config.disag_offset]
90 | targets = tf.stop_gradient(targets)
91 | inputs = tf.stop_gradient(inputs)
92 | with tf.GradientTape() as tape:
93 | preds = [head(inputs) for head in self._networks]
94 | loss = -sum([pred.log_prob(targets).mean() for pred in preds])
95 | metrics = self.opt(tape, loss, self._networks)
96 | return metrics
97 |
98 |
99 | class ModelLoss(common.Module):
100 |
101 | def __init__(self, config, act_space, wm, tfstep, reward):
102 | self.config = config
103 | self.reward = reward
104 | self.wm = wm
105 | self.ac = agent.ActorCritic(config, act_space, tfstep)
106 | self.actor = self.ac.actor
107 | self.head = common.MLP([], **self.config.expl_head)
108 | self.opt = common.Optimizer('expl', **self.config.expl_opt)
109 |
110 | def train(self, start, context, data):
111 | metrics = {}
112 | target = tf.cast(context[self.config.expl_model_loss], tf.float32)
113 | with tf.GradientTape() as tape:
114 | loss = -self.head(context['feat']).log_prob(target).mean()
115 | metrics.update(self.opt(tape, loss, self.head))
116 | metrics.update(self.ac.train(
117 | self.wm, start, data['is_terminal'], self._intr_reward))
118 | return None, metrics
119 |
120 | def _intr_reward(self, seq):
121 | reward = self.config.expl_intr_scale * self.head(seq['feat']).mode()
122 | if self.config.expl_extr_scale:
123 | reward += self.config.expl_extr_scale * self.reward(seq)
124 | return reward
125 |
--------------------------------------------------------------------------------
/dreamerv2/train.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import functools
3 | import logging
4 | import os
5 | import pathlib
6 | import re
7 | import sys
8 | import warnings
9 |
10 | try:
11 | import rich.traceback
12 | rich.traceback.install()
13 | except ImportError:
14 | pass
15 |
16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
17 | logging.getLogger().setLevel('ERROR')
18 | warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
19 |
20 | sys.path.append(str(pathlib.Path(__file__).parent))
21 | sys.path.append(str(pathlib.Path(__file__).parent.parent))
22 |
23 | import numpy as np
24 | import ruamel.yaml as yaml
25 |
26 | import agent
27 | import common
28 |
29 |
30 | def main():
31 |
32 | configs = yaml.safe_load((
33 | pathlib.Path(sys.argv[0]).parent / 'configs.yaml').read_text())
34 | parsed, remaining = common.Flags(configs=['defaults']).parse(known_only=True)
35 | config = common.Config(configs['defaults'])
36 | for name in parsed.configs:
37 | config = config.update(configs[name])
38 | config = common.Flags(config).parse(remaining)
39 |
40 | logdir = pathlib.Path(config.logdir).expanduser()
41 | logdir.mkdir(parents=True, exist_ok=True)
42 | config.save(logdir / 'config.yaml')
43 | print(config, '\n')
44 | print('Logdir', logdir)
45 |
46 | import tensorflow as tf
47 | tf.config.experimental_run_functions_eagerly(not config.jit)
48 | message = 'No GPU found. To actually train on CPU remove this assert.'
49 | assert tf.config.experimental.list_physical_devices('GPU'), message
50 | for gpu in tf.config.experimental.list_physical_devices('GPU'):
51 | tf.config.experimental.set_memory_growth(gpu, True)
52 | assert config.precision in (16, 32), config.precision
53 | if config.precision == 16:
54 | from tensorflow.keras.mixed_precision import experimental as prec
55 | prec.set_policy(prec.Policy('mixed_float16'))
56 |
57 | train_replay = common.Replay(logdir / 'train_episodes', **config.replay)
58 | eval_replay = common.Replay(logdir / 'eval_episodes', **dict(
59 | capacity=config.replay.capacity // 10,
60 | minlen=config.dataset.length,
61 | maxlen=config.dataset.length))
62 | step = common.Counter(train_replay.stats['total_steps'])
63 | outputs = [
64 | common.TerminalOutput(),
65 | common.JSONLOutput(logdir),
66 | common.TensorBoardOutput(logdir),
67 | ]
68 | logger = common.Logger(step, outputs, multiplier=config.action_repeat)
69 | metrics = collections.defaultdict(list)
70 |
71 | should_train = common.Every(config.train_every)
72 | should_log = common.Every(config.log_every)
73 | should_video_train = common.Every(config.eval_every)
74 | should_video_eval = common.Every(config.eval_every)
75 | should_expl = common.Until(config.expl_until // config.action_repeat)
76 |
77 | def make_env(mode):
78 | suite, task = config.task.split('_', 1)
79 | if suite == 'dmc':
80 | env = common.DMC(
81 | task, config.action_repeat, config.render_size, config.dmc_camera)
82 | env = common.NormalizeAction(env)
83 | elif suite == 'atari':
84 | env = common.Atari(
85 | task, config.action_repeat, config.render_size,
86 | config.atari_grayscale)
87 | env = common.OneHotAction(env)
88 | elif suite == 'crafter':
89 | assert config.action_repeat == 1
90 | outdir = logdir / 'crafter' if mode == 'train' else None
91 | reward = bool(['noreward', 'reward'].index(task)) or mode == 'eval'
92 | env = common.Crafter(outdir, reward)
93 | env = common.OneHotAction(env)
94 | else:
95 | raise NotImplementedError(suite)
96 | env = common.TimeLimit(env, config.time_limit)
97 | return env
98 |
99 | def per_episode(ep, mode):
100 | length = len(ep['reward']) - 1
101 | score = float(ep['reward'].astype(np.float64).sum())
102 | print(f'{mode.title()} episode has {length} steps and return {score:.1f}.')
103 | logger.scalar(f'{mode}_return', score)
104 | logger.scalar(f'{mode}_length', length)
105 | for key, value in ep.items():
106 | if re.match(config.log_keys_sum, key):
107 | logger.scalar(f'sum_{mode}_{key}', ep[key].sum())
108 | if re.match(config.log_keys_mean, key):
109 | logger.scalar(f'mean_{mode}_{key}', ep[key].mean())
110 | if re.match(config.log_keys_max, key):
111 | logger.scalar(f'max_{mode}_{key}', ep[key].max(0).mean())
112 | should = {'train': should_video_train, 'eval': should_video_eval}[mode]
113 | if should(step):
114 | for key in config.log_keys_video:
115 | logger.video(f'{mode}_policy_{key}', ep[key])
116 | replay = dict(train=train_replay, eval=eval_replay)[mode]
117 | logger.add(replay.stats, prefix=mode)
118 | logger.write()
119 |
120 | print('Create envs.')
121 | num_eval_envs = min(config.envs, config.eval_eps)
122 | if config.envs_parallel == 'none':
123 | train_envs = [make_env('train') for _ in range(config.envs)]
124 | eval_envs = [make_env('eval') for _ in range(num_eval_envs)]
125 | else:
126 | make_async_env = lambda mode: common.Async(
127 | functools.partial(make_env, mode), config.envs_parallel)
128 | train_envs = [make_async_env('train') for _ in range(config.envs)]
129 | eval_envs = [make_async_env('eval') for _ in range(eval_envs)]
130 | act_space = train_envs[0].act_space
131 | obs_space = train_envs[0].obs_space
132 | train_driver = common.Driver(train_envs)
133 | train_driver.on_episode(lambda ep: per_episode(ep, mode='train'))
134 | train_driver.on_step(lambda tran, worker: step.increment())
135 | train_driver.on_step(train_replay.add_step)
136 | train_driver.on_reset(train_replay.add_step)
137 | eval_driver = common.Driver(eval_envs)
138 | eval_driver.on_episode(lambda ep: per_episode(ep, mode='eval'))
139 | eval_driver.on_episode(eval_replay.add_episode)
140 |
141 | prefill = max(0, config.prefill - train_replay.stats['total_steps'])
142 | if prefill:
143 | print(f'Prefill dataset ({prefill} steps).')
144 | random_agent = common.RandomAgent(act_space)
145 | train_driver(random_agent, steps=prefill, episodes=1)
146 | eval_driver(random_agent, episodes=1)
147 | train_driver.reset()
148 | eval_driver.reset()
149 |
150 | print('Create agent.')
151 | train_dataset = iter(train_replay.dataset(**config.dataset))
152 | report_dataset = iter(train_replay.dataset(**config.dataset))
153 | eval_dataset = iter(eval_replay.dataset(**config.dataset))
154 | agnt = agent.Agent(config, obs_space, act_space, step)
155 | train_agent = common.CarryOverState(agnt.train)
156 | train_agent(next(train_dataset))
157 | if (logdir / 'variables.pkl').exists():
158 | agnt.load(logdir / 'variables.pkl')
159 | else:
160 | print('Pretrain agent.')
161 | for _ in range(config.pretrain):
162 | train_agent(next(train_dataset))
163 | train_policy = lambda *args: agnt.policy(
164 | *args, mode='explore' if should_expl(step) else 'train')
165 | eval_policy = lambda *args: agnt.policy(*args, mode='eval')
166 |
167 | def train_step(tran, worker):
168 | if should_train(step):
169 | for _ in range(config.train_steps):
170 | mets = train_agent(next(train_dataset))
171 | [metrics[key].append(value) for key, value in mets.items()]
172 | if should_log(step):
173 | for name, values in metrics.items():
174 | logger.scalar(name, np.array(values, np.float64).mean())
175 | metrics[name].clear()
176 | logger.add(agnt.report(next(report_dataset)), prefix='train')
177 | logger.write(fps=True)
178 | train_driver.on_step(train_step)
179 |
180 | while step < config.steps:
181 | logger.write()
182 | print('Start evaluation.')
183 | logger.add(agnt.report(next(eval_dataset)), prefix='eval')
184 | eval_driver(eval_policy, episodes=config.eval_eps)
185 | print('Start training.')
186 | train_driver(train_policy, steps=config.eval_every)
187 | agnt.save(logdir / 'variables.pkl')
188 | for env in train_envs + eval_envs:
189 | try:
190 | env.close()
191 | except Exception:
192 | pass
193 |
194 |
195 | if __name__ == '__main__':
196 | main()
197 |
--------------------------------------------------------------------------------
/examples/minigrid.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import gym_minigrid
3 | import dreamerv2.api as dv2
4 |
5 | config = dv2.defaults.update({
6 | 'logdir': '~/logdir/minigrid',
7 | 'log_every': 1e3,
8 | 'train_every': 10,
9 | 'prefill': 1e5,
10 | 'actor_ent': 3e-3,
11 | 'loss_scales.kl': 1.0,
12 | 'discount': 0.99,
13 | }).parse_flags()
14 |
15 | env = gym.make('MiniGrid-DoorKey-6x6-v0')
16 | env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(env)
17 | dv2.train(env, config)
18 |
--------------------------------------------------------------------------------
/scores/baselines.json:
--------------------------------------------------------------------------------
1 | {"atari_alien": {"c51_sticky_2e8": 2474.88, "ddqn_determ_2e8": 3747.7, "dqn_determ_2e8": 1620.0, "dqn_sticky_2e8": 2484.49, "dueling_determ_2e8": 4461.4, "human_gamer": 7127.7, "human_record": 251916.0, "iqn_sticky_2e8": 4961.42, "prioritized_determ_2e8": 4203.8, "rainbow_determ_2e8": 9491.7, "rainbow_determ_4e5": 739.9, "rainbow_sticky_2e8": 3456.95, "random": 228.8, "reactor_determ_2e8": 4199.4, "simple_determ_4e5": 405.2}, "atari_amidar": {"c51_sticky_2e8": 1363.18, "ddqn_determ_2e8": 1793.3, "dqn_determ_2e8": 978.0, "dqn_sticky_2e8": 1207.74, "dueling_determ_2e8": 2354.5, "human_gamer": 1719.5, "human_record": 104159.0, "iqn_sticky_2e8": 2393.43, "prioritized_determ_2e8": 1838.9, "rainbow_determ_2e8": 5131.2, "rainbow_determ_4e5": 188.6, "rainbow_sticky_2e8": 2529.09, "random": 5.8, "reactor_determ_2e8": 1546.8, "simple_determ_4e5": 88.0}, "atari_assault": {"c51_sticky_2e8": 1800.12, "ddqn_determ_2e8": 5393.2, "dqn_determ_2e8": 4280.4, "dqn_sticky_2e8": 1525.24, "dueling_determ_2e8": 4621.0, "human_gamer": 742.0, "human_record": 8647.0, "iqn_sticky_2e8": 4884.78, "prioritized_determ_2e8": 7672.1, "rainbow_determ_2e8": 14198.5, "rainbow_determ_4e5": 431.2, "rainbow_sticky_2e8": 3228.82, "random": 222.4, "reactor_determ_2e8": 17543.8, "simple_determ_4e5": 369.3}, "atari_asterix": {"c51_sticky_2e8": 13453.37, "ddqn_determ_2e8": 17356.5, "dqn_determ_2e8": 4359.0, "dqn_sticky_2e8": 2711.41, "dueling_determ_2e8": 28188.0, "human_gamer": 8503.3, "human_record": 1000000.0, "iqn_sticky_2e8": 10374.44, "prioritized_determ_2e8": 31527.0, "rainbow_determ_2e8": 428200.3, "rainbow_determ_4e5": 470.8, "rainbow_sticky_2e8": 18366.58, "random": 210.0, "reactor_determ_2e8": 16121.0, "simple_determ_4e5": 1089.5}, "atari_asteroids": {"c51_sticky_2e8": 1079.58, "ddqn_determ_2e8": 734.7, "dqn_determ_2e8": 1364.5, "dqn_sticky_2e8": 698.37, "dueling_determ_2e8": 2837.7, "human_gamer": 47388.7, "human_record": 10506650.0, "iqn_sticky_2e8": 1584.51, "prioritized_determ_2e8": 2654.3, "rainbow_determ_2e8": 2712.8, "rainbow_sticky_2e8": 1483.54, "random": 719.1, "reactor_determ_2e8": 4467.4}, "atari_atlantis": {"c51_sticky_2e8": 824994.31, "ddqn_determ_2e8": 106056.0, "dqn_determ_2e8": 279987.0, "dqn_sticky_2e8": 853640.0, "dueling_determ_2e8": 382572.0, "human_gamer": 29028.1, "human_record": 10604840.0, "iqn_sticky_2e8": 890214.0, "prioritized_determ_2e8": 357324.0, "rainbow_determ_2e8": 826659.5, "rainbow_sticky_2e8": 802548.0, "random": 12850.0, "reactor_determ_2e8": 968179.5}, "atari_bank_heist": {"c51_sticky_2e8": 757.65, "ddqn_determ_2e8": 1030.6, "dqn_determ_2e8": 455.0, "dqn_sticky_2e8": 601.79, "dueling_determ_2e8": 1611.9, "human_gamer": 753.1, "human_record": 82058.0, "iqn_sticky_2e8": 1052.27, "prioritized_determ_2e8": 1054.6, "rainbow_determ_2e8": 1358.0, "rainbow_determ_4e5": 51.0, "rainbow_sticky_2e8": 1074.97, "random": 14.2, "reactor_determ_2e8": 1236.8, "simple_determ_4e5": 8.2}, "atari_battle_zone": {"c51_sticky_2e8": 25501.76, "ddqn_determ_2e8": 31700.0, "dqn_determ_2e8": 29900.0, "dqn_sticky_2e8": 17784.84, "dueling_determ_2e8": 37150.0, "human_gamer": 37187.5, "human_record": 801000.0, "iqn_sticky_2e8": 40953.44, "prioritized_determ_2e8": 31530.0, "rainbow_determ_2e8": 62010.0, "rainbow_determ_4e5": 10124.6, "rainbow_sticky_2e8": 40060.59, "random": 2360.0, "reactor_determ_2e8": 98235.0, "simple_determ_4e5": 5184.4}, "atari_beam_rider": {"c51_sticky_2e8": 5331.59, "ddqn_determ_2e8": 13772.8, "dqn_determ_2e8": 8627.5, "dqn_sticky_2e8": 5852.42, "dueling_determ_2e8": 12164.0, "human_gamer": 16926.5, "human_record": 999999.0, "iqn_sticky_2e8": 7130.45, "prioritized_determ_2e8": 23384.2, "rainbow_determ_2e8": 16850.2, "rainbow_sticky_2e8": 6290.48, "random": 363.9, "reactor_determ_2e8": 8811.8}, "atari_berzerk": {"c51_sticky_2e8": 516.69, "ddqn_determ_2e8": 1225.4, "dqn_determ_2e8": 585.6, "dqn_sticky_2e8": 487.48, "dueling_determ_2e8": 1472.6, "human_gamer": 2630.4, "human_record": 1057940.0, "iqn_sticky_2e8": 648.4, "prioritized_determ_2e8": 1305.6, "rainbow_determ_2e8": 2545.6, "rainbow_sticky_2e8": 833.41, "random": 123.7, "reactor_determ_2e8": 1515.7}, "atari_bowling": {"c51_sticky_2e8": 32.1, "ddqn_determ_2e8": 68.1, "dqn_determ_2e8": 50.4, "dqn_sticky_2e8": 30.12, "dueling_determ_2e8": 65.5, "human_gamer": 160.7, "human_record": 300.0, "iqn_sticky_2e8": 39.44, "prioritized_determ_2e8": 47.9, "rainbow_determ_2e8": 30.0, "rainbow_sticky_2e8": 42.86, "random": 23.1, "reactor_determ_2e8": 59.3}, "atari_boxing": {"c51_sticky_2e8": 81.4, "ddqn_determ_2e8": 91.6, "dqn_determ_2e8": 88.0, "dqn_sticky_2e8": 77.99, "dueling_determ_2e8": 99.4, "human_gamer": 12.1, "human_record": 100.0, "iqn_sticky_2e8": 97.91, "prioritized_determ_2e8": 95.6, "rainbow_determ_2e8": 99.6, "rainbow_determ_4e5": 0.2, "rainbow_sticky_2e8": 98.56, "random": 0.1, "reactor_determ_2e8": 99.7, "simple_determ_4e5": 9.1}, "atari_breakout": {"c51_sticky_2e8": 202.39, "ddqn_determ_2e8": 418.5, "dqn_determ_2e8": 385.5, "dqn_sticky_2e8": 96.23, "dueling_determ_2e8": 345.3, "human_gamer": 30.5, "human_record": 864.0, "iqn_sticky_2e8": 79.48, "prioritized_determ_2e8": 373.9, "rainbow_determ_2e8": 417.5, "rainbow_determ_4e5": 1.9, "rainbow_sticky_2e8": 120.07, "random": 1.7, "reactor_determ_2e8": 509.5, "simple_determ_4e5": 12.7}, "atari_centipede": {"c51_sticky_2e8": 7594.53, "ddqn_determ_2e8": 5409.4, "dqn_determ_2e8": 4657.7, "dqn_sticky_2e8": 2583.03, "dueling_determ_2e8": 7561.4, "human_gamer": 12017.0, "human_record": 1301709.0, "iqn_sticky_2e8": 3727.64, "prioritized_determ_2e8": 4463.2, "rainbow_determ_2e8": 8167.3, "rainbow_sticky_2e8": 6509.87, "random": 2090.9, "reactor_determ_2e8": 7267.2}, "atari_chopper_command": {"c51_sticky_2e8": 5718.16, "ddqn_determ_2e8": 5809.0, "dqn_determ_2e8": 6126.0, "dqn_sticky_2e8": 2690.61, "dueling_determ_2e8": 11215.0, "human_gamer": 7387.8, "human_record": 999999.0, "iqn_sticky_2e8": 9281.96, "prioritized_determ_2e8": 8600.0, "rainbow_determ_2e8": 16654.0, "rainbow_determ_4e5": 861.8, "rainbow_sticky_2e8": 12337.51, "random": 811.0, "reactor_determ_2e8": 19901.5, "simple_determ_4e5": 1246.9}, "atari_crazy_climber": {"c51_sticky_2e8": 136041.14, "ddqn_determ_2e8": 117282.0, "dqn_determ_2e8": 110763.0, "dqn_sticky_2e8": 104568.76, "dueling_determ_2e8": 143570.0, "human_gamer": 35829.4, "human_record": 219900.0, "iqn_sticky_2e8": 132738.13, "prioritized_determ_2e8": 141161.0, "rainbow_determ_2e8": 168788.5, "rainbow_determ_4e5": 16185.3, "rainbow_sticky_2e8": 145389.29, "random": 10780.5, "reactor_determ_2e8": 173274.0, "simple_determ_4e5": 39827.8}, "atari_defender": {"ddqn_determ_2e8": 35338.5, "dqn_determ_2e8": 23633.0, "dueling_determ_2e8": 42214.0, "human_gamer": 18688.9, "human_record": 6010500.0, "prioritized_determ_2e8": 31286.5, "rainbow_determ_2e8": 55105.0, "random": 2874.5, "reactor_determ_2e8": 181074.3}, "atari_demon_attack": {"c51_sticky_2e8": 7098.0, "ddqn_determ_2e8": 58044.2, "dqn_determ_2e8": 12149.4, "dqn_sticky_2e8": 6361.58, "dueling_determ_2e8": 60813.3, "human_gamer": 1971.0, "human_record": 1556345.0, "iqn_sticky_2e8": 15350.43, "prioritized_determ_2e8": 71846.4, "rainbow_determ_2e8": 111185.2, "rainbow_determ_4e5": 508.0, "rainbow_sticky_2e8": 17071.26, "random": 152.1, "reactor_determ_2e8": 122782.5, "simple_determ_4e5": 169.5}, "atari_double_dunk": {"c51_sticky_2e8": 1.28, "ddqn_determ_2e8": -5.5, "dqn_determ_2e8": -6.6, "dqn_sticky_2e8": -6.54, "dueling_determ_2e8": 0.1, "human_gamer": -16.4, "human_record": 22.0, "iqn_sticky_2e8": 21.12, "prioritized_determ_2e8": 18.5, "rainbow_determ_2e8": -0.3, "rainbow_sticky_2e8": 22.12, "random": -18.6, "reactor_determ_2e8": 23.0}, "atari_enduro": {"c51_sticky_2e8": 1283.8, "ddqn_determ_2e8": 1211.8, "dqn_determ_2e8": 729.0, "dqn_sticky_2e8": 628.91, "dueling_determ_2e8": 2258.2, "human_gamer": 860.5, "human_record": 9500.0, "iqn_sticky_2e8": 2202.98, "prioritized_determ_2e8": 2093.0, "rainbow_determ_2e8": 2125.9, "rainbow_sticky_2e8": 2200.16, "random": 0.0, "reactor_determ_2e8": 2211.3}, "atari_fishing_derby": {"c51_sticky_2e8": 10.4, "ddqn_determ_2e8": 15.5, "dqn_determ_2e8": -4.9, "dqn_sticky_2e8": 0.58, "dueling_determ_2e8": 46.4, "human_gamer": -38.7, "human_record": 71.0, "iqn_sticky_2e8": 45.33, "prioritized_determ_2e8": 39.5, "rainbow_determ_2e8": 31.3, "rainbow_sticky_2e8": 41.8, "random": -91.7, "reactor_determ_2e8": 33.1}, "atari_freeway": {"c51_sticky_2e8": 33.68, "ddqn_determ_2e8": 33.3, "dqn_determ_2e8": 30.8, "dqn_sticky_2e8": 26.27, "dueling_determ_2e8": 0.0, "human_gamer": 29.6, "human_record": 38.0, "iqn_sticky_2e8": 33.67, "prioritized_determ_2e8": 33.7, "rainbow_determ_2e8": 34.0, "rainbow_determ_4e5": 27.9, "rainbow_sticky_2e8": 33.69, "random": 0.0, "reactor_determ_2e8": 22.3, "simple_determ_4e5": 20.3}, "atari_frostbite": {"c51_sticky_2e8": 4284.99, "ddqn_determ_2e8": 1683.3, "dqn_determ_2e8": 797.4, "dqn_sticky_2e8": 367.07, "dueling_determ_2e8": 4672.8, "human_gamer": 4334.7, "human_record": 454830.0, "iqn_sticky_2e8": 7811.57, "prioritized_determ_2e8": 4380.1, "rainbow_determ_2e8": 9590.5, "rainbow_determ_4e5": 866.8, "rainbow_sticky_2e8": 8207.72, "random": 65.2, "reactor_determ_2e8": 7136.7, "simple_determ_4e5": 254.7}, "atari_gopher": {"c51_sticky_2e8": 7172.88, "ddqn_determ_2e8": 14840.8, "dqn_determ_2e8": 8777.4, "dqn_sticky_2e8": 5479.9, "dueling_determ_2e8": 15718.4, "human_gamer": 2412.5, "human_record": 355040.0, "iqn_sticky_2e8": 12107.61, "prioritized_determ_2e8": 32487.2, "rainbow_determ_2e8": 70354.6, "rainbow_determ_4e5": 349.5, "rainbow_sticky_2e8": 10641.07, "random": 257.6, "reactor_determ_2e8": 36279.1, "simple_determ_4e5": 771.0}, "atari_gravitar": {"c51_sticky_2e8": 686.0, "ddqn_determ_2e8": 412.0, "dqn_determ_2e8": 473.0, "dqn_sticky_2e8": 330.07, "dueling_determ_2e8": 588.0, "human_gamer": 3351.4, "human_record": 162850.0, "iqn_sticky_2e8": 1346.62, "prioritized_determ_2e8": 548.5, "rainbow_determ_2e8": 1419.3, "rainbow_sticky_2e8": 1271.8, "random": 173.0, "reactor_determ_2e8": 1804.8}, "atari_hero": {"c51_sticky_2e8": 33753.23, "ddqn_determ_2e8": 20130.2, "dqn_determ_2e8": 20437.8, "dqn_sticky_2e8": 17325.44, "dueling_determ_2e8": 20818.2, "human_gamer": 30826.4, "human_record": 1000000.0, "iqn_sticky_2e8": 36058.41, "prioritized_determ_2e8": 23037.7, "rainbow_determ_2e8": 55887.4, "rainbow_determ_4e5": 6857.0, "rainbow_sticky_2e8": 46675.16, "random": 1027.0, "reactor_determ_2e8": 27833.0, "simple_determ_4e5": 1295.1}, "atari_ice_hockey": {"c51_sticky_2e8": -4.36, "ddqn_determ_2e8": -2.7, "dqn_determ_2e8": -1.9, "dqn_sticky_2e8": -5.84, "dueling_determ_2e8": 0.5, "human_gamer": 0.9, "human_record": 36.0, "iqn_sticky_2e8": -5.11, "prioritized_determ_2e8": 1.3, "rainbow_determ_2e8": 1.1, "rainbow_sticky_2e8": -0.18, "random": -11.2, "reactor_determ_2e8": 15.7}, "atari_james_bond": {"c51_sticky_2e8": 725.56, "ddqn_determ_2e8": 768.5, "dqn_determ_2e8": 302.8, "dqn_sticky_2e8": 573.31, "dueling_determ_2e8": 1358.0, "human_gamer": 29.0, "human_record": 45550.0, "iqn_sticky_2e8": 3165.7, "prioritized_determ_2e8": 1312.5, "rainbow_determ_2e8": 812.0, "rainbow_determ_4e5": 301.6, "rainbow_sticky_2e8": 1097.12, "random": 7.0, "reactor_determ_2e8": 19809.0, "simple_determ_4e5": 125.3}, "atari_kangaroo": {"c51_sticky_2e8": 7687.54, "ddqn_determ_2e8": 12992.0, "dqn_determ_2e8": 7259.0, "dqn_sticky_2e8": 11485.98, "dueling_determ_2e8": 14854.0, "human_gamer": 3035.0, "human_record": 1424600.0, "iqn_sticky_2e8": 12602.27, "prioritized_determ_2e8": 16200.0, "rainbow_determ_2e8": 14637.5, "rainbow_determ_4e5": 779.3, "rainbow_sticky_2e8": 12748.3, "random": 52.0, "reactor_determ_2e8": 13349.0, "simple_determ_4e5": 323.1}, "atari_krull": {"c51_sticky_2e8": 6770.45, "ddqn_determ_2e8": 7920.5, "dqn_determ_2e8": 8422.3, "dqn_sticky_2e8": 6097.63, "dueling_determ_2e8": 11451.9, "human_gamer": 2665.5, "human_record": 104100.0, "iqn_sticky_2e8": 8844.17, "prioritized_determ_2e8": 9728.0, "rainbow_determ_2e8": 8741.5, "rainbow_determ_4e5": 2851.5, "rainbow_sticky_2e8": 4065.97, "random": 1598.0, "reactor_determ_2e8": 10237.8, "simple_determ_4e5": 4539.9}, "atari_kung_fu_master": {"c51_sticky_2e8": 24811.6, "ddqn_determ_2e8": 29710.0, "dqn_determ_2e8": 26059.0, "dqn_sticky_2e8": 23435.38, "dueling_determ_2e8": 34294.0, "human_gamer": 22736.3, "human_record": 1000000.0, "iqn_sticky_2e8": 31653.37, "prioritized_determ_2e8": 39581.0, "rainbow_determ_2e8": 52181.0, "rainbow_determ_4e5": 14346.1, "rainbow_sticky_2e8": 26475.07, "random": 258.5, "reactor_determ_2e8": 61621.5, "simple_determ_4e5": 17257.2}, "atari_montezuma_revenge": {"c51_sticky_2e8": 1044.67, "ddqn_determ_2e8": 0.0, "dqn_determ_2e8": 0.0, "dqn_sticky_2e8": 0.0, "dueling_determ_2e8": 0.0, "human_gamer": 4753.3, "human_record": 1219200.0, "iqn_sticky_2e8": 500.0, "prioritized_determ_2e8": 0.0, "rainbow_determ_2e8": 384.0, "rainbow_sticky_2e8": 500.0, "random": 0.0, "reactor_determ_2e8": 0.0}, "atari_ms_pacman": {"c51_sticky_2e8": 3693.53, "ddqn_determ_2e8": 2711.4, "dqn_determ_2e8": 3085.6, "dqn_sticky_2e8": 3402.4, "dueling_determ_2e8": 6283.5, "human_gamer": 6951.6, "human_record": 290090.0, "iqn_sticky_2e8": 5217.99, "prioritized_determ_2e8": 6518.7, "rainbow_determ_2e8": 5380.4, "rainbow_determ_4e5": 1204.1, "rainbow_sticky_2e8": 3861.01, "random": 307.3, "reactor_determ_2e8": 4416.9, "simple_determ_4e5": 762.8}, "atari_name_this_game": {"c51_sticky_2e8": 12594.19, "ddqn_determ_2e8": 10616.0, "dqn_determ_2e8": 8207.8, "dqn_sticky_2e8": 7278.65, "dueling_determ_2e8": 11971.1, "human_gamer": 8049.0, "human_record": 25220.0, "iqn_sticky_2e8": 6638.61, "prioritized_determ_2e8": 12270.5, "rainbow_determ_2e8": 13136.0, "rainbow_sticky_2e8": 9025.75, "random": 2292.3, "reactor_determ_2e8": 12636.5}, "atari_phoenix": {"c51_sticky_2e8": 5849.25, "ddqn_determ_2e8": 12252.5, "dqn_determ_2e8": 8485.2, "dqn_sticky_2e8": 4996.58, "dueling_determ_2e8": 23092.2, "human_gamer": 7242.6, "human_record": 4014440.0, "iqn_sticky_2e8": 5102.03, "prioritized_determ_2e8": 18992.7, "rainbow_determ_2e8": 108528.6, "rainbow_sticky_2e8": 8545.35, "random": 761.4, "reactor_determ_2e8": 10261.4}, "atari_pitfall": {"c51_sticky_2e8": -18.39, "ddqn_determ_2e8": -29.9, "dqn_determ_2e8": -286.1, "dqn_sticky_2e8": -73.81, "dueling_determ_2e8": 0.0, "human_gamer": 6463.7, "human_record": 114000.0, "iqn_sticky_2e8": -13.16, "prioritized_determ_2e8": -356.5, "rainbow_determ_2e8": 0.0, "rainbow_sticky_2e8": -19.76, "random": -229.4, "reactor_determ_2e8": -3.7}, "atari_pong": {"c51_sticky_2e8": 19.34, "ddqn_determ_2e8": 20.9, "dqn_determ_2e8": 19.5, "dqn_sticky_2e8": 16.61, "dueling_determ_2e8": 21.0, "human_gamer": 14.6, "human_record": 21.0, "iqn_sticky_2e8": 20.13, "prioritized_determ_2e8": 20.6, "rainbow_determ_2e8": 20.9, "rainbow_determ_4e5": -19.3, "rainbow_sticky_2e8": 20.18, "random": -20.7, "reactor_determ_2e8": 20.7, "simple_determ_4e5": 5.2}, "atari_private_eye": {"c51_sticky_2e8": 4207.11, "ddqn_determ_2e8": 129.7, "dqn_determ_2e8": 146.7, "dqn_sticky_2e8": -16.04, "dueling_determ_2e8": 103.0, "human_gamer": 69571.3, "human_record": 101800.0, "iqn_sticky_2e8": 4180.76, "prioritized_determ_2e8": 200.0, "rainbow_determ_2e8": 4234.0, "rainbow_determ_4e5": 97.8, "rainbow_sticky_2e8": 21333.56, "random": 24.9, "reactor_determ_2e8": 15198.0, "simple_determ_4e5": 58.3}, "atari_qbert": {"c51_sticky_2e8": 9761.91, "ddqn_determ_2e8": 15088.5, "dqn_determ_2e8": 13117.3, "dqn_sticky_2e8": 10117.5, "dueling_determ_2e8": 19220.3, "human_gamer": 13455.0, "human_record": 2400000.0, "iqn_sticky_2e8": 16729.89, "prioritized_determ_2e8": 16256.5, "rainbow_determ_2e8": 33817.5, "rainbow_determ_4e5": 1152.9, "rainbow_sticky_2e8": 17382.94, "random": 163.9, "reactor_determ_2e8": 21222.5, "simple_determ_4e5": 559.8}, "atari_riverraid": {"c51_sticky_2e8": 13514.28, "ddqn_determ_2e8": 14884.5, "dqn_determ_2e8": 7377.6, "dqn_sticky_2e8": 11638.93, "dueling_determ_2e8": 21162.6, "human_gamer": 17118.0, "human_record": 1000000.0, "iqn_sticky_2e8": 15183.38, "prioritized_determ_2e8": 14522.3, "rainbow_determ_2e8": 22920.8, "rainbow_sticky_2e8": 20755.91, "random": 1338.5, "reactor_determ_2e8": 16957.3}, "atari_road_runner": {"c51_sticky_2e8": 48213.46, "ddqn_determ_2e8": 44127.0, "dqn_determ_2e8": 39544.0, "dqn_sticky_2e8": 36925.47, "dueling_determ_2e8": 69524.0, "human_gamer": 7845.0, "human_record": 2038100.0, "iqn_sticky_2e8": 58965.95, "prioritized_determ_2e8": 57608.0, "rainbow_determ_2e8": 62041.0, "rainbow_determ_4e5": 9600.0, "rainbow_sticky_2e8": 54662.07, "random": 11.5, "reactor_determ_2e8": 66790.5, "simple_determ_4e5": 5169.4}, "atari_robotank": {"c51_sticky_2e8": 60.53, "ddqn_determ_2e8": 65.1, "dqn_determ_2e8": 63.9, "dqn_sticky_2e8": 59.77, "dueling_determ_2e8": 65.3, "human_gamer": 11.9, "human_record": 76.0, "iqn_sticky_2e8": 65.72, "prioritized_determ_2e8": 62.6, "rainbow_determ_2e8": 61.4, "rainbow_sticky_2e8": 65.52, "random": 2.2, "reactor_determ_2e8": 71.8}, "atari_seaquest": {"c51_sticky_2e8": 31025.85, "ddqn_determ_2e8": 16452.7, "dqn_determ_2e8": 5860.6, "dqn_sticky_2e8": 1600.66, "dueling_determ_2e8": 50254.2, "human_gamer": 42054.7, "human_record": 999999.0, "iqn_sticky_2e8": 17039.16, "prioritized_determ_2e8": 26357.8, "rainbow_determ_2e8": 15898.9, "rainbow_determ_4e5": 354.1, "rainbow_sticky_2e8": 9903.38, "random": 68.4, "reactor_determ_2e8": 5071.6, "simple_determ_4e5": 370.9}, "atari_skiing": {"c51_sticky_2e8": -22231.09, "ddqn_determ_2e8": -9021.8, "dqn_determ_2e8": -13062.3, "dqn_sticky_2e8": -15824.61, "dueling_determ_2e8": -8857.4, "human_gamer": -4336.9, "human_record": -3272.0, "iqn_sticky_2e8": -11161.85, "prioritized_determ_2e8": -9996.9, "rainbow_determ_2e8": -12957.8, "rainbow_sticky_2e8": -28707.56, "random": -17098.1, "reactor_determ_2e8": -10632.9}, "atari_solaris": {"c51_sticky_2e8": 1961.66, "ddqn_determ_2e8": 3067.8, "dqn_determ_2e8": 3482.8, "dqn_sticky_2e8": 1436.4, "dueling_determ_2e8": 2250.8, "human_gamer": 12326.7, "human_record": 111420.0, "iqn_sticky_2e8": 1684.4, "prioritized_determ_2e8": 4309.0, "rainbow_determ_2e8": 3560.3, "rainbow_sticky_2e8": 1582.69, "random": 1236.3, "reactor_determ_2e8": 2236.0}, "atari_space_invaders": {"c51_sticky_2e8": 3983.6, "ddqn_determ_2e8": 2525.5, "dqn_determ_2e8": 1692.3, "dqn_sticky_2e8": 1794.24, "dueling_determ_2e8": 6427.3, "human_gamer": 1668.7, "human_record": 621535.0, "iqn_sticky_2e8": 4530.03, "prioritized_determ_2e8": 2865.8, "rainbow_determ_2e8": 18789.0, "rainbow_sticky_2e8": 4130.88, "random": 148.0, "reactor_determ_2e8": 2387.1}, "atari_star_gunner": {"c51_sticky_2e8": 34804.97, "ddqn_determ_2e8": 60142.0, "dqn_determ_2e8": 54282.0, "dqn_sticky_2e8": 42165.22, "dueling_determ_2e8": 89238.0, "human_gamer": 10250.0, "human_record": 77400.0, "iqn_sticky_2e8": 80002.58, "prioritized_determ_2e8": 63302.0, "rainbow_determ_2e8": 127029.0, "rainbow_sticky_2e8": 57908.72, "random": 664.0, "reactor_determ_2e8": 48942.0}, "atari_surround": {"ddqn_determ_2e8": -2.9, "dqn_determ_2e8": -5.6, "dueling_determ_2e8": 4.4, "human_gamer": 6.5, "prioritized_determ_2e8": 8.9, "rainbow_determ_2e8": 9.7, "random": -10.0, "reactor_determ_2e8": 0.9}, "atari_tennis": {"c51_sticky_2e8": 21.7, "ddqn_determ_2e8": -22.8, "dqn_determ_2e8": 12.2, "dqn_sticky_2e8": -1.5, "dueling_determ_2e8": 5.1, "human_gamer": -8.3, "human_record": 21.0, "iqn_sticky_2e8": 22.54, "prioritized_determ_2e8": 0.0, "rainbow_determ_2e8": 0.0, "rainbow_sticky_2e8": -0.22, "random": -23.8, "reactor_determ_2e8": 23.4}, "atari_time_pilot": {"c51_sticky_2e8": 8117.05, "ddqn_determ_2e8": 8339.0, "dqn_determ_2e8": 4870.0, "dqn_sticky_2e8": 3654.37, "dueling_determ_2e8": 11666.0, "human_gamer": 5229.2, "human_record": 65300.0, "iqn_sticky_2e8": 11666.23, "prioritized_determ_2e8": 9197.0, "rainbow_determ_2e8": 12926.0, "rainbow_sticky_2e8": 12050.53, "random": 3568.0, "reactor_determ_2e8": 18871.5}, "atari_tutankham": {"c51_sticky_2e8": 240.72, "ddqn_determ_2e8": 218.4, "dqn_determ_2e8": 68.1, "dqn_sticky_2e8": 103.84, "dueling_determ_2e8": 211.4, "human_gamer": 167.6, "human_record": 5384.0, "iqn_sticky_2e8": 251.41, "prioritized_determ_2e8": 204.6, "rainbow_determ_2e8": 241.0, "rainbow_sticky_2e8": 239.07, "random": 11.4, "reactor_determ_2e8": 263.2}, "atari_up_n_down": {"c51_sticky_2e8": 8279.26, "ddqn_determ_2e8": 22972.2, "dqn_determ_2e8": 9989.9, "dqn_sticky_2e8": 8488.31, "dueling_determ_2e8": 44939.6, "human_gamer": 11693.2, "human_record": 82840.0, "iqn_sticky_2e8": 59943.87, "prioritized_determ_2e8": 16154.1, "rainbow_determ_2e8": 125754.6, "rainbow_determ_4e5": 2877.4, "rainbow_sticky_2e8": 34887.76, "random": 533.4, "reactor_determ_2e8": 194989.5, "simple_determ_4e5": 2152.6}, "atari_venture": {"c51_sticky_2e8": 1339.41, "ddqn_determ_2e8": 98.0, "dqn_determ_2e8": 163.0, "dqn_sticky_2e8": 39.13, "dueling_determ_2e8": 497.0, "human_gamer": 1187.5, "human_record": 38900.0, "iqn_sticky_2e8": 1312.84, "prioritized_determ_2e8": 54.0, "rainbow_determ_2e8": 5.5, "rainbow_sticky_2e8": 1528.87, "random": 0.0, "reactor_determ_2e8": 0.0}, "atari_video_pinball": {"c51_sticky_2e8": 405173.28, "ddqn_determ_2e8": 309941.9, "dqn_determ_2e8": 196760.4, "dqn_sticky_2e8": 63406.11, "dueling_determ_2e8": 98209.5, "human_gamer": 17667.9, "human_record": 89218328.0, "iqn_sticky_2e8": 415833.17, "prioritized_determ_2e8": 282007.3, "rainbow_determ_2e8": 533936.5, "rainbow_sticky_2e8": 466894.97, "random": 16256.9, "reactor_determ_2e8": 261720.2}, "atari_wizard_of_wor": {"c51_sticky_2e8": 3229.23, "ddqn_determ_2e8": 7492.0, "dqn_determ_2e8": 2704.0, "dqn_sticky_2e8": 2065.8, "dueling_determ_2e8": 7855.0, "human_gamer": 4756.5, "human_record": 395300.0, "iqn_sticky_2e8": 5670.54, "prioritized_determ_2e8": 4802.0, "rainbow_determ_2e8": 17862.5, "rainbow_sticky_2e8": 7878.62, "random": 563.5, "reactor_determ_2e8": 18484.0}, "atari_yars_revenge": {"c51_sticky_2e8": 11746.61, "ddqn_determ_2e8": 11712.6, "dqn_determ_2e8": 18098.9, "dqn_sticky_2e8": 23909.38, "dueling_determ_2e8": 49622.1, "human_gamer": 54576.9, "human_record": 15000105.0, "iqn_sticky_2e8": 84144.01, "prioritized_determ_2e8": 11357.0, "rainbow_determ_2e8": 102557.0, "rainbow_sticky_2e8": 45542.03, "random": 3092.9, "reactor_determ_2e8": 109607.5}, "atari_zaxxon": {"c51_sticky_2e8": 6502.33, "ddqn_determ_2e8": 10163.0, "dqn_determ_2e8": 5363.0, "dqn_sticky_2e8": 4538.57, "dueling_determ_2e8": 12944.0, "human_gamer": 9173.3, "human_record": 83700.0, "iqn_sticky_2e8": 11022.93, "prioritized_determ_2e8": 10469.0, "rainbow_determ_2e8": 22209.5, "rainbow_sticky_2e8": 14603.02, "random": 32.5, "reactor_determ_2e8": 16525.0}, "atari_air_raid": {"c51_sticky_2e8": 8288.24, "dqn_sticky_2e8": 7479.5, "human_record": 23050.0, "iqn_sticky_2e8": 10658.59, "rainbow_sticky_2e8": 11668.75, "random": 579.25}, "atari_carnival": {"c51_sticky_2e8": 4809.81, "dqn_sticky_2e8": 4784.84, "human_record": 2541440.0, "iqn_sticky_2e8": 5616.85, "rainbow_sticky_2e8": 4533.1, "random": 700.8}, "atari_elevator_action": {"c51_sticky_2e8": 68332.0, "dqn_sticky_2e8": 439.77, "human_record": 156550.0, "iqn_sticky_2e8": 65484.0, "rainbow_sticky_2e8": 76652.0, "random": 4387.0}, "atari_journey_escape": {"c51_sticky_2e8": -2707.38, "dqn_sticky_2e8": -3671.09, "human_record": -4317804.0, "iqn_sticky_2e8": -1490.44, "rainbow_sticky_2e8": -1106.49, "random": -19977.0}, "atari_pooyan": {"c51_sticky_2e8": 2931.13, "dqn_sticky_2e8": 3211.96, "human_record": 13025.0, "iqn_sticky_2e8": 4595.22, "rainbow_sticky_2e8": 4494.43, "random": 371.2}}
2 |
--------------------------------------------------------------------------------
/scores/montezuma-dreamerv2.json:
--------------------------------------------------------------------------------
1 | [{"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "0", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 20.0, 60.0, 160.0, 336.3636363636364, 400.0, 400.0, 400.0, 400.0, 400.0, 460.0, 500.0, 520.0, 400.0, 400.0, 420.0, 400.0, 430.0, 450.0, 490.0, 430.0, 560.0, 490.0, 460.0, 430.0, 410.0, 408.3333333333333, 520.0, 440.0, 440.0, 430.0, 430.0, 430.0, 460.0, 460.0, 460.0, 490.0, 450.0, 460.0, 430.0, 440.0, 390.0, 450.0, 400.0, 400.0, 400.0, 440.0, 490.0, 430.0, 500.0, 500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2250.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2300.0, 2500.0, 2520.0, 2300.0, 2520.0, 2520.0, 2580.0, 2560.0, 2580.0, 2540.0, 2540.0, 2550.0, 2560.0, 2520.0, 2560.0, 2520.0, 1660.0, 2270.0, 2580.0, 2320.0, 2540.0, 2540.0, 2520.0, 2250.0, 2280.0, 2250.0, 2570.0, 2540.0, 1930.0, 2520.0, 2520.0, 2520.0, 2530.0, 2550.0, 2550.0, 2540.0, 2500.0, 2500.0, 2540.0, 2530.0, 2570.0, 2540.0, 2350.0, 2550.0, 2570.0, 2540.0, 2530.0, 2580.0, 2530.0, 2590.0, 2580.0, 2590.0, 2590.0, 2600.0, 2540.0, 2560.0, 2590.0, 2570.0, 2550.0, 2140.0, 1560.0, 2590.0, 2290.0, 1930.0, 2550.0, 2550.0, 2280.0, 2520.0, 2530.0, 2530.0, 2520.0, 2530.0, 2570.0, 2560.0, 2590.0, 2580.0, 2520.0, 2550.0, 2590.0, 2570.0, 2560.0, 2590.0, 2560.0, 2120.0, 2390.0, 2600.0, 2310.0, 2500.0, 2500.0, 1720.0, 2500.0, 2510.0, 2500.0, 2300.0, 2580.0, 2570.0, 2570.0, 2580.0, 2570.0, 2560.0, 2570.0, 2520.0, 2500.0, 2530.0, 2560.0, 2530.0, 2580.0, 2590.0, 2600.0, 2540.0, 2590.0, 2600.0, 2560.0, 2550.0, 2590.0, 2150.0, 2540.0, 2580.0, 2530.0, 2590.0, 2600.0, 2600.0, 2590.0, 2580.0, 2570.0, 2500.0, 2530.0, 2580.0, 2130.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "1", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.0, 90.0, 160.0, 180.0, 130.0, 260.0, 400.0, 400.0, 440.0, 420.0, 430.0, 400.0, 400.0, 450.0, 470.0, 450.0, 470.0, 430.0, 400.0, 400.0, 400.0, 420.0, 410.0, 440.0, 420.0, 410.0, 430.0, 410.0, 460.0, 500.0, 500.0, 460.0, 450.0, 440.0, 470.0, 390.0, 500.0, 500.0, 500.0, 2100.0, 1870.0, 460.0, 1080.0, 2300.0, 1941.6666666666667, 2080.0, 2500.0, 2500.0, 2290.0, 2560.0, 2380.0, 2590.0, 2560.0, 2530.0, 2600.0, 2580.0, 2560.0, 2540.0, 2500.0, 2530.0, 2560.0, 2590.0, 1970.0, 2580.0, 2560.0, 2560.0, 2600.0, 2560.0, 2600.0, 2570.0, 2570.0, 2600.0, 2580.0, 2530.0, 2560.0, 2540.0, 2560.0, 2560.0, 2560.0, 2550.0, 2560.0, 2560.0, 2580.0, 2570.0, 2580.0, 2590.0, 2580.0, 2600.0, 2590.0, 2600.0, 2600.0, 2580.0, 2600.0, 2570.0, 2520.0, 2550.0, 2600.0, 2580.0, 2600.0, 2600.0, 2590.0, 2600.0, 2600.0, 2600.0, 2600.0, 2570.0, 2140.0, 2500.0, 2580.0, 2590.0, 2600.0, 2590.0, 2540.0, 2500.0, 2560.0, 2600.0, 2590.0, 2580.0, 2520.0, 2510.0, 2520.0, 2590.0, 2290.0, 2510.0, 2500.0, 2520.0, 2510.0, 2500.0, 2540.0, 2570.0, 2600.0, 2600.0, 2600.0, 2500.0, 2600.0, 2600.0, 2590.0, 2600.0, 2600.0, 2100.0, 2530.0, 2570.0, 2530.0, 2310.0, 2590.0, 2600.0, 2590.0, 2540.0, 2600.0, 2580.0, 2520.0, 2580.0, 2510.0, 2550.0, 2500.0, 2510.0, 2510.0, 2360.0, 2590.0, 2600.0, 1810.0, 2380.0, 2600.0, 2530.0, 2540.0, 2600.0, 2590.0, 2540.0, 2590.0, 2600.0, 2560.0, 2530.0, 2510.0, 2340.0, 2500.0, 2520.0, 2520.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "2", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 0.0, 20.0, 120.0, 310.0, 400.0, 400.0, 400.0, 400.0, 400.0, 400.0, 400.0, 430.0, 400.0, 490.0, 460.0, 470.0, 460.0, 430.0, 430.0, 490.0, 420.0, 470.0, 470.0, 440.0, 440.0, 430.0, 360.0, 400.0, 400.0, 450.0, 420.0, 430.0, 480.0, 480.0, 500.0, 480.0, 490.0, 450.0, 480.0, 500.0, 490.0, 460.0, 470.0, 440.0, 450.0, 410.0, 410.0, 410.0, 400.0, 420.0, 410.0, 400.0, 430.0, 410.0, 460.0, 430.0, 460.0, 450.0, 410.0, 410.0, 400.0, 410.0, 420.0, 430.0, 400.0, 400.0, 400.0, 366.6666666666667, 410.0, 400.0, 410.0, 430.0, 410.0, 440.0, 420.0, 400.0, 440.0, 430.0, 430.0, 440.0, 450.0, 500.0, 500.0, 500.0, 500.0, 490.0, 490.0, 500.0, 650.0, 1280.0, 2300.0, 1700.0, 2300.0, 2320.0, 1920.0, 2360.0, 2310.0, 2310.0, 2510.0, 2500.0, 2500.0, 2520.0, 2500.0, 2500.0, 2530.0, 2500.0, 2500.0, 2530.0, 2550.0, 2540.0, 2540.0, 2530.0, 2500.0, 2550.0, 2550.0, 2550.0, 2530.0, 2540.0, 2530.0, 2530.0, 2560.0, 2560.0, 2500.0, 2550.0, 2550.0, 2570.0, 2570.0, 2500.0, 2550.0, 2320.0, 2570.0, 2580.0, 2530.0, 2550.0, 2530.0, 2540.0, 2500.0, 2130.0, 2510.0, 2510.0, 2500.0, 2500.0, 2520.0, 2500.0, 2510.0, 2500.0, 2510.0, 2570.0, 2520.0, 2530.0, 2520.0, 2580.0, 2580.0, 2590.0, 2540.0, 2500.0, 2520.0, 2560.0, 2530.0, 2570.0, 2530.0, 1940.0, 2590.0, 2560.0, 2520.0, 2510.0, 2510.0, 2540.0, 2510.0, 2500.0, 2520.0, 2120.0, 2120.0, 2560.0, 2540.0, 2550.0, 2500.0, 2500.0, 2540.0, 2520.0, 2520.0, 2510.0, 2100.0, 2500.0, 2500.0, 2510.0, 2510.0, 2520.0, 2520.0, 2500.0, 2500.0, 2500.0, 2530.0, 2520.0, 2500.0, 2510.0, 1900.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "3", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.0, 90.0, 100.0, 100.0, 92.3076923076923, 91.66666666666667, 150.0, 50.0, 0.0, 0.0, 0.0, 330.0, 410.0, 400.0, 400.0, 230.0, 470.0, 440.0, 490.0, 500.0, 460.0, 460.0, 410.0, 480.0, 490.0, 500.0, 480.0, 480.0, 480.0, 480.0, 480.0, 450.0, 480.0, 490.0, 500.0, 500.0, 500.0, 450.0, 500.0, 500.0, 500.0, 500.0, 500.0, 490.0, 470.0, 490.0, 470.0, 510.0, 420.0, 420.0, 490.0, 1680.0, 1030.0, 690.0, 680.0, 1230.0, 1660.0, 2110.0, 1290.0, 1680.0, 2560.0, 2600.0, 2590.0, 2530.0, 2500.0, 2500.0, 2320.0, 2540.0, 2580.0, 2540.0, 2550.0, 2550.0, 2500.0, 2510.0, 2520.0, 2560.0, 2530.0, 2550.0, 2570.0, 2600.0, 2590.0, 2580.0, 2590.0, 2590.0, 2590.0, 2510.0, 2600.0, 2590.0, 2530.0, 2520.0, 2590.0, 2580.0, 2590.0, 2590.0, 2300.0, 2500.0, 2550.0, 2560.0, 2540.0, 2600.0, 2580.0, 2590.0, 2600.0, 2600.0, 2600.0, 2510.0, 2570.0, 2170.0, 2600.0, 2570.0, 2590.0, 2550.0, 2520.0, 2570.0, 2600.0, 2560.0, 2530.0, 2530.0, 2550.0, 2600.0, 2570.0, 2560.0, 2550.0, 2590.0, 2600.0, 2580.0, 2380.0, 2600.0, 2530.0, 2530.0, 2520.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "4", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 0.0, 100.0, 360.0, 440.0, 430.0, 400.0, 470.0, 450.0, 500.0, 500.0, 350.0, 500.0, 500.0, 490.0, 500.0, 470.0, 470.0, 500.0, 450.0, 450.0, 480.0, 480.0, 500.0, 500.0, 500.0, 500.0, 500.0, 500.0, 500.0, 900.0, 1100.0, 1050.0, 700.0, 1690.0, 2320.0, 2510.0, 2500.0, 2520.0, 2540.0, 2510.0, 2530.0, 2540.0, 2550.0, 2550.0, 2550.0, 2550.0, 2510.0, 2500.0, 2510.0, 2500.0, 2510.0, 2550.0, 2510.0, 2550.0, 2510.0, 1320.0, 2510.0, 2580.0, 2540.0, 2520.0, 2570.0, 2570.0, 2530.0, 2540.0, 2540.0, 2272.7272727272725, 2580.0, 2310.0, 2341.6666666666665, 2500.0, 2510.0, 2100.0, 2540.0, 2500.0, 2510.0, 2500.0, 2520.0, 1900.0, 2530.0, 2540.0, 2500.0, 2270.0, 2310.0, 2500.0, 2500.0, 2500.0, 2500.0, 2510.0, 2510.0, 2500.0, 2510.0, 2500.0, 2540.0, 2550.0, 1710.0, 2550.0, 2540.0, 2520.0, 2510.0, 2520.0, 2550.0, 2570.0, 1720.0, 2300.0, 2560.0, 2560.0, 2550.0, 2550.0, 2570.0, 2560.0, 2530.0, 2510.0, 2570.0, 2590.0, 2570.0, 2310.0, 2530.0, 2550.0, 2320.0, 2510.0, 2550.0, 2530.0, 2250.0, 2540.0, 2530.0, 2580.0, 2550.0, 2540.0, 2570.0, 2540.0, 2580.0, 2300.0, 2260.0, 2540.0, 2530.0, 2530.0, 2580.0, 2590.0, 2540.0, 2520.0, 2520.0, 1510.0, 2340.0, 2530.0, 2370.0, 2560.0, 2360.0, 2560.0, 2020.0, 1560.0, 2590.0, 2600.0, 2580.0, 2560.0, 2590.0, 2590.0, 2600.0, 2600.0, 2510.0, 2500.0, 2580.0, 2560.0, 2580.0, 2080.0, 2580.0, 2600.0, 2180.0, 2600.0, 2600.0, 2600.0, 2540.0, 2600.0, 2600.0, 2380.0, 2160.0, 2600.0, 2390.0, 2570.0, 2590.0, 2590.0, 2580.0, 2570.0, 2590.0, 2550.0, 2570.0, 2590.0, 2590.0, 2600.0, 2550.0, 2550.0, 2600.0, 2600.0, 2600.0, 2570.0, 2030.0, 2160.0, 2310.0]}, {"task": "atari_montezuma_revenge", "method": "dreamerv2", "seed": "5", "xs": [1000000.0, 2000000.0, 3000000.0, 4000000.0, 5000000.0, 6000000.0, 7000000.0, 8000000.0, 9000000.0, 10000000.0, 11000000.0, 12000000.0, 13000000.0, 14000000.0, 15000000.0, 16000000.0, 17000000.0, 18000000.0, 19000000.0, 20000000.0, 21000000.0, 22000000.0, 23000000.0, 24000000.0, 25000000.0, 26000000.0, 27000000.0, 28000000.0, 29000000.0, 30000000.0, 31000000.0, 32000000.0, 33000000.0, 34000000.0, 35000000.0, 36000000.0, 37000000.0, 38000000.0, 39000000.0, 40000000.0, 41000000.0, 42000000.0, 43000000.0, 44000000.0, 45000000.0, 46000000.0, 47000000.0, 48000000.0, 49000000.0, 50000000.0, 51000000.0, 52000000.0, 53000000.0, 54000000.0, 55000000.0, 56000000.0, 57000000.0, 58000000.0, 59000000.0, 60000000.0, 61000000.0, 62000000.0, 63000000.0, 64000000.0, 65000000.0, 66000000.0, 67000000.0, 68000000.0, 69000000.0, 70000000.0, 71000000.0, 72000000.0, 73000000.0, 74000000.0, 75000000.0, 76000000.0, 77000000.0, 78000000.0, 79000000.0, 80000000.0, 81000000.0, 82000000.0, 83000000.0, 84000000.0, 85000000.0, 86000000.0, 87000000.0, 88000000.0, 89000000.0, 90000000.0, 91000000.0, 92000000.0, 93000000.0, 94000000.0, 95000000.0, 96000000.0, 97000000.0, 98000000.0, 99000000.0, 100000000.0, 101000000.0, 102000000.0, 103000000.0, 104000000.0, 105000000.0, 106000000.0, 107000000.0, 108000000.0, 109000000.0, 110000000.0, 111000000.0, 112000000.0, 113000000.0, 114000000.0, 115000000.0, 116000000.0, 117000000.0, 118000000.0, 119000000.0, 120000000.0, 121000000.0, 122000000.0, 123000000.0, 124000000.0, 125000000.0, 126000000.0, 127000000.0, 128000000.0, 129000000.0, 130000000.0, 131000000.0, 132000000.0, 133000000.0, 134000000.0, 135000000.0, 136000000.0, 137000000.0, 138000000.0, 139000000.0, 140000000.0, 141000000.0, 142000000.0, 143000000.0, 144000000.0, 145000000.0, 146000000.0, 147000000.0, 148000000.0, 149000000.0, 150000000.0, 151000000.0, 152000000.0, 153000000.0, 154000000.0, 155000000.0, 156000000.0, 157000000.0, 158000000.0, 159000000.0, 160000000.0, 161000000.0, 162000000.0, 163000000.0, 164000000.0, 165000000.0, 166000000.0, 167000000.0, 168000000.0, 169000000.0, 170000000.0, 171000000.0, 172000000.0, 173000000.0, 174000000.0, 175000000.0, 176000000.0, 177000000.0, 178000000.0, 179000000.0, 180000000.0, 181000000.0, 182000000.0, 183000000.0, 184000000.0, 185000000.0, 186000000.0, 187000000.0, 188000000.0, 189000000.0, 190000000.0, 191000000.0, 192000000.0, 193000000.0, 194000000.0, 195000000.0, 196000000.0, 197000000.0, 198000000.0, 199000000.0], "ys": [0.0, 0.0, 80.0, 370.0, 430.0, 410.0, 410.0, 420.0, 450.0, 460.0, 500.0, 900.0, 900.0, 1900.0, 900.0, 700.0, 1300.0, 2260.0, 2020.0, 2500.0, 2300.0, 2500.0, 2300.0, 2500.0, 2090.0, 2500.0, 2500.0, 2080.0, 2500.0, 2500.0, 2500.0, 2500.0, 2250.0, 2500.0, 2300.0, 2500.0, 2300.0, 1700.0, 2500.0, 2272.7272727272725, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2300.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2510.0, 2500.0, 2290.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2100.0, 2500.0, 2080.0, 2080.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2250.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2520.0, 2500.0, 2500.0, 2500.0, 2250.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 1900.0, 2500.0, 2500.0, 2500.0, 2500.0, 2510.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2510.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2100.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2090.0, 2500.0, 2500.0, 2500.0, 2500.0, 1900.0, 1900.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0, 2500.0]}]
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | import pathlib
3 |
4 |
5 | setuptools.setup(
6 | name='dreamerv2',
7 | version='2.2.0',
8 | description='Mastering Atari with Discrete World Models',
9 | url='http://github.com/danijar/dreamerv2',
10 | long_description=pathlib.Path('README.md').read_text(),
11 | long_description_content_type='text/markdown',
12 | packages=['dreamerv2', 'dreamerv2.common'],
13 | package_data={'dreamerv2': ['configs.yaml']},
14 | entry_points={'console_scripts': ['dreamerv2=dreamerv2.train:main']},
15 | install_requires=[
16 | 'gym[atari]', 'atari_py', 'crafter', 'dm_control', 'ruamel.yaml',
17 | 'tensorflow', 'tensorflow_probability'],
18 | classifiers=[
19 | 'Intended Audience :: Science/Research',
20 | 'License :: OSI Approved :: MIT License',
21 | 'Programming Language :: Python :: 3',
22 | 'Topic :: Games/Entertainment',
23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
24 | ],
25 | )
26 |
--------------------------------------------------------------------------------