├── .gitignore
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── dreamer
├── __init__.py
├── control
│ ├── __init__.py
│ ├── batch_env.py
│ ├── dummy_env.py
│ ├── mpc_agent.py
│ ├── planning.py
│ ├── random_episodes.py
│ ├── simulate.py
│ ├── temporal_difference.py
│ └── wrappers.py
├── distributions
│ ├── __init__.py
│ ├── one_hot.py
│ └── tanh_normal.py
├── models
│ ├── __init__.py
│ ├── base.py
│ └── rssm.py
├── networks
│ ├── __init__.py
│ ├── basic.py
│ ├── conv.py
│ └── proprio.py
├── scripts
│ ├── __init__.py
│ ├── configs.py
│ ├── fileutil_fetch.py
│ ├── objectives.py
│ ├── sync.py
│ ├── tasks.py
│ └── train.py
├── tools
│ ├── __init__.py
│ ├── attr_dict.py
│ ├── bind.py
│ ├── chunk_sequence.py
│ ├── copy_weights.py
│ ├── count_dataset.py
│ ├── count_weights.py
│ ├── custom_optimizer.py
│ ├── filter_variables_lib.py
│ ├── gif_summary.py
│ ├── image_strip_summary.py
│ ├── mask.py
│ ├── metrics.py
│ ├── nested.py
│ ├── numpy_episodes.py
│ ├── overshooting.py
│ ├── preprocess.py
│ ├── reshape_as.py
│ ├── schedule.py
│ ├── shape.py
│ ├── streaming_mean.py
│ ├── summary.py
│ ├── target_network.py
│ └── unroll.py
└── training
│ ├── __init__.py
│ ├── define_model.py
│ ├── define_summaries.py
│ ├── running.py
│ ├── trainer.py
│ └── utility.py
├── setup.py
└── tests
├── test_dreamer.py
├── test_isolate_envs.py
├── test_nested.py
├── test_overshooting.py
└── test_running.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *.egg-info
4 | /dist
5 | /logdir
6 | MUJOCO_LOG.TXT
7 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of PlaNet authors for copyright purposes.
2 | #
3 | # This does not necessarily list everyone who has contributed code, since in
4 | # some cases, their employer may be the copyright holder. To see the full list
5 | # of contributors, see the revision history in source control.
6 | Google LLC
7 | Danijar Hafner
8 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dream to Control
2 |
3 | Danijar Hafner, Timothy Lillicrap, Jimmy Ba, Mohammad Norouzi
4 |
5 | **Note:** This is the original implementation. To build upon Dreamer, we
6 | recommend the newer implementation of [Dreamer in TensorFlow
7 | 2](https://github.com/danijar/dreamer). It is substantially simpler
8 | and faster while replicating the results.
9 |
10 |
11 |
12 | Implementation of Dreamer, the reinforcement learning agent introduced in
13 | [Dream to Control: Learning Behaviors by Latent Imagination][paper]. Dreamer
14 | learns long-horizon behaviors from images purely by latent imagination. For
15 | this, it backpropagates value estimates through trajectories imagined in the
16 | compact latent space of a learned world model. Dreamer solves visual control
17 | tasks using substantilly fewer episodes than strong model-free agents.
18 |
19 | If you find this open source release useful, please reference in your paper:
20 |
21 | ```
22 | @article{hafner2019dreamer,
23 | title={Dream to Control: Learning Behaviors by Latent Imagination},
24 | author={Hafner, Danijar and Lillicrap, Timothy and Ba, Jimmy and Norouzi, Mohammad},
25 | journal={arXiv preprint arXiv:1912.01603},
26 | year={2019}
27 | }
28 | ```
29 |
30 | ## Method
31 |
32 | 
33 |
34 | Dreamer learns a world model from past experience that can predict into the
35 | future. It then learns action and value models in its compact latent space. The
36 | value model optimizes Bellman consistency of imagined trajectories. The action
37 | model maximizes value estimates by propgating their analytic gradients back
38 | through imagined trajectories. When interacting with the environment, it simply
39 | executes the action model.
40 |
41 | Find out more:
42 |
43 | - [Project website][website]
44 | - [PDF paper][paper]
45 |
46 | [website]: https://danijar.com/dreamer
47 | [paper]: https://arxiv.org/pdf/1912.01603.pdf
48 |
49 | ## Instructions
50 |
51 | To train an agent, install the dependencies and then run one of these commands:
52 |
53 | ```sh
54 | python3 -m dreamer.scripts.train --logdir ./logdir/debug \
55 | --params '{defaults: [dreamer, debug], tasks: [dummy]}' \
56 | --num_runs 1000 --resume_runs False
57 | ```
58 |
59 | ```sh
60 | python3 -m dreamer.scripts.train --logdir ./logdir/control \
61 | --params '{defaults: [dreamer], tasks: [walker_run]}'
62 | ```
63 |
64 | ```sh
65 | python3 -m dreamer.scripts.train --logdir ./logdir/atari \
66 | --params '{defaults: [dreamer, pcont, discrete, atari], tasks: [atari_boxing]}'
67 | ```
68 |
69 | ```sh
70 | python3 -m dreamer.scripts.train --logdir ./logdir/dmlab \
71 | --params '{defaults: [dreamer, discrete], tasks: [dmlab_collect]}'
72 | ```
73 |
74 | The available tasks are listed in `scripts/tasks.py`. The hyper parameters can
75 | be found in `scripts/configs.py`.
76 |
77 | Tips:
78 |
79 | - Add `debug` to the list of defaults to use a smaller config and reach
80 | the code you're developing more quickly.
81 | - Add the flags `--resume_runs False` and `--num_runs 1000`
82 | to automatically create unique logdirs.
83 | - To train the baseline without value function, add `value_head: False` to the
84 | params.
85 | - To train PlaNet, add `train_planner: cem, test_planner: cem,
86 | planner_objective: reward, action_head: False, value_head: False,
87 | imagination_horizon: 0` to the params.
88 |
89 | ## Dependencies
90 |
91 | The code was tested under Ubuntu 18 and uses these packages:
92 | tensorflow-gpu==1.13.1, tensorflow_probability==0.6.0, dm_control (`egl`
93 | [rendering option][rendering] recommended), gym, imageio, matplotlib,
94 | ruamel.yaml, scikit-image, scipy.
95 |
96 | [rendering]: https://github.com/deepmind/dm_control#rendering
97 |
98 | Disclaimer: This is not an official Google product.
99 |
--------------------------------------------------------------------------------
/dreamer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import control
20 | from . import models
21 | from . import networks
22 | from . import scripts
23 | from . import tools
24 | from . import training
25 |
--------------------------------------------------------------------------------
/dreamer/control/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import planning
20 | from .batch_env import PyBatchEnv
21 | from .batch_env import TFBatchEnv
22 | from .dummy_env import DummyEnv
23 | from .mpc_agent import MPCAgent
24 | from .random_episodes import random_episodes
25 | from .simulate import create_batch_env
26 | from .simulate import create_env
27 | from .simulate import simulate
28 | from .temporal_difference import discounted_return
29 | from .temporal_difference import fixed_step_return
30 | from .temporal_difference import lambda_return
31 |
--------------------------------------------------------------------------------
/dreamer/control/batch_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import gym
20 | import numpy as np
21 | import tensorflow as tf
22 |
23 |
24 | class TFBatchEnv(object):
25 |
26 | def __init__(self, envs, blocking):
27 | self._batch_env = PyBatchEnv(envs, blocking, flatten=True)
28 | spaces = self._batch_env.observation_space.spaces
29 | self._dtypes = [self._parse_dtype(spaces[key]) for key in self._keys[:-2]]
30 | self._dtypes += [tf.float32, tf.bool] # Reward and done flag.
31 | self._shapes = [self._parse_shape(spaces[key]) for key in self._keys[:-2]]
32 | self._shapes += [(), ()] # Reward and done flag.
33 |
34 | def __getattr__(self, name):
35 | return getattr(self._batch_env, name)
36 |
37 | def __len__(self):
38 | return len(self._batch_env)
39 |
40 | def __getitem__(self, index):
41 | return self._batch_env[index]
42 |
43 | def step(self, action):
44 | output = tf.py_func(
45 | self._batch_env.step, [action], self._dtypes, name='step')
46 | return self._process_output(output, len(self._batch_env))
47 |
48 | def reset(self, indices=None):
49 | if indices is None:
50 | indices = tf.range(len(self._batch_env))
51 | output = tf.py_func(
52 | self._batch_env.reset, [indices], self._dtypes, name='reset')
53 | return self._process_output(output, None)
54 |
55 | def _process_output(self, output, batch_size):
56 | for tensor, shape in zip(output, self._shapes):
57 | tensor.set_shape((batch_size,) + shape)
58 | return {key: tensor for key, tensor in zip(self._keys, output)}
59 |
60 | def _parse_dtype(self, space):
61 | if isinstance(space, gym.spaces.Discrete):
62 | return tf.int32
63 | if isinstance(space, gym.spaces.Box):
64 | if space.low.dtype == np.uint8:
65 | return tf.uint8
66 | else:
67 | return tf.float32
68 | raise NotImplementedError()
69 |
70 | def _parse_shape(self, space):
71 | if isinstance(space, gym.spaces.Discrete):
72 | return ()
73 | if isinstance(space, gym.spaces.Box):
74 | return space.shape
75 | raise NotImplementedError("Unsupported space '{}.'".format(space))
76 |
77 |
78 | class PyBatchEnv(object):
79 |
80 | def __init__(self, envs, blocking, flatten=False):
81 | observ_space = envs[0].observation_space
82 | if not all(env.observation_space == observ_space for env in envs):
83 | raise ValueError('All environments must use the same observation space.')
84 | action_space = envs[0].action_space
85 | if not all(env.action_space == action_space for env in envs):
86 | raise ValueError('All environments must use the same observation space.')
87 | self._envs = envs
88 | self._blocking = blocking
89 | self._flatten = flatten
90 | self._keys = list(sorted(observ_space.spaces.keys())) + ['reward', 'done']
91 |
92 | def __len__(self):
93 | return len(self._envs)
94 |
95 | def __getitem__(self, index):
96 | return self._envs[index]
97 |
98 | def __getattr__(self, name):
99 | return getattr(self._envs[0], name)
100 |
101 | def step(self, actions):
102 | for index, (env, action) in enumerate(zip(self._envs, actions)):
103 | if not env.action_space.contains(action):
104 | message = 'Invalid action for batch index {}: {}'
105 | raise ValueError(message.format(index, action))
106 | if self._blocking:
107 | transitions = [
108 | env.step(action)
109 | for env, action in zip(self._envs, actions)]
110 | else:
111 | transitions = [
112 | env.step(action, blocking=False)
113 | for env, action in zip(self._envs, actions)]
114 | transitions = [transition() for transition in transitions]
115 | outputs = {key: [] for key in self._keys}
116 | for observ, reward, done, _ in transitions:
117 | for key, value in observ.items():
118 | outputs[key].append(np.array(value))
119 | outputs['reward'].append(np.array(reward, np.float32))
120 | outputs['done'].append(np.array(done, np.bool))
121 | outputs = {key: np.stack(value) for key, value in outputs.items()}
122 | if self._flatten:
123 | outputs = tuple(outputs[key] for key in self._keys)
124 | return outputs
125 |
126 | def reset(self, indices=None):
127 | if indices is None:
128 | indices = range(len(self._envs))
129 | if self._blocking:
130 | observs = [self._envs[index].reset() for index in indices]
131 | else:
132 | observs = [self._envs[index].reset(blocking=False) for index in indices]
133 | observs = [observ() for observ in observs]
134 | outputs = {key: [] for key in self._keys}
135 | for observ in observs:
136 | for key, value in observ.items():
137 | outputs[key].append(np.array(value))
138 | outputs['reward'].append(np.array(0.0, np.float32))
139 | outputs['done'].append(np.array(False, np.bool))
140 | outputs = {key: np.stack(value) for key, value in outputs.items()}
141 | if self._flatten:
142 | outputs = tuple(outputs[key] for key in self._keys)
143 | return outputs
144 |
145 | def close(self):
146 | for env in self._envs:
147 | if hasattr(env, 'close'):
148 | env.close()
149 |
--------------------------------------------------------------------------------
/dreamer/control/dummy_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import gym
20 | import numpy as np
21 |
22 |
23 | class DummyEnv(object):
24 |
25 | def __init__(self):
26 | self._random = np.random.RandomState(seed=0)
27 | self._step = None
28 |
29 | @property
30 | def observation_space(self):
31 | low = np.zeros([64, 64, 3], dtype=np.float32)
32 | high = np.ones([64, 64, 3], dtype=np.float32)
33 | spaces = {'image': gym.spaces.Box(low, high)}
34 | return gym.spaces.Dict(spaces)
35 |
36 | @property
37 | def action_space(self):
38 | low = -np.ones([5], dtype=np.float32)
39 | high = np.ones([5], dtype=np.float32)
40 | return gym.spaces.Box(low, high)
41 |
42 | def reset(self):
43 | self._step = 0
44 | obs = self.observation_space.sample()
45 | return obs
46 |
47 | def step(self, action):
48 | obs = self.observation_space.sample()
49 | reward = self._random.uniform(0, 1)
50 | self._step += 1
51 | done = self._step >= 1000
52 | info = {}
53 | return obs, reward, done, info
54 |
--------------------------------------------------------------------------------
/dreamer/control/mpc_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | from tensorflow_probability import distributions as tfd
22 |
23 | from dreamer.tools import nested
24 |
25 |
26 | class MPCAgent(object):
27 |
28 | def __init__(self, batch_env, step, is_training, should_log, config):
29 | self._step = step # Trainer step, not environment step.
30 | self._is_training = is_training
31 | self._should_log = should_log
32 | self._config = config
33 | self._cell = config.cell
34 | self._num_envs = len(batch_env)
35 | state = self._cell.zero_state(self._num_envs, tf.float32)
36 | var_like = lambda x: tf.get_local_variable(
37 | x.name.split(':')[0].replace('/', '_') + '_var',
38 | shape=x.shape,
39 | initializer=lambda *_, **__: tf.zeros_like(x), use_resource=True)
40 | self._state = nested.map(var_like, state)
41 | batch_action_shape = (self._num_envs,) + batch_env.action_space.shape
42 | self._prev_action = tf.get_local_variable(
43 | 'prev_action_var', shape=batch_action_shape,
44 | initializer=lambda *_, **__: tf.zeros(batch_action_shape),
45 | use_resource=True)
46 |
47 | def reset(self, agent_indices):
48 | state = nested.map(
49 | lambda tensor: tf.gather(tensor, agent_indices),
50 | self._state)
51 | reset_state = nested.map(
52 | lambda var, val: tf.scatter_update(var, agent_indices, 0 * val),
53 | self._state, state, flatten=True)
54 | reset_prev_action = self._prev_action.assign(
55 | tf.zeros_like(self._prev_action))
56 | return tf.group(reset_prev_action, *reset_state)
57 |
58 | def step(self, agent_indices, observ):
59 | observ = self._config.preprocess_fn(observ)
60 | # Converts observ to sequence.
61 | observ = nested.map(lambda x: x[:, None], observ)
62 | embedded = self._config.encoder(observ)[:, 0]
63 | state = nested.map(
64 | lambda tensor: tf.gather(tensor, agent_indices),
65 | self._state)
66 | prev_action = self._prev_action + 0
67 | with tf.control_dependencies([prev_action]):
68 | use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None]
69 | _, state = self._cell((embedded, prev_action, use_obs), state)
70 | action = self._config.planner(
71 | self._cell, self._config.objective, state,
72 | embedded.shape[1:].as_list(),
73 | prev_action.shape[1:].as_list())
74 | action = action[:, 0]
75 | if self._config.exploration:
76 | expl = self._config.exploration
77 | scale = tf.cast(expl.scale, tf.float32)[None] # Batch dimension.
78 | if expl.schedule:
79 | scale *= expl.schedule(self._step)
80 | if expl.factors:
81 | scale *= np.array(expl.factors)
82 | if expl.type == 'additive_normal':
83 | action = tfd.Normal(action, scale[:, None]).sample()
84 | elif expl.type == 'epsilon_greedy':
85 | random_action = tf.one_hot(
86 | tfd.Categorical(0 * action).sample(), action.shape[-1])
87 | switch = tf.cast(tf.less(
88 | tf.random.uniform((self._num_envs,)),
89 | scale), tf.float32)[:, None]
90 | action = switch * random_action + (1 - switch) * action
91 | else:
92 | raise NotImplementedError(expl.type)
93 | action = tf.clip_by_value(action, -1, 1)
94 | remember_action = self._prev_action.assign(action)
95 | remember_state = nested.map(
96 | lambda var, val: tf.scatter_update(var, agent_indices, val),
97 | self._state, state, flatten=True)
98 | with tf.control_dependencies(remember_state + (remember_action,)):
99 | return tf.identity(action)
100 |
--------------------------------------------------------------------------------
/dreamer/control/planning.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer import tools
22 |
23 |
24 | def cross_entropy_method(
25 | cell, objective, state, obs_shape, action_shape, horizon, graph,
26 | beams=1000, topk=100, iterations=10, min_action=-1, max_action=1):
27 | obs_shape, action_shape = tuple(obs_shape), tuple(action_shape)
28 | batch = tools.shape(tools.nested.flatten(state)[0])[0]
29 | initial_state = tools.nested.map(lambda tensor: tf.tile(
30 | tensor, [beams] + [1] * (tensor.shape.ndims - 1)), state)
31 | extended_batch = tools.shape(tools.nested.flatten(initial_state)[0])[0]
32 | use_obs = tf.zeros([extended_batch, horizon, 1], tf.bool)
33 | obs = tf.zeros((extended_batch, horizon) + obs_shape)
34 |
35 | def iteration(index, mean, stddev):
36 | # Sample action proposals from belief.
37 | normal = tf.random_normal((batch, beams, horizon) + action_shape)
38 | action = normal * stddev[:, None] + mean[:, None]
39 | action = tf.clip_by_value(action, min_action, max_action)
40 | # Evaluate proposal actions.
41 | action = tf.reshape(
42 | action, (extended_batch, horizon) + action_shape)
43 | (_, state), _ = tf.nn.dynamic_rnn(
44 | cell, (0 * obs, action, use_obs), initial_state=initial_state)
45 | return_ = objective(state)
46 | return_ = tf.reshape(return_, (batch, beams))
47 | # Re-fit belief to the best ones.
48 | _, indices = tf.nn.top_k(return_, topk, sorted=False)
49 | indices += tf.range(batch)[:, None] * beams
50 | best_actions = tf.gather(action, indices)
51 | mean, variance = tf.nn.moments(best_actions, 1)
52 | stddev = tf.sqrt(variance + 1e-6)
53 | return index + 1, mean, stddev
54 |
55 | mean = tf.zeros((batch, horizon) + action_shape)
56 | stddev = tf.ones((batch, horizon) + action_shape)
57 | _, mean, std = tf.while_loop(
58 | lambda index, mean, stddev: index < iterations, iteration,
59 | (0, mean, stddev), back_prop=False)
60 | return mean
61 |
62 |
63 | def action_head_policy(
64 | cell, objective, state, obs_shape, action_shape, graph, config, strategy):
65 | features = cell.features_from_state(state)
66 | policy = graph.heads.action(features)
67 | if strategy == 'sample':
68 | action = policy.sample()
69 | elif strategy == 'mode':
70 | action = policy.mode()
71 | else:
72 | raise NotImplementedError(strategy)
73 | plan = action[:, None, :]
74 | return plan
75 |
--------------------------------------------------------------------------------
/dreamer/control/random_episodes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from dreamer import control
20 |
21 |
22 | def random_episodes(
23 | env_ctor, num_episodes, num_steps, outdir=None, isolate_envs='none'):
24 | # If using environment processes or threads, we should also use them here to
25 | # avoid loading their dependencies into the global name space. This way,
26 | # their imports will be isolated from the main process and later created envs
27 | # do not inherit them via global state but import their own copies.
28 | env, _ = control.create_env(env_ctor, isolate_envs)
29 | env = control.wrappers.CollectDataset(env, outdir)
30 | episodes = [] if outdir else None
31 | while num_episodes > 0 or num_steps > 0:
32 | policy = lambda env, obs: env.action_space.sample()
33 | done = False
34 | obs = env.reset()
35 | while not done:
36 | action = policy(env, obs)
37 | obs, _, done, info = env.step(action)
38 | episode = env._get_episode()
39 | episodes.append(episode)
40 | num_episodes -= 1
41 | num_steps -= len(episode['reward'])
42 | try:
43 | env.close()
44 | except AttributeError:
45 | pass
46 | return episodes
47 |
--------------------------------------------------------------------------------
/dreamer/control/simulate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer.control import batch_env
22 | from dreamer.control import wrappers
23 |
24 |
25 | def create_batch_env(env_ctor, num_envs, isolate_envs):
26 | envs, blockings = zip(*[
27 | create_env(env_ctor, isolate_envs)
28 | for _ in range(num_envs)])
29 | assert all(blocking == blockings[0] for blocking in blockings)
30 | return batch_env.TFBatchEnv(envs, blockings[0])
31 |
32 |
33 | def create_env(env_ctor, isolate_envs):
34 | if isolate_envs == 'none':
35 | env = env_ctor()
36 | blocking = True
37 | elif isolate_envs == 'thread':
38 | env = wrappers.Async(env_ctor, 'thread')
39 | blocking = False
40 | elif isolate_envs == 'process':
41 | env = wrappers.Async(env_ctor, 'process')
42 | blocking = False
43 | else:
44 | raise NotImplementedError(isolate_envs)
45 | return env, blocking
46 |
47 |
48 | def simulate(agent, env, episodes=None, steps=None):
49 |
50 | def pred(step, episode, *args, **kwargs):
51 | if episodes is None:
52 | return tf.less(step, steps)
53 | if steps is None:
54 | return tf.less(episode, episodes)
55 | return tf.logical_or(tf.less(episode, episodes), tf.less(step, steps))
56 |
57 | def reset(mask, scores, lengths, previous):
58 | indices = tf.where(mask)[:, 0]
59 | reset_agent = agent.reset(indices)
60 | values = env.reset(indices)
61 | # This would be shorter but gives an internal TensorFlow error.
62 | # previous = {
63 | # key: tf.tensor_scatter_update(
64 | # previous[key], indices[:, None], values[key])
65 | # for key in previous}
66 | idx = tf.cast(indices[:, None], tf.int32)
67 | previous = {
68 | key: tf.where(
69 | mask,
70 | tf.scatter_nd(idx, values[key], tf.shape(previous[key])),
71 | previous[key])
72 | for key in previous}
73 | scores = tf.where(mask, tf.zeros_like(scores), scores)
74 | lengths = tf.where(mask, tf.zeros_like(lengths), lengths)
75 | with tf.control_dependencies([reset_agent]):
76 | return tf.identity(scores), tf.identity(lengths), previous
77 |
78 | def body(step, episode, scores, lengths, previous, ta_out, ta_sco, ta_len):
79 | # Reset episodes and agents if necessary.
80 | reset_mask = tf.cond(
81 | tf.equal(step, 0),
82 | lambda: tf.ones(len(env), tf.bool),
83 | lambda: previous['done'])
84 | reset_mask.set_shape([len(env)])
85 | scores, lengths, previous = tf.cond(
86 | tf.reduce_any(reset_mask),
87 | lambda: reset(reset_mask, scores, lengths, previous),
88 | lambda: (scores, lengths, previous))
89 | step_indices = tf.range(len(env))
90 | action = agent.step(step_indices, previous)
91 | values = env.step(action)
92 | # Update book keeping variables.
93 | done_indices = tf.cast(tf.where(values['done']), tf.int32)
94 | step += tf.shape(step_indices)[0]
95 | episode += tf.shape(done_indices)[0]
96 | scores += values['reward']
97 | lengths += tf.shape(step_indices)[0]
98 | # Write transitions, scores, and lengths to tensor arrays.
99 | ta_out = {
100 | key: array.write(array.size(), values[key])
101 | for key, array in ta_out.items()}
102 | ta_sco = tf.cond(
103 | tf.greater(tf.shape(done_indices)[0], 0),
104 | lambda: ta_sco.write(
105 | ta_sco.size(), tf.gather(scores, done_indices)[:, 0]),
106 | lambda: ta_sco)
107 | ta_len = tf.cond(
108 | tf.greater(tf.shape(done_indices)[0], 0),
109 | lambda: ta_len.write(
110 | ta_len.size(), tf.gather(lengths, done_indices)[:, 0]),
111 | lambda: ta_len)
112 | return step, episode, scores, lengths, values, ta_out, ta_sco, ta_len
113 |
114 | initial = env.reset()
115 | ta_out = {
116 | key: tf.TensorArray(
117 | value.dtype, 0, True, element_shape=value.shape)
118 | for key, value in initial.items()}
119 | ta_sco = tf.TensorArray(
120 | tf.float32, 0, True, element_shape=[None], infer_shape=False)
121 | ta_len = tf.TensorArray(
122 | tf.int32, 0, True, element_shape=[None], infer_shape=False)
123 | zero_scores = tf.zeros(len(env), tf.float32, name='scores')
124 | zero_lengths = tf.zeros(len(env), tf.int32, name='lengths')
125 | ta_out, ta_sco, ta_len = tf.while_loop(
126 | pred, body,
127 | (0, 0, zero_scores, zero_lengths, initial, ta_out, ta_sco, ta_len),
128 | parallel_iterations=1, back_prop=False)[-3:]
129 | transitions = {key: array.stack() for key, array in ta_out.items()}
130 | transitions = {
131 | key: tf.transpose(value, [1, 0] + list(range(2, value.shape.ndims)))
132 | for key, value in transitions.items()}
133 | scores = ta_sco.concat()
134 | lengths = ta_len.concat()
135 | return scores, lengths, transitions
136 |
--------------------------------------------------------------------------------
/dreamer/control/temporal_difference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def discounted_return(reward, pcont, bootstrap, axis, stop_gradient=True):
23 | if isinstance(pcont, (float, int)):
24 | if pcont == 1 and bootstrap is None:
25 | return tf.reduce_sum(reward, axis)
26 | if pcont == 1:
27 | return tf.reduce_sum(reward, axis) + bootstrap
28 | pcont = pcont * tf.ones_like(reward)
29 | # Bring the aggregation dimension front.
30 | dims = list(range(reward.shape.ndims))
31 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
32 | reward = tf.transpose(reward, dims)
33 | pcont = tf.transpose(pcont, dims)
34 | if bootstrap is None:
35 | bootstrap = tf.zeros_like(reward[-1])
36 | return_ = tf.scan(
37 | fn=lambda agg, cur: cur[0] + cur[1] * agg,
38 | elems=(reward, pcont),
39 | initializer=bootstrap,
40 | back_prop=not stop_gradient,
41 | reverse=True)
42 | return_ = tf.transpose(return_, dims)
43 | if stop_gradient:
44 | return_ = tf.stop_gradient(return_)
45 | return return_
46 |
47 |
48 | def lambda_return(
49 | reward, value, bootstrap, pcont, lambda_, axis, stop_gradient=True):
50 | # Setting lambda=1 gives a discounted Monte Carlo return.
51 | # Setting lambda=0 gives a fixed 1-step return.
52 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
53 | # Bring the aggregation dimension front.
54 | dims = list(range(reward.shape.ndims))
55 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
56 | if isinstance(pcont, (int, float)):
57 | pcont = pcont * tf.ones_like(reward)
58 | reward = tf.transpose(reward, dims)
59 | value = tf.transpose(value, dims)
60 | pcont = tf.transpose(pcont, dims)
61 | if bootstrap is None:
62 | bootstrap = tf.zeros_like(value[-1])
63 | next_values = tf.concat([value[1:], bootstrap[None]], 0)
64 | inputs = reward + pcont * next_values * (1 - lambda_)
65 | return_ = tf.scan(
66 | fn=lambda agg, cur: cur[0] + cur[1] * lambda_ * agg,
67 | elems=(inputs, pcont),
68 | initializer=bootstrap,
69 | back_prop=not stop_gradient,
70 | reverse=True)
71 | return_ = tf.transpose(return_, dims)
72 | if stop_gradient:
73 | return_ = tf.stop_gradient(return_)
74 | return return_
75 |
76 |
77 | def fixed_step_return(
78 | reward, value, discount, steps, axis, stop_gradient=True):
79 | # Brings the aggregation dimension front.
80 | dims = list(range(reward.shape.ndims))
81 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
82 | reward = tf.transpose(reward, dims)
83 | length = tf.shape(reward)[0]
84 | _, return_ = tf.while_loop(
85 | cond=lambda i, p: i < steps + 1,
86 | body=lambda i, p: (i + 1, reward[steps - i: length - i] + discount * p),
87 | loop_vars=[tf.constant(1), tf.zeros_like(reward[steps:])],
88 | back_prop=not stop_gradient)
89 | if value is not None:
90 | return_ += discount ** steps * tf.transpose(value, dims)[steps:]
91 | return_ = tf.transpose(return_, dims)
92 | if stop_gradient:
93 | return_ = tf.stop_gradient(return_)
94 | return return_
95 |
--------------------------------------------------------------------------------
/dreamer/distributions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from .one_hot import OneHot
20 | from .tanh_normal import TanhNormal
21 |
--------------------------------------------------------------------------------
/dreamer/distributions/one_hot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | from tensorflow_probability import distributions as tfd
21 |
22 |
23 | class OneHot(object):
24 |
25 | def __init__(self, logits=None, probs=None, gradient='score'):
26 | self._gradient = gradient
27 | self._dist = tfd.Categorical(logits=logits, probs=probs)
28 | self._num_classes = self.mean().shape[-1].value
29 |
30 | @property
31 | def name(self):
32 | return 'OneHotDistribution'
33 |
34 | def __getattr__(self, name):
35 | return getattr(self._dist, name)
36 |
37 | def prob(self, events):
38 | indices = tf.argmax(events, axis=-1)
39 | return self._dist.prob(indices)
40 |
41 | def log_prob(self, events):
42 | indices = tf.argmax(events, axis=-1)
43 | return self._dist.log_prob(indices)
44 |
45 | def mean(self):
46 | return self._dist.probs_parameter()
47 |
48 | def stddev(self):
49 | values = tf.one_hot(tf.range(self._num_classes), self._num_classes)
50 | distances = tf.reduce_sum((values - self.mean()[..., None, :]) ** 2, -1)
51 | return tf.sqrt(distances)
52 |
53 | def mode(self):
54 | return tf.one_hot(self._dist.mode(), self._num_classes)
55 |
56 | def sample(self, amount=None):
57 | amount = [amount] if amount else []
58 | sample = tf.one_hot(self._dist.sample(*amount), self._num_classes)
59 | if self._gradient == 'score': # Implemented as DiCE.
60 | logp = self.log_prob(sample)[..., None]
61 | sample *= tf.exp(logp - tf.stop_gradient(logp))
62 | elif self._gradient == 'straight': # Gradient for all classes.
63 | probs = self._dist.probs_parameter()
64 | sample += probs - tf.stop_gradient(probs)
65 | else:
66 | raise NotImplementedError(self._gradient)
67 | return sample
68 |
--------------------------------------------------------------------------------
/dreamer/distributions/tanh_normal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | import tensorflow_probability as tfp
21 | from tensorflow_probability import distributions as tfd
22 |
23 |
24 | class TanhNormal(object):
25 |
26 | def __init__(self, mean, std, samples=100):
27 | dist = tfd.Normal(mean, std)
28 | dist = tfd.TransformedDistribution(dist, TanhBijector())
29 | dist = tfd.Independent(dist, 1)
30 | self._dist = dist
31 | self._samples = samples
32 |
33 | @property
34 | def name(self):
35 | return 'TanhNormalDistribution'
36 |
37 | def __getattr__(self, name):
38 | return getattr(self._dist, name)
39 |
40 | def mean(self):
41 | samples = self._dist.sample(self._samples)
42 | return tf.reduce_mean(samples, 0)
43 |
44 | def stddev(self):
45 | samples = self._dist.sample(self._samples)
46 | mean = tf.reduce_mean(samples, 0, keep_dims=True)
47 | return tf.reduce_mean(tf.pow(samples - mean, 2), 0)
48 |
49 | def mode(self):
50 | samples = self._dist.sample(self._samples)
51 | logprobs = self._dist.log_prob(samples)
52 | mask = tf.one_hot(tf.argmax(logprobs, axis=0), self._samples, axis=0)
53 | return tf.reduce_sum(samples * mask[..., None], 0)
54 |
55 | def entropy(self):
56 | sample = self._dist.sample(self._samples)
57 | logprob = self.log_prob(sample)
58 | return -tf.reduce_mean(logprob, 0)
59 |
60 |
61 | class TanhBijector(tfp.bijectors.Bijector):
62 |
63 | def __init__(self, validate_args=False, name='tanh'):
64 | super(TanhBijector, self).__init__(
65 | forward_min_event_ndims=0,
66 | validate_args=validate_args,
67 | name=name)
68 |
69 | def _forward(self, x):
70 | return tf.nn.tanh(x)
71 |
72 | def _inverse(self, y):
73 | precision = 0.99999997
74 | clipped = tf.where(
75 | tf.less_equal(tf.abs(y), 1.),
76 | tf.clip_by_value(y, -precision, precision), y)
77 | # y = tf.stop_gradient(clipped) + y - tf.stop_gradient(y)
78 | return tf.atanh(clipped)
79 |
80 | def _forward_log_det_jacobian(self, x):
81 | log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype))
82 | return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x))
83 |
--------------------------------------------------------------------------------
/dreamer/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from .base import Base
20 | from .rssm import RSSM
21 |
--------------------------------------------------------------------------------
/dreamer/models/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | from tensorflow_probability import distributions as tfd
21 |
22 | from dreamer import tools
23 |
24 |
25 | class Base(tf.nn.rnn_cell.RNNCell):
26 |
27 | def __init__(self, transition_tpl, posterior_tpl, reuse=None):
28 | super(Base, self).__init__(_reuse=reuse)
29 | self._posterior_tpl = posterior_tpl
30 | self._transition_tpl = transition_tpl
31 | self._debug = False
32 |
33 | @property
34 | def state_size(self):
35 | raise NotImplementedError
36 |
37 | @property
38 | def updates(self):
39 | return []
40 |
41 | @property
42 | def losses(self):
43 | return []
44 |
45 | @property
46 | def output_size(self):
47 | return (self.state_size, self.state_size)
48 |
49 | def zero_state(self, batch_size, dtype):
50 | return tools.nested.map(
51 | lambda size: tf.zeros([batch_size, size], dtype),
52 | self.state_size)
53 |
54 | def features_from_state(self, state):
55 | raise NotImplementedError
56 |
57 | def dist_from_state(self, state, mask=None):
58 | raise NotImplementedError
59 |
60 | def divergence_from_states(self, lhs, rhs, mask=None):
61 | lhs = self.dist_from_state(lhs, mask)
62 | rhs = self.dist_from_state(rhs, mask)
63 | divergence = tfd.kl_divergence(lhs, rhs)
64 | if mask is not None:
65 | divergence = tools.mask(divergence, mask)
66 | return divergence
67 |
68 | def call(self, inputs, prev_state):
69 | obs, prev_action, use_obs = inputs
70 | if self._debug:
71 | with tf.control_dependencies([tf.assert_equal(use_obs, use_obs[0, 0])]):
72 | use_obs = tf.identity(use_obs)
73 | use_obs = use_obs[0, 0]
74 | zero_obs = tools.nested.map(tf.zeros_like, obs)
75 | prior = self._transition_tpl(prev_state, prev_action, zero_obs)
76 | posterior = tf.cond(
77 | use_obs,
78 | lambda: self._posterior_tpl(prev_state, prev_action, obs),
79 | lambda: prior)
80 | return (prior, posterior), posterior
81 |
--------------------------------------------------------------------------------
/dreamer/models/rssm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | from tensorflow_probability import distributions as tfd
21 |
22 | from dreamer import tools
23 | from dreamer.models import base
24 |
25 |
26 | class RSSM(base.Base):
27 |
28 | def __init__(
29 | self, state_size, belief_size, embed_size,
30 | future_rnn=True, mean_only=False, min_stddev=0.1, activation=tf.nn.elu,
31 | num_layers=1):
32 | self._state_size = state_size
33 | self._belief_size = belief_size
34 | self._embed_size = embed_size
35 | self._future_rnn = future_rnn
36 | self._cell = tf.contrib.rnn.GRUBlockCell(self._belief_size)
37 | self._kwargs = dict(units=self._embed_size, activation=activation)
38 | self._mean_only = mean_only
39 | self._min_stddev = min_stddev
40 | self._num_layers = num_layers
41 | super(RSSM, self).__init__(
42 | tf.make_template('transition', self._transition),
43 | tf.make_template('posterior', self._posterior))
44 |
45 | @property
46 | def state_size(self):
47 | return {
48 | 'mean': self._state_size,
49 | 'stddev': self._state_size,
50 | 'sample': self._state_size,
51 | 'belief': self._belief_size,
52 | 'rnn_state': self._belief_size,
53 | }
54 |
55 | @property
56 | def feature_size(self):
57 | return self._belief_size + self._state_size
58 |
59 | def dist_from_state(self, state, mask=None):
60 | if mask is not None:
61 | stddev = tools.mask(state['stddev'], mask, value=1)
62 | else:
63 | stddev = state['stddev']
64 | dist = tfd.MultivariateNormalDiag(state['mean'], stddev)
65 | return dist
66 |
67 | def features_from_state(self, state):
68 | return tf.concat([state['sample'], state['belief']], -1)
69 |
70 | def divergence_from_states(self, lhs, rhs, mask=None):
71 | lhs = self.dist_from_state(lhs, mask)
72 | rhs = self.dist_from_state(rhs, mask)
73 | divergence = tfd.kl_divergence(lhs, rhs)
74 | if mask is not None:
75 | divergence = tools.mask(divergence, mask)
76 | return divergence
77 |
78 | def _transition(self, prev_state, prev_action, zero_obs):
79 | hidden = tf.concat([prev_state['sample'], prev_action], -1)
80 | for _ in range(self._num_layers):
81 | hidden = tf.layers.dense(hidden, **self._kwargs)
82 | belief, rnn_state = self._cell(hidden, prev_state['rnn_state'])
83 | if self._future_rnn:
84 | hidden = belief
85 | for _ in range(self._num_layers):
86 | hidden = tf.layers.dense(hidden, **self._kwargs)
87 | mean = tf.layers.dense(hidden, self._state_size, None)
88 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus)
89 | stddev += self._min_stddev
90 | if self._mean_only:
91 | sample = mean
92 | else:
93 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample()
94 | return {
95 | 'mean': mean,
96 | 'stddev': stddev,
97 | 'sample': sample,
98 | 'belief': belief,
99 | 'rnn_state': rnn_state,
100 | }
101 |
102 | def _posterior(self, prev_state, prev_action, obs):
103 | prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs))
104 | hidden = tf.concat([prior['belief'], obs], -1)
105 | for _ in range(self._num_layers):
106 | hidden = tf.layers.dense(hidden, **self._kwargs)
107 | mean = tf.layers.dense(hidden, self._state_size, None)
108 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus)
109 | stddev += self._min_stddev
110 | if self._mean_only:
111 | sample = mean
112 | else:
113 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample()
114 | return {
115 | 'mean': mean,
116 | 'stddev': stddev,
117 | 'sample': sample,
118 | 'belief': prior['belief'],
119 | 'rnn_state': prior['rnn_state'],
120 | }
121 |
--------------------------------------------------------------------------------
/dreamer/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import conv
20 | from . import proprio
21 | from .basic import feed_forward
22 |
--------------------------------------------------------------------------------
/dreamer/networks/basic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | from tensorflow_probability import distributions as tfd
22 |
23 | from dreamer import distributions
24 | from dreamer import tools
25 |
26 |
27 | def feed_forward(
28 | features, data_shape, num_layers=2, activation=tf.nn.relu,
29 | mean_activation=None, stop_gradient=False, trainable=True, units=100,
30 | std=1.0, low=-1.0, high=1.0, dist='normal', min_std=1e-2, init_std=1.0):
31 | hidden = features
32 | if stop_gradient:
33 | hidden = tf.stop_gradient(hidden)
34 | for _ in range(num_layers):
35 | hidden = tf.layers.dense(hidden, units, activation, trainable=trainable)
36 | mean = tf.layers.dense(
37 | hidden, int(np.prod(data_shape)), mean_activation, trainable=trainable)
38 | mean = tf.reshape(mean, tools.shape(features)[:-1] + data_shape)
39 | if std == 'learned':
40 | std = tf.layers.dense(
41 | hidden, int(np.prod(data_shape)), None, trainable=trainable)
42 | init_std = np.log(np.exp(init_std) - 1)
43 | std = tf.nn.softplus(std + init_std) + min_std
44 | std = tf.reshape(std, tools.shape(features)[:-1] + data_shape)
45 | if dist == 'normal':
46 | dist = tfd.Normal(mean, std)
47 | dist = tfd.Independent(dist, len(data_shape))
48 | elif dist == 'deterministic':
49 | dist = tfd.Deterministic(mean)
50 | dist = tfd.Independent(dist, len(data_shape))
51 | elif dist == 'binary':
52 | dist = tfd.Bernoulli(mean)
53 | dist = tfd.Independent(dist, len(data_shape))
54 | elif dist == 'trunc_normal':
55 | # https://www.desmos.com/calculator/rnksmhtgui
56 | dist = tfd.TruncatedNormal(mean, std, low, high)
57 | dist = tfd.Independent(dist, len(data_shape))
58 | elif dist == 'tanh_normal':
59 | # https://www.desmos.com/calculator/794s8kf0es
60 | dist = distributions.TanhNormal(mean, std)
61 | elif dist == 'tanh_normal_tanh':
62 | # https://www.desmos.com/calculator/794s8kf0es
63 | mean = 5.0 * tf.tanh(mean / 5.0)
64 | dist = distributions.TanhNormal(mean, std)
65 | elif dist == 'onehot_score':
66 | dist = distributions.OneHot(mean, gradient='score')
67 | elif dist == 'onehot_straight':
68 | dist = distributions.OneHot(mean, gradient='straight')
69 | else:
70 | raise NotImplementedError(dist)
71 | return dist
72 |
--------------------------------------------------------------------------------
/dreamer/networks/conv.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | from tensorflow_probability import distributions as tfd
22 |
23 | from dreamer import tools
24 |
25 |
26 | def encoder(obs):
27 | kwargs = dict(strides=2, activation=tf.nn.relu)
28 | hidden = tf.reshape(obs['image'], [-1] + obs['image'].shape[2:].as_list())
29 | hidden = tf.layers.conv2d(hidden, 32, 4, **kwargs)
30 | hidden = tf.layers.conv2d(hidden, 64, 4, **kwargs)
31 | hidden = tf.layers.conv2d(hidden, 128, 4, **kwargs)
32 | hidden = tf.layers.conv2d(hidden, 256, 4, **kwargs)
33 | hidden = tf.layers.flatten(hidden)
34 | assert hidden.shape[1:].as_list() == [1024], hidden.shape.as_list()
35 | hidden = tf.reshape(hidden, tools.shape(obs['image'])[:2] + [
36 | np.prod(hidden.shape[1:].as_list())])
37 | return hidden
38 |
39 |
40 | def decoder(features, data_shape, std=1.0):
41 | kwargs = dict(strides=2, activation=tf.nn.relu)
42 | hidden = tf.layers.dense(features, 1024, None)
43 | hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1].value])
44 | hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs)
45 | hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs)
46 | hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs)
47 | mean = tf.layers.conv2d_transpose(hidden, data_shape[-1], 6, strides=2)
48 | assert mean.shape[1:].as_list() == data_shape, mean.shape
49 | mean = tf.reshape(mean, tools.shape(features)[:-1] + data_shape)
50 | return tfd.Independent(tfd.Normal(mean, std), len(data_shape))
51 |
--------------------------------------------------------------------------------
/dreamer/networks/proprio.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer import tools
22 |
23 |
24 | def encoder(obs, keys=None, num_layers=3, units=300, activation=tf.nn.relu):
25 | if not keys:
26 | keys = [key for key in obs.keys() if key != 'image']
27 | inputs = tf.concat([obs[key] for key in keys], -1)
28 | hidden = tf.reshape(inputs, [-1] + inputs.shape[2:].as_list())
29 | for _ in range(num_layers):
30 | hidden = tf.layers.dense(hidden, units, activation)
31 | hidden = tf.reshape(hidden, tools.shape(inputs)[:2] + [
32 | hidden.shape[1].value])
33 | return hidden
34 |
--------------------------------------------------------------------------------
/dreamer/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import configs
20 | from . import tasks
21 | from . import train
22 |
--------------------------------------------------------------------------------
/dreamer/scripts/fileutil_fetch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import pathlib
17 | import re
18 | import threading
19 |
20 | import sh
21 |
22 |
23 | def execute_commands(commands, parallel):
24 | semaphore = threading.Semaphore(parallel)
25 | def done_fn(cmd, success, exit_code):
26 | print(cmd._foo)
27 | semaphore.release()
28 | running = []
29 | for command in commands:
30 | semaphore.acquire()
31 | running.append(command(_bg=True, _done=done_fn))
32 | running[-1]._foo = command._foo
33 | failures = 0
34 | outputs = []
35 | for command in running:
36 | try:
37 | command.wait()
38 | outputs.append(command.stdout.decode('utf-8'))
39 | except sh.ErrorReturnCode as e:
40 | print(e)
41 | failures += 1
42 | print('')
43 | return outputs, failures
44 |
45 |
46 | def main(args):
47 | pattern = re.compile(args.pattern)
48 | filenames = sh.fileutil.ls('-R', args.indir)
49 | filenames = [filename.strip() for filename in filenames]
50 | filenames = [filename for filename in filenames if pattern.search(filename)]
51 | print('Found', len(filenames), 'filenames.')
52 | commands = []
53 | for index, filename in enumerate(filenames):
54 | relative = pathlib.Path(filename).relative_to(args.indir)
55 | destination = args.outdir / relative
56 | if not args.overwrite and destination.exists():
57 | continue
58 | destination.parent.mkdir(parents=True, exist_ok=True)
59 | flags = [filename, destination]
60 | if args.overwrite:
61 | flags = ['-f'] + flags
62 | command = sh.fileutil.cp.bake(*flags)
63 | command._foo = f'{index + 1}/{len(filenames)} {relative}'
64 | commands.append(command)
65 | print(f'Executing {args.parallel} in parallel.')
66 | execute_commands(commands, args.parallel)
67 |
68 |
69 | if __name__ == '__main__':
70 | boolean = lambda x: bool(['False', 'True'].index(x))
71 | parser = argparse.ArgumentParser()
72 | parser.add_argument('--indir', type=pathlib.Path, required=True)
73 | parser.add_argument('--outdir', type=pathlib.Path, required=True)
74 | parser.add_argument('--parallel', type=int, default=100)
75 | # parser.add_argument('--pattern', type=str, default='.*/records.yaml$')
76 | parser.add_argument(
77 | '--pattern', type=str,
78 | default='.*/(return|length)/records.jsonl$')
79 | parser.add_argument('--subdir', type=boolean, default=True)
80 | parser.add_argument('--overwrite', type=boolean, default=False)
81 | _args = parser.parse_args()
82 | if _args.subdir:
83 | _args.outdir /= _args.indir.stem
84 | main(_args)
85 |
--------------------------------------------------------------------------------
/dreamer/scripts/objectives.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer import tools
22 | from dreamer.control import temporal_difference as td
23 |
24 |
25 | def reward(state, graph, params):
26 | features = graph.cell.features_from_state(state)
27 | reward = graph.heads.reward(features).mean()
28 | return tf.reduce_sum(reward, 1)
29 |
30 |
31 | def reward_value(state, graph, params):
32 | features = graph.cell.features_from_state(state)
33 | reward = graph.heads.reward(features).mean()
34 | value = graph.heads.value(features).mean()
35 | value *= tools.schedule.linear(
36 | graph.step, params.get('objective_value_ramp', 0))
37 | return_ = td.lambda_return(
38 | reward[:, :-1], value[:, :-1], value[:, -1],
39 | params.get('planner_discount', 0.99),
40 | params.get('planner_lambda', 0.95),
41 | axis=1)
42 | return return_[:, 0]
43 |
--------------------------------------------------------------------------------
/dreamer/scripts/sync.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | # Copyright 2019 The Dreamer Authors. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse
18 | import glob
19 | import os
20 | import shutil
21 |
22 |
23 | def find_source_files(directory):
24 | top_level = glob.glob(directory + '*.py')
25 | recursive = glob.glob(directory + '/**/*.py')
26 | return top_level + recursive
27 |
28 |
29 | def copy_source_tree(source_dir, target_dir):
30 | for source in find_source_files(source_dir):
31 | target = os.path.join(target_dir, os.path.relpath(source, source_dir))
32 | os.makedirs(os.path.dirname(target), exist_ok=True)
33 | if os.path.exists(target):
34 | print('Override', os.path.relpath(target, target_dir))
35 | else:
36 | print('Add', os.path.relpath(target, target_dir))
37 | shutil.copy(source, target)
38 | for target in find_source_files(target_dir):
39 | source = os.path.join(source_dir, os.path.relpath(target, target_dir))
40 | if not os.path.exists(source):
41 | print('Remove', os.path.relpath(target, target_dir))
42 | os.remove(target)
43 |
44 |
45 | def infer_headers(directory):
46 | try:
47 | filename = find_source_files(directory)[0]
48 | except IndexError:
49 | raise RuntimeError('No code files found in {}.'.format(directory))
50 | header = []
51 | with open(filename, 'r') as f:
52 | for index, line in enumerate(f):
53 | if index == 0 and not line.startswith('#'):
54 | break
55 | if not line.startswith('#') and line.strip(' \n'):
56 | break
57 | header.append(line)
58 | return header
59 |
60 |
61 | def add_headers(directory, header):
62 | for filename in find_source_files(directory):
63 | with open(filename, 'r') as f:
64 | text = f.readlines()
65 | with open(filename, 'w') as f:
66 | f.write(''.join(header + text))
67 |
68 |
69 | def remove_headers(directory, header):
70 | for filename in find_source_files(directory):
71 | with open(filename, 'r') as f:
72 | text = f.readlines()
73 | if text[:len(header)] == header:
74 | text = text[len(header):]
75 | with open(filename, 'w') as f:
76 | f.write(''.join(text))
77 |
78 |
79 | def main(args):
80 | print('Inferring headers.\n')
81 | source_header = infer_headers(args.source)
82 | print('{} Source header {}\n{}{}\n'.format(
83 | '-' * 32, '-' * 32, ''.join(source_header), '-' * 79))
84 | target_header = infer_headers(args.target)
85 | print('{} Target header {}\n{}{}\n'.format(
86 | '-' * 32, '-' * 32, ''.join(target_header), '-' * 79))
87 | print('Synchronizing directories.')
88 | copy_source_tree(args.source, args.target)
89 | if target_header and not source_header:
90 | print('Adding headers.')
91 | add_headers(args.target, target_header)
92 | if source_header and not target_header:
93 | print('Removing headers.')
94 | remove_headers(args.target, source_header)
95 | print('Done.')
96 |
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument('--source', type=os.path.expanduser, required=True)
101 | parser.add_argument('--target', type=os.path.expanduser, required=True)
102 | main(parser.parse_args())
103 |
--------------------------------------------------------------------------------
/dreamer/scripts/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import functools
21 | import os
22 | import pathlib
23 | import sys
24 |
25 | sys.path.append(str(pathlib.Path(__file__).resolve().parent.parent.parent))
26 |
27 | import ruamel.yaml as yaml
28 | import tensorflow as tf
29 |
30 | from dreamer import tools
31 | from dreamer import training
32 | from dreamer.scripts import configs
33 |
34 |
35 | def process(logdir, args):
36 | with args.params.unlocked:
37 | args.params.logdir = logdir
38 | config = configs.make_config(args.params)
39 | logdir = pathlib.Path(logdir)
40 | metrics = tools.Metrics(logdir / 'metrics', workers=5)
41 | training.utility.collect_initial_episodes(metrics, config)
42 | tf.reset_default_graph()
43 | dataset = tools.numpy_episodes.numpy_episodes(
44 | config.train_dir, config.test_dir, config.batch_shape,
45 | reader=config.data_reader,
46 | loader=config.data_loader,
47 | num_chunks=config.num_chunks,
48 | preprocess_fn=config.preprocess_fn,
49 | gpu_prefetch=config.gpu_prefetch)
50 | metrics = tools.InGraphMetrics(metrics)
51 | build_graph = tools.bind(training.define_model, logdir, metrics)
52 | for score in training.utility.train(build_graph, dataset, logdir, config):
53 | yield score
54 |
55 |
56 | def main(args):
57 | experiment = training.Experiment(
58 | args.logdir,
59 | process_fn=functools.partial(process, args=args),
60 | num_runs=args.num_runs,
61 | ping_every=args.ping_every,
62 | resume_runs=args.resume_runs)
63 | for run in experiment:
64 | for unused_score in run:
65 | pass
66 |
67 |
68 | if __name__ == '__main__':
69 | boolean = lambda x: bool(['False', 'True'].index(x))
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('--logdir', type=pathlib.Path, required=True)
72 | parser.add_argument('--params', default='{}')
73 | parser.add_argument('--num_runs', type=int, default=1)
74 | parser.add_argument('--ping_every', type=int, default=0)
75 | parser.add_argument('--resume_runs', type=boolean, default=True)
76 | parser.add_argument('--dmlab_runfiles_path', default=None)
77 | args_, remaining = parser.parse_known_args()
78 | params_ = args_.params.replace('#', ',').replace('\\', '')
79 | args_.params = tools.AttrDict(yaml.safe_load(params_))
80 | if args_.dmlab_runfiles_path:
81 | with args_.params.unlocked:
82 | args_.params.dmlab_runfiles_path = args_.dmlab_runfiles_path
83 | assert args_.params.dmlab_runfiles_path # Mark as accessed.
84 | args_.logdir = args_.logdir and os.path.expanduser(args_.logdir)
85 | remaining.insert(0, sys.argv[0])
86 | tf.app.run(lambda _: main(args_), remaining)
87 |
--------------------------------------------------------------------------------
/dreamer/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import nested
20 | from . import numpy_episodes
21 | from . import preprocess
22 | from . import schedule
23 | from . import summary
24 | from . import unroll
25 | from .attr_dict import AttrDict
26 | from .bind import bind
27 | from .copy_weights import soft_copy_weights
28 | from .count_dataset import count_dataset
29 | from .count_weights import count_weights
30 | from .custom_optimizer import CustomOptimizer
31 | from .filter_variables_lib import filter_variables
32 | from .gif_summary import gif_summary
33 | from .image_strip_summary import image_strip_summary
34 | from .mask import mask
35 | from .metrics import InGraphMetrics
36 | from .metrics import Metrics
37 | from .overshooting import overshooting
38 | from .reshape_as import reshape_as
39 | from .shape import shape
40 | from .streaming_mean import StreamingMean
41 | from .target_network import track_network
42 |
--------------------------------------------------------------------------------
/dreamer/tools/attr_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import contextlib
21 | import os
22 |
23 | import numpy as np
24 | import ruamel.yaml as yaml
25 |
26 |
27 | class AttrDict(dict):
28 |
29 | def __init__(self, *args, **kwargs):
30 | unlocked = kwargs.pop('_unlocked', not (args or kwargs))
31 | defaults = kwargs.pop('_defaults', {})
32 | touched = kwargs.pop('_touched', set())
33 | super(AttrDict, self).__setattr__('_unlocked', True)
34 | super(AttrDict, self).__setattr__('_touched', set())
35 | super(AttrDict, self).__setattr__('_defaults', {})
36 | super(AttrDict, self).__init__(*args, **kwargs)
37 | super(AttrDict, self).__setattr__('_unlocked', unlocked)
38 | super(AttrDict, self).__setattr__('_defaults', defaults)
39 | super(AttrDict, self).__setattr__('_touched', touched)
40 |
41 | def __getattr__(self, name):
42 | try:
43 | return self[name]
44 | except KeyError:
45 | raise AttributeError(name)
46 |
47 | def __setattr__(self, name, value):
48 | self[name] = value
49 |
50 | def __getitem__(self, name):
51 | if name.startswith('__'):
52 | raise AttributeError(name)
53 | self._touched.add(name)
54 | if name in self:
55 | return super(AttrDict, self).__getitem__(name)
56 | if name in self._defaults:
57 | return self._defaults[name]
58 | raise AttributeError(name)
59 |
60 | def __setitem__(self, name, value):
61 | if name.startswith('_'):
62 | raise AttributeError('Cannot set private attribute {}'.format(name))
63 | if name.startswith('__'):
64 | raise AttributeError("Cannot set magic attribute '{}'".format(name))
65 | if not self._unlocked:
66 | message = 'Use obj.unlock() before setting {}'
67 | raise RuntimeError(message.format(name))
68 | super(AttrDict, self).__setitem__(name, value)
69 |
70 | def __repr__(self):
71 | items = []
72 | for key, value in self.items():
73 | items.append('{}: {}'.format(key, self._format_value(value)))
74 | return '{' + ', '.join(items) + '}'
75 |
76 | def get(self, key, default=None):
77 | self._touched.add(key)
78 | if key not in self:
79 | return default
80 | return self[key]
81 |
82 | @property
83 | def untouched(self):
84 | return sorted(set(self.keys()) - self._touched)
85 |
86 | @property
87 | @contextlib.contextmanager
88 | def unlocked(self):
89 | self.unlock()
90 | yield
91 | self.lock()
92 |
93 | def lock(self):
94 | super(AttrDict, self).__setattr__('_unlocked', False)
95 | for value in self.values():
96 | if isinstance(value, AttrDict):
97 | value.lock()
98 |
99 | def unlock(self):
100 | super(AttrDict, self).__setattr__('_unlocked', True)
101 | for value in self.values():
102 | if isinstance(value, AttrDict):
103 | value.unlock()
104 |
105 | def summarize(self):
106 | items = []
107 | for key, value in self.items():
108 | items.append('{}: {}'.format(key, self._format_value(value)))
109 | return '\n'.join(items)
110 |
111 | def update(self, mapping):
112 | if not self._unlocked:
113 | message = 'Use obj.unlock() before updating'
114 | raise RuntimeError(message)
115 | super(AttrDict, self).update(mapping)
116 | return self
117 |
118 | def copy(self, _unlocked=False):
119 | return type(self)(super(AttrDict, self).copy(), _unlocked=_unlocked)
120 |
121 | def save(self, filename):
122 | assert str(filename).endswith('.yaml')
123 | directory = os.path.dirname(str(filename))
124 | os.makedirs(directory, exist_ok=True)
125 | with open(filename, 'w') as f:
126 | yaml.dump(collections.OrderedDict(self), f)
127 |
128 | @classmethod
129 | def load(cls, filename):
130 | assert str(filename).endswith('.yaml')
131 | with open(filename, 'r') as f:
132 | return cls(yaml.load(f, Loader=yaml.Loader))
133 |
134 | def _format_value(self, value):
135 | if isinstance(value, np.ndarray):
136 | template = ''
137 | min_ = self._format_value(value.min())
138 | mean = self._format_value(value.mean())
139 | max_ = self._format_value(value.max())
140 | return template.format(value.shape, value.dtype, min_, mean, max_)
141 | if isinstance(value, float) and 1e-3 < abs(value) < 1e6:
142 | return '{:.3f}'.format(value)
143 | if isinstance(value, float):
144 | return '{:4.1e}'.format(value)
145 | if hasattr(value, '__name__'):
146 | return value.__name__
147 | return str(value)
148 |
--------------------------------------------------------------------------------
/dreamer/tools/bind.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 |
20 | class bind(object):
21 |
22 | def __init__(self, fn, *args, **kwargs):
23 | self._fn = fn
24 | self._args = args
25 | self._kwargs = kwargs
26 |
27 | def __call__(self, *args, **kwargs):
28 | args_ = self._args + args
29 | kwargs_ = self._kwargs.copy()
30 | kwargs_.update(kwargs)
31 | return self._fn(*args_, **kwargs_)
32 |
33 | def __repr__(self):
34 | return 'bind({})'.format(self._fn.__name__)
35 |
36 | def copy(self, *args, **kwargs):
37 | args_ = self._args + args
38 | kwargs_ = self._kwargs.copy()
39 | kwargs_.update(kwargs)
40 | return bind(self._fn, *args_, **kwargs_)
41 |
--------------------------------------------------------------------------------
/dreamer/tools/chunk_sequence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer.tools import nested
22 |
23 |
24 | def chunk_sequence(sequence, chunk_length, randomize=True, num_chunks=None):
25 | if 'length' in sequence:
26 | length = sequence.pop('length')
27 | else:
28 | length = tf.shape(nested.flatten(sequence)[0])[0]
29 | if randomize:
30 | if not num_chunks:
31 | num_chunks = tf.maximum(1, length // chunk_length - 1)
32 | else:
33 | num_chunks = num_chunks + 0 * length
34 | used_length = num_chunks * chunk_length
35 | max_offset = length - used_length
36 | offset = tf.random_uniform((), 0, max_offset + 1, dtype=tf.int32)
37 | else:
38 | if num_chunks is None:
39 | num_chunks = length // chunk_length
40 | else:
41 | num_chunks = num_chunks + 0 * length
42 | used_length = num_chunks * chunk_length
43 | offset = 0
44 | clipped = nested.map(
45 | lambda tensor: tensor[offset: offset + used_length],
46 | sequence)
47 | chunks = nested.map(
48 | lambda tensor: tf.reshape(
49 | tensor, [num_chunks, chunk_length] + tensor.shape[1:].as_list()),
50 | clipped)
51 | return chunks
52 |
--------------------------------------------------------------------------------
/dreamer/tools/copy_weights.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer.tools import filter_variables_lib
22 |
23 |
24 | def soft_copy_weights(source_pattern, target_pattern, amount):
25 | if not (0 < amount <= 0.5 or amount == 1):
26 | message = 'Copy amount should probably be less than half or everything,'
27 | message += ' not {}'.format(amount)
28 | raise ValueError(message)
29 | source_vars = filter_variables_lib.filter_variables(include=source_pattern)
30 | target_vars = filter_variables_lib.filter_variables(include=target_pattern)
31 | source_vars = sorted(source_vars, key=lambda x: x.name)
32 | target_vars = sorted(target_vars, key=lambda x: x.name)
33 | assert len(source_vars) == len(target_vars)
34 | updates = []
35 | for source, target in zip(source_vars, target_vars):
36 | assert source.name != target.name
37 | if amount == 1.0:
38 | updates.append(target.assign(source))
39 | else:
40 | updates.append(target.assign((1 - amount) * target + amount * source))
41 | return tf.group(*updates)
42 |
--------------------------------------------------------------------------------
/dreamer/tools/count_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 |
25 | def count_dataset(directory):
26 | directory = os.path.expanduser(directory)
27 | if not tf.gfile.Exists(directory):
28 | message = "Data set directory '{}' does not exist."
29 | raise ValueError(message.format(directory))
30 | pattern = os.path.join(directory, '*.npz')
31 | def func():
32 | filenames = tf.gfile.Glob(pattern)
33 | episodes = len(filenames)
34 | episodes = np.array(episodes, dtype=np.int32)
35 | return episodes
36 | return tf.py_func(func, [], tf.int32)
37 |
--------------------------------------------------------------------------------
/dreamer/tools/count_weights.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import re
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 |
25 | def count_weights(scope=None, exclude=None, graph=None):
26 | if scope:
27 | scope = scope if scope.endswith('/') else scope + '/'
28 | graph = graph or tf.get_default_graph()
29 | vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
30 | if scope:
31 | vars_ = [var for var in vars_ if var.name.startswith(scope)]
32 | if exclude:
33 | exclude = re.compile(exclude)
34 | vars_ = [var for var in vars_ if not exclude.match(var.name)]
35 | shapes = []
36 | for var in vars_:
37 | if not var.shape.is_fully_defined():
38 | message = "Trainable variable '{}' has undefined shape '{}'."
39 | raise ValueError(message.format(var.name, var.shape))
40 | shapes.append(var.shape.as_list())
41 | return int(sum(np.prod(shape) for shape in shapes))
42 |
--------------------------------------------------------------------------------
/dreamer/tools/custom_optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer.tools import filter_variables_lib
22 |
23 |
24 | class CustomOptimizer(object):
25 |
26 | def __init__(
27 | self, optimizer_cls, step, log, learning_rate,
28 | include=None, exclude=None, clipping=None, schedule=None,
29 | debug=False, name='custom_optimizer'):
30 | if schedule:
31 | learning_rate *= schedule(step)
32 | self._step = step
33 | self._log = log
34 | self._learning_rate = learning_rate
35 | self._variables = filter_variables_lib.filter_variables(include, exclude)
36 | self._clipping = float(clipping)
37 | self._debug = debug
38 | self._name = name
39 | self._optimizer = optimizer_cls(learning_rate=learning_rate, name=name)
40 |
41 | def maybe_minimize(self, condition, loss):
42 | with tf.name_scope('optimizer_{}'.format(self._name)):
43 | # loss = tf.cond(condition, lambda: loss, float)
44 | update_op, grad_norm = tf.cond(
45 | condition,
46 | lambda: self.minimize(loss),
47 | lambda: (tf.no_op(), 0.0))
48 | with tf.control_dependencies([update_op]):
49 | summary = tf.cond(
50 | tf.logical_and(condition, self._log),
51 | lambda: self.summarize(grad_norm), str)
52 | if self._debug:
53 | # print_op = tf.print('{}_grad_norm='.format(self._name), grad_norm)
54 | message = 'Zero gradient norm in {} optimizer.'.format(self._name)
55 | assertion = lambda: tf.assert_greater(grad_norm, 0.0, message=message)
56 | assert_op = tf.cond(condition, assertion, tf.no_op)
57 | with tf.control_dependencies([assert_op]):
58 | summary = tf.identity(summary)
59 | return summary, grad_norm
60 |
61 | def minimize(self, loss):
62 | if self._debug:
63 | loss = tf.check_numerics(loss, '{}_loss'.format(self._name))
64 | if hasattr(self._optimizer, 'compute_gradients'): # TF 1.0
65 | gradients, variables = zip(*self._optimizer.compute_gradients(
66 | loss, self._variables, colocate_gradients_with_ops=True))
67 | else: # Keras
68 | gradients = self._optimizer.get_gradients(loss, self._variables)
69 | variables = self._variables
70 | grad_norm = tf.global_norm(gradients)
71 | if self._clipping:
72 | gradients, _ = tf.clip_by_global_norm(
73 | gradients, self._clipping, grad_norm)
74 | optimize = self._optimizer.apply_gradients(zip(gradients, variables))
75 | return optimize, grad_norm
76 |
77 | def summarize(self, grad_norm):
78 | summaries = []
79 | with tf.name_scope('optimizer_{}'.format(self._name)):
80 | summaries.append(tf.summary.scalar('learning_rate', self._learning_rate))
81 | summaries.append(tf.summary.scalar('grad_norm', grad_norm))
82 | if self._clipping:
83 | clipped = tf.minimum(grad_norm, self._clipping)
84 | summaries.append(tf.summary.scalar('clipped_gradient_norm', clipped))
85 | summary = tf.summary.merge(summaries)
86 | return summary
87 |
--------------------------------------------------------------------------------
/dreamer/tools/filter_variables_lib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import re
20 |
21 | import tensorflow as tf
22 |
23 |
24 | def filter_variables(include=None, exclude=None):
25 | # Check arguments.
26 | if include is None:
27 | include = (r'.*',)
28 | if exclude is None:
29 | exclude = ()
30 | if not isinstance(include, (tuple, list)):
31 | include = (include,)
32 | if not isinstance(exclude, (tuple, list)):
33 | exclude = (exclude,)
34 | # Compile regexes.
35 | include = [re.compile(regex) for regex in include]
36 | exclude = [re.compile(regex) for regex in exclude]
37 | variables = tf.global_variables()
38 | if not variables:
39 | raise RuntimeError('There are no variables to filter.')
40 | # Check regexes.
41 | for regex in include:
42 | message = "Regex r'{}' does not match any variables in the graph.\n"
43 | message += 'All variables:\n'
44 | message += '\n'.join('- {}'.format(var.name) for var in variables)
45 | if not any(regex.match(variable.name) for variable in variables):
46 | raise RuntimeError(message.format(regex.pattern))
47 | # Filter variables.
48 | filtered = []
49 | for variable in variables:
50 | if not any(regex.match(variable.name) for regex in include):
51 | continue
52 | if any(regex.match(variable.name) for regex in exclude):
53 | continue
54 | filtered.append(variable)
55 | # Check result.
56 | if not filtered:
57 | message = 'No variables left after filtering.'
58 | message += '\nIncludes:\n' + '\n'.join(regex.pattern for regex in include)
59 | message += '\nExcludes:\n' + '\n'.join(regex.pattern for regex in exclude)
60 | raise RuntimeError(message)
61 | return filtered
62 |
--------------------------------------------------------------------------------
/dreamer/tools/gif_summary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | from tensorflow.python.ops import summary_op_util
22 |
23 |
24 | def encode_gif(images, fps):
25 | from subprocess import Popen, PIPE
26 | h, w, c = images[0].shape
27 | cmd = [
28 | 'ffmpeg', '-y',
29 | '-f', 'rawvideo',
30 | '-vcodec', 'rawvideo',
31 | '-r', '%.02f' % fps,
32 | '-s', '%dx%d' % (w, h),
33 | '-pix_fmt', {1: 'gray', 3: 'rgb24'}[c],
34 | '-i', '-',
35 | '-filter_complex',
36 | '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
37 | '-r', '%.02f' % fps,
38 | '-f', 'gif',
39 | '-']
40 | proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
41 | for image in images:
42 | proc.stdin.write(image.tostring())
43 | out, err = proc.communicate()
44 | if proc.returncode:
45 | err = '\n'.join([' '.join(cmd), err.decode('utf8')])
46 | raise IOError(err)
47 | del proc
48 | return out
49 |
50 |
51 | def py_gif_summary(tag, images, max_outputs, fps):
52 | is_bytes = isinstance(tag, bytes)
53 | if is_bytes:
54 | tag = tag.decode('utf-8')
55 | images = np.asarray(images)
56 | if images.dtype != np.uint8:
57 | raise ValueError('Tensor must have dtype uint8 for gif summary.')
58 | if images.ndim != 5:
59 | raise ValueError('Tensor must be 5-D for gif summary.')
60 | batch_size, _, height, width, channels = images.shape
61 | if channels not in (1, 3):
62 | raise ValueError('Tensors must have 1 or 3 channels for gif summary.')
63 | summ = tf.Summary()
64 | num_outputs = min(batch_size, max_outputs)
65 | for i in range(num_outputs):
66 | image_summ = tf.Summary.Image()
67 | image_summ.height = height
68 | image_summ.width = width
69 | image_summ.colorspace = channels # 1: grayscale, 3: RGB
70 | try:
71 | image_summ.encoded_image_string = encode_gif(images[i], fps)
72 | except (IOError, OSError) as e:
73 | print(
74 | 'Unable to encode images to a gif string because either ffmpeg is '
75 | 'not installed or ffmpeg returned an error: %s. Falling back to an '
76 | 'image summary of the first frame in the sequence.', e)
77 | try:
78 | from PIL import Image # pylint: disable=g-import-not-at-top
79 | import io # pylint: disable=g-import-not-at-top
80 | with io.BytesIO() as output:
81 | Image.fromarray(images[i][0]).save(output, 'PNG')
82 | image_summ.encoded_image_string = output.getvalue()
83 | except Exception:
84 | print('Gif summaries requires ffmpeg or PIL to be installed: %s', e)
85 | image_summ.encoded_image_string = (
86 | ''.encode('utf-8') if is_bytes else '')
87 | if num_outputs == 1:
88 | summ_tag = tag
89 | else:
90 | summ_tag = '{}/{}'.format(tag, i)
91 | summ.value.add(tag=summ_tag, image=image_summ)
92 | summ_str = summ.SerializeToString()
93 | return summ_str
94 |
95 |
96 | def gif_summary(name, tensor, max_outputs, fps, collections=None, family=None):
97 | tensor = tf.convert_to_tensor(tensor)
98 | if tensor.dtype in (tf.float32, tf.float64):
99 | tensor = tf.cast(255.0 * tensor, tf.uint8)
100 | with summary_op_util.summary_scope(
101 | name, family, values=[tensor]) as (tag, scope):
102 | val = tf.py_func(
103 | py_gif_summary,
104 | [tag, tensor, max_outputs, fps],
105 | tf.string,
106 | stateful=False,
107 | name=scope)
108 | summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES])
109 | return val
110 |
--------------------------------------------------------------------------------
/dreamer/tools/image_strip_summary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def image_strip_summary(name, images, max_length=100, max_batch=10):
23 | if max_batch:
24 | images = images[:max_batch]
25 | if max_length:
26 | images = images[:, :max_length]
27 | if images.dtype == tf.uint8:
28 | images = tf.cast(images, tf.float32) / 255.0
29 | length, width = tf.shape(images)[1], tf.shape(images)[3]
30 | channels = images.shape[-1].value
31 | images = tf.transpose(images, [0, 2, 1, 3, 4])
32 | images = tf.reshape(images, [1, -1, length * width, channels])
33 | images = tf.clip_by_value(images, 0., 1.)
34 | return tf.summary.image(name, images)
35 |
--------------------------------------------------------------------------------
/dreamer/tools/mask.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def mask(tensor, mask=None, length=None, value=0, debug=False):
23 | if len([x for x in (mask, length) if x is not None]) != 1:
24 | raise KeyError('Exactly one of mask and length must be provided.')
25 | with tf.name_scope('mask'):
26 | if mask is None:
27 | range_ = tf.range(tensor.shape[1].value)
28 | mask = range_[None, :] < length[:, None]
29 | batch_dims = mask.shape.ndims
30 | while tensor.shape.ndims > mask.shape.ndims:
31 | mask = mask[..., None]
32 | multiples = [1] * batch_dims + tensor.shape[batch_dims:].as_list()
33 | mask = tf.tile(mask, multiples)
34 | masked = tf.where(mask, tensor, value * tf.ones_like(tensor))
35 | if debug:
36 | masked = tf.check_numerics(masked, 'masked')
37 | return masked
38 |
--------------------------------------------------------------------------------
/dreamer/tools/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from concurrent import futures
20 | import collections
21 | import datetime
22 | import itertools
23 | import json
24 | import pathlib
25 | import re
26 | import uuid
27 |
28 | # import imageio
29 | import numpy as np
30 | import tensorflow as tf
31 |
32 |
33 | class InGraphMetrics:
34 |
35 | def __init__(self, metrics):
36 | self._metrics = metrics
37 |
38 | def set_tags(self, **tags):
39 | keys, values = zip(*sorted(tags.items(), key=lambda x: x[0]))
40 | def inner(*values):
41 | parsed = []
42 | for index, value in enumerate(values):
43 | if value.dtype == tf.string:
44 | parsed.append(value.numpy().decode('utf-8'))
45 | elif value.dtype in (tf.float32, tf.float64):
46 | parsed.append(float(value.numpy()))
47 | elif value.dtype in (tf.int32, tf.int64):
48 | parsed.append(int(value.numpy()))
49 | else:
50 | raise NotImplementedError(value.dtype)
51 | tags = dict(zip(keys, parsed))
52 | return self._metrics.set_tags(**tags)
53 | # String tensors in tf.py_function are only supported on CPU.
54 | with tf.device('/cpu:0'):
55 | return tf.py_function(inner, values, [], 'set_tags')
56 |
57 | def reset_tags(self):
58 | return tf.py_function(
59 | self._metrics.reset_tags, [], [], 'reset_tags')
60 |
61 | def flush(self):
62 | def inner():
63 | _ = self._metrics.flush()
64 | return tf.py_function(inner, [], [], 'flush')
65 |
66 | def add_scalar(self, name, value):
67 | assert len(value.shape) == 0, (name, value)
68 | def inner(value):
69 | self._metrics.add_scalar(name, value.numpy())
70 | return tf.py_function(inner, [value], [], 'add_scalar_' + name)
71 |
72 | def add_scalars(self, name, value):
73 | assert len(value.shape) == 1, (name, value)
74 | def inner(value):
75 | self._metrics.add_scalars(name, value.numpy())
76 | return tf.py_function(inner, [value], [], 'add_scalars_' + name)
77 |
78 | def add_tensor(self, name, value):
79 | def inner(value):
80 | self._metrics.add_tensor(name, value.numpy())
81 | return tf.py_function(inner, [value], [], 'add_tensor_' + name)
82 |
83 |
84 | class Metrics:
85 |
86 | def __init__(self, directory, workers=None):
87 | assert workers is None or isinstance(workers, int)
88 | self._directory = pathlib.Path(directory).expanduser()
89 | self._records = collections.defaultdict(collections.deque)
90 | self._created_directories = set()
91 | self._values = {}
92 | self._tags = {}
93 | self._writers = {
94 | 'npy': np.save,
95 | # 'png': imageio.imwrite,
96 | # 'jpg': imageio.imwrite,
97 | # 'bmp': imageio.imwrite,
98 | # 'gif': functools.partial(imageio.mimwrite, fps=30),
99 | # 'mp4': functools.partial(imageio.mimwrite, fps=30),
100 | }
101 | if workers:
102 | self._pool = futures.ThreadPoolExecutor(max_workers=workers)
103 | else:
104 | self._pool = None
105 | self._last_futures = []
106 |
107 | @property
108 | def names(self):
109 | names = set(self._records.keys())
110 | for catalogue in self._directory.glob('**/records.jsonl'):
111 | names.add(self._path_to_name(catalogue.parent))
112 | return sorted(names)
113 |
114 | def set_tags(self, **kwargs):
115 | reserved = ('value', 'name')
116 | if any(key in reserved for key in kwargs):
117 | message = "Reserved keys '{}' and cannot be used for tags"
118 | raise KeyError(message.format(', '.join(reserved)))
119 | for key, value in kwargs.items():
120 | key = json.loads(json.dumps(key))
121 | value = json.loads(json.dumps(value))
122 | self._tags[key] = value
123 |
124 | def reset_tags(self):
125 | self._tags = {}
126 |
127 | def add_scalar(self, name, value):
128 | self._validate_name(name)
129 | record = self._tags.copy()
130 | record['value'] = float(value)
131 | self._records[name].append(record)
132 |
133 | def add_scalars(self, name, values):
134 | self._validate_name(name)
135 | for value in values:
136 | record = self._tags.copy()
137 | record['value'] = float(value)
138 | self._records[name].append(record)
139 |
140 | def add_tensor(self, name, value, format='npy'):
141 | assert format in ('npy',)
142 | self._validate_name(name)
143 | record = self._tags.copy()
144 | record['filename'] = self._random_filename('npy')
145 | self._values[record['filename']] = value
146 | self._records[name].append(record)
147 |
148 | def add_image(self, name, value, format='png'):
149 | assert format in ('png', 'jpg', 'bmp')
150 | self._validate_name(name)
151 | record = self._tags.copy()
152 | record['filename'] = self._random_filename(format)
153 | self._values[record['filename']] = value
154 | self._records[name].append(record)
155 |
156 | def add_video(self, name, value, format='gif'):
157 | assert format in ('gif', 'mp4')
158 | self._validate_name(name)
159 | record = self._tags.copy()
160 | record['filename'] = self._random_filename(format)
161 | self._values[record['filename']] = value
162 | self._records[name].append(record)
163 |
164 | def flush(self, blocking=None):
165 | if blocking is False and not self._pool:
166 | message = 'Create with workers argument for non-blocking flushing.'
167 | raise ValueError(message)
168 | for future in self._last_futures or []:
169 | try:
170 | future.result()
171 | except Exception as e:
172 | message = 'Previous asynchronous flush failed.'
173 | raise RuntimeError(message) from e
174 | self._last_futures = None
175 | jobs = []
176 | for name in self._records:
177 | records = list(self._records[name])
178 | if not records:
179 | continue
180 | self._records[name].clear()
181 | filename = self._name_to_path(name) / 'records.jsonl'
182 | jobs += [(self._append_catalogue, filename, records)]
183 | for record in records:
184 | jobs.append((self._write_file, name, record))
185 | if self._pool and not blocking:
186 | futures = []
187 | for job in jobs:
188 | futures.append(self._pool.submit(*job))
189 | self._last_futures = futures.copy()
190 | return futures
191 | else:
192 | for job in jobs:
193 | job[0](*job[1:])
194 |
195 | def query(self, pattern=None, **tags):
196 | pattern = pattern and re.compile(pattern)
197 | for name in self.names:
198 | if pattern and not pattern.search(name):
199 | continue
200 | for record in self._read_records(name):
201 | if {k: v for k, v in record.items() if k in tags} != tags:
202 | continue
203 | record['name'] = name
204 | yield record
205 |
206 | def _validate_name(self, name):
207 | assert isinstance(name, str) and name
208 | if re.search(r'[^a-z0-9/_-]+', name):
209 | message = (
210 | "Invalid metric name '{}'. Names must contain only lower case "
211 | "letters, digits, dashes, underscores, and forward slashes.")
212 | raise NameError(message.format(name))
213 |
214 | def _name_to_path(self, name):
215 | return self._directory.joinpath(*name.split('/'))
216 |
217 | def _path_to_name(self, path):
218 | return '/'.join(path.relative_to(self._directory).parts)
219 |
220 | def _read_records(self, name):
221 | path = self._name_to_path(name)
222 | catalogue = path / 'records.jsonl'
223 | records = self._records[name]
224 | if catalogue.exists():
225 | records = itertools.chain(records, self._load_catalogue(catalogue))
226 | for record in records:
227 | if 'filename' in record:
228 | record['filename'] = path / record['filename']
229 | yield record
230 |
231 | def _write_file(self, name, record):
232 | if 'filename' not in record:
233 | return
234 | format = pathlib.Path(record['filename']).suffix.lstrip('.')
235 | if format not in self._writers:
236 | raise TypeError('Trying to write unknown format {}'.format(format))
237 | value = self._values.pop(record['filename'])
238 | filename = self._name_to_path(name) / record['filename']
239 | self._ensure_directory(filename.parent)
240 | self._writers[format](filename, value)
241 |
242 | def _load_catalogue(self, filename):
243 | rows = [json.loads(line) for line in filename.open('r')]
244 | message = 'Metrics files do not contain lists of mappings'
245 | if not isinstance(rows, list):
246 | raise TypeError(message)
247 | if not all(isinstance(row, dict) for row in rows):
248 | raise TypeError(message)
249 | return rows
250 |
251 | def _append_catalogue(self, filename, records):
252 | self._ensure_directory(filename.parent)
253 | records = [
254 | collections.OrderedDict(sorted(record.items(), key=lambda x: x[0]))
255 | for record in records]
256 | content = ''.join([json.dumps(record) + '\n' for record in records])
257 | with filename.open('a') as f:
258 | f.write(content)
259 |
260 | def _ensure_directory(self, directory):
261 | if directory in self._created_directories:
262 | return
263 | # We first attempt to create the directory and afterwards add it to the set
264 | # of created directories. This means multiple workers could attempt to
265 | # create the directory at the same time, which is better than one trying to
266 | # create a file while another has not yet created the directory.
267 | directory.mkdir(parents=True, exist_ok=True)
268 | self._created_directories.add(directory)
269 |
270 | def _random_filename(self, extension):
271 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
272 | identifier = str(uuid.uuid4()).replace('-', '')
273 | filename = '{}-{}.{}'.format(timestamp, identifier, extension)
274 | return filename
275 |
--------------------------------------------------------------------------------
/dreamer/tools/nested.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 |
18 | _builtin_zip = zip
19 | _builtin_map = map
20 | _builtin_filter = filter
21 |
22 |
23 | def zip_(*structures, **kwargs):
24 | # Named keyword arguments are not allowed after *args in Python 2.
25 | flatten = kwargs.pop('flatten', False)
26 | assert not kwargs, 'zip() got unexpected keyword arguments.'
27 | return map(
28 | lambda *x: x if len(x) > 1 else x[0],
29 | *structures,
30 | flatten=flatten)
31 |
32 |
33 | def map_(function, *structures, **kwargs):
34 | # Named keyword arguments are not allowed after *args in Python 2.
35 | flatten = kwargs.pop('flatten', False)
36 | assert not kwargs, 'map() got unexpected keyword arguments.'
37 |
38 | def impl(function, *structures):
39 | if len(structures) == 0:
40 | return structures
41 | if all(isinstance(s, (tuple, list)) for s in structures):
42 | if len(set(len(x) for x in structures)) > 1:
43 | raise ValueError('Cannot merge tuples or lists of different length.')
44 | args = tuple((impl(function, *x) for x in _builtin_zip(*structures)))
45 | if hasattr(structures[0], '_fields'): # namedtuple
46 | return type(structures[0])(*args)
47 | else: # tuple, list
48 | return type(structures[0])(args)
49 | if all(isinstance(s, dict) for s in structures):
50 | if len(set(frozenset(x.keys()) for x in structures)) > 1:
51 | raise ValueError('Cannot merge dicts with different keys.')
52 | merged = {
53 | k: impl(function, *(s[k] for s in structures))
54 | for k in structures[0]}
55 | return type(structures[0])(merged)
56 | return function(*structures)
57 |
58 | result = impl(function, *structures)
59 | if flatten:
60 | result = flatten_(result)
61 | return result
62 |
63 |
64 | def flatten_(structure):
65 | if isinstance(structure, dict):
66 | result = ()
67 | for key in sorted(list(structure.keys())):
68 | result += flatten_(structure[key])
69 | return result
70 | if isinstance(structure, (tuple, list)):
71 | result = ()
72 | for element in structure:
73 | result += flatten_(element)
74 | return result
75 | return (structure,)
76 |
77 |
78 | def filter_(predicate, *structures, **kwargs):
79 | # Named keyword arguments are not allowed after *args in Python 2.
80 | flatten = kwargs.pop('flatten', False)
81 | assert not kwargs, 'filter() got unexpected keyword arguments.'
82 |
83 | def impl(predicate, *structures):
84 | if len(structures) == 0:
85 | return structures
86 | if all(isinstance(s, (tuple, list)) for s in structures):
87 | if len(set(len(x) for x in structures)) > 1:
88 | raise ValueError('Cannot merge tuples or lists of different length.')
89 | # Only wrap in tuples if more than one structure provided.
90 | if len(structures) > 1:
91 | filtered = (impl(predicate, *x) for x in _builtin_zip(*structures))
92 | else:
93 | filtered = (impl(predicate, x) for x in structures[0])
94 | # Remove empty containers and construct result structure.
95 | if hasattr(structures[0], '_fields'): # namedtuple
96 | filtered = (x if x != () else None for x in filtered)
97 | return type(structures[0])(*filtered)
98 | else: # tuple, list
99 | filtered = (
100 | x for x in filtered if not isinstance(x, (tuple, list, dict)) or x)
101 | return type(structures[0])(filtered)
102 | if all(isinstance(s, dict) for s in structures):
103 | if len(set(frozenset(x.keys()) for x in structures)) > 1:
104 | raise ValueError('Cannot merge dicts with different keys.')
105 | # Only wrap in tuples if more than one structure provided.
106 | if len(structures) > 1:
107 | filtered = {
108 | k: impl(predicate, *(s[k] for s in structures))
109 | for k in structures[0]}
110 | else:
111 | filtered = {k: impl(predicate, v) for k, v in structures[0].items()}
112 | # Remove empty containers and construct result structure.
113 | filtered = {
114 | k: v for k, v in filtered.items()
115 | if not isinstance(v, (tuple, list, dict)) or v}
116 | return type(structures[0])(filtered)
117 | if len(structures) > 1:
118 | return structures if predicate(*structures) else ()
119 | else:
120 | return structures[0] if predicate(structures[0]) else ()
121 |
122 | result = impl(predicate, *structures)
123 | if flatten:
124 | result = flatten_(result)
125 | return result
126 |
127 |
128 | zip = zip_
129 | map = map_
130 | flatten = flatten_
131 | filter = filter_
132 |
--------------------------------------------------------------------------------
/dreamer/tools/numpy_episodes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import functools
20 | import os
21 | import random
22 | import time
23 |
24 | from scipy.ndimage import interpolation
25 | import numpy as np
26 | import tensorflow as tf
27 |
28 | from dreamer.tools import attr_dict
29 | from dreamer.tools import chunk_sequence
30 |
31 |
32 | def numpy_episodes(
33 | train_dir, test_dir, shape, reader=None, loader=None, num_chunks=None,
34 | preprocess_fn=None, gpu_prefetch=False):
35 | reader = reader or episode_reader
36 | loader = loader or cache_loader
37 | try:
38 | dtypes, shapes = _read_spec(reader, train_dir)
39 | except ZeroDivisionError:
40 | dtypes, shapes = _read_spec(reader, test_dir)
41 | train = tf.data.Dataset.from_generator(
42 | functools.partial(loader, reader, train_dir),
43 | dtypes, shapes)
44 | test = tf.data.Dataset.from_generator(
45 | functools.partial(cache_loader, reader, test_dir, shape[0]),
46 | dtypes, shapes)
47 | chunking = lambda x: tf.data.Dataset.from_tensor_slices(
48 | chunk_sequence.chunk_sequence(x, shape[1], True, num_chunks))
49 | train = train.flat_map(chunking)
50 | train = train.batch(shape[0], drop_remainder=True)
51 | if preprocess_fn:
52 | train = train.map(preprocess_fn, tf.data.experimental.AUTOTUNE)
53 | if gpu_prefetch:
54 | train = train.apply(tf.data.experimental.copy_to_device('/gpu:0'))
55 | train = train.prefetch(tf.data.experimental.AUTOTUNE)
56 | test = test.flat_map(chunking)
57 | test = test.batch(shape[0], drop_remainder=True)
58 | if preprocess_fn:
59 | test = test.map(preprocess_fn, tf.data.experimental.AUTOTUNE)
60 | if gpu_prefetch:
61 | test = test.apply(tf.data.experimental.copy_to_device('/gpu:0'))
62 | test = test.prefetch(tf.data.experimental.AUTOTUNE)
63 | return attr_dict.AttrDict(train=train, test=test)
64 |
65 |
66 | def cache_loader(reader, directory, every):
67 | cache = {}
68 | while True:
69 | episodes = _sample(cache.values(), every)
70 | for episode in _permuted(episodes, every):
71 | yield episode
72 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz'))
73 | filenames = [filename for filename in filenames if filename not in cache]
74 | for filename in filenames:
75 | cache[filename] = reader(filename)
76 |
77 |
78 | def recent_loader(reader, directory, every):
79 | recent = {}
80 | cache = {}
81 | while True:
82 | episodes = []
83 | episodes += _sample(recent.values(), every // 2)
84 | episodes += _sample(cache.values(), every // 2)
85 | for episode in _permuted(episodes, every):
86 | yield episode
87 | cache.update(recent)
88 | recent = {}
89 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz'))
90 | filenames = [filename for filename in filenames if filename not in cache]
91 | for filename in filenames:
92 | recent[filename] = reader(filename)
93 |
94 |
95 | def window_loader(reader, directory, window, every):
96 | cache = {}
97 | while True:
98 | episodes = _sample(cache.values(), every)
99 | for episode in _permuted(episodes, every):
100 | yield episode
101 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz'))
102 | filenames = sorted(filenames)[-window:]
103 | for filename in filenames:
104 | if filename not in cache:
105 | cache[filename] = reader(filename)
106 | for key in list(cache.keys()):
107 | if key not in filenames:
108 | del cache[key]
109 |
110 |
111 | def reload_loader(reader, directory):
112 | directory = os.path.expanduser(directory)
113 | while True:
114 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz'))
115 | random.shuffle(filenames)
116 | for filename in filenames:
117 | yield reader(filename)
118 |
119 |
120 | def dummy_loader(reader, directory):
121 | random = np.random.RandomState(seed=0)
122 | dtypes, shapes, length = _read_spec(reader, directory, True, True)
123 | while True:
124 | episode = {}
125 | for key in dtypes:
126 | dtype, shape = dtypes[key], (length,) + shapes[key][1:]
127 | if dtype in (np.float32, np.float64):
128 | episode[key] = random.uniform(0, 1, shape).astype(dtype)
129 | elif dtype in (np.int32, np.int64, np.uint8):
130 | episode[key] = random.uniform(0, 255, shape).astype(dtype)
131 | else:
132 | raise NotImplementedError('Unsupported dtype {}.'.format(dtype))
133 | yield episode
134 |
135 |
136 | def episode_reader(
137 | filename, resize=None, max_length=None, action_noise=None,
138 | clip_rewards=False, pcont_scale=None):
139 | try:
140 | with tf.gfile.Open(filename, 'rb') as file_:
141 | episode = np.load(file_)
142 | except (IOError, ValueError):
143 | # Try again one second later, in case the file was still being written.
144 | time.sleep(1)
145 | with tf.gfile.Open(filename, 'rb') as file_:
146 | episode = np.load(file_)
147 | episode = {key: _convert_type(episode[key]) for key in episode.keys()}
148 | episode['return'] = np.cumsum(episode['reward'])
149 | if 'reward_mask' not in episode:
150 | episode['reward_mask'] = np.ones_like(episode['reward'])[..., None]
151 | if max_length:
152 | episode = {key: value[:max_length] for key, value in episode.items()}
153 | if resize and resize != 1:
154 | factors = (1, resize, resize, 1)
155 | episode['image'] = interpolation.zoom(episode['image'], factors)
156 | if action_noise:
157 | seed = np.fromstring(filename, dtype=np.uint8)
158 | episode['action'] += np.random.RandomState(seed).normal(
159 | 0, action_noise, episode['action'].shape)
160 | if clip_rewards is False:
161 | pass
162 | elif clip_rewards == 'sign':
163 | episode['reward'] = np.sign(episode['reward'])
164 | elif clip_rewards == 'tanh':
165 | episode['reward'] = np.tanh(episode['reward'])
166 | else:
167 | raise NotImplementedError(clip_rewards)
168 | if pcont_scale is not None:
169 | episode['pcont'] *= pcont_scale
170 | return episode
171 |
172 |
173 | def _read_spec(
174 | reader, directory, return_length=False, numpy_types=False):
175 | episodes = reload_loader(reader, directory)
176 | episode = next(episodes)
177 | episodes.close()
178 | dtypes = {key: value.dtype for key, value in episode.items()}
179 | if not numpy_types:
180 | dtypes = {key: tf.as_dtype(value) for key, value in dtypes.items()}
181 | shapes = {key: value.shape for key, value in episode.items()}
182 | shapes = {key: (None,) + shape[1:] for key, shape in shapes.items()}
183 | if return_length:
184 | length = len(episode[list(shapes.keys())[0]])
185 | return dtypes, shapes, length
186 | else:
187 | return dtypes, shapes
188 |
189 |
190 | def _convert_type(array):
191 | if array.dtype == np.float64:
192 | return array.astype(np.float32)
193 | if array.dtype == np.int64:
194 | return array.astype(np.int32)
195 | return array
196 |
197 |
198 | def _sample(sequence, amount):
199 | sequence = list(sequence)
200 | amount = min(amount, len(sequence))
201 | return random.sample(sequence, amount)
202 |
203 |
204 | def _permuted(sequence, amount):
205 | sequence = list(sequence)
206 | if not sequence:
207 | return
208 | index = 0
209 | while True:
210 | for element in np.random.permutation(sequence):
211 | if index >= amount:
212 | return
213 | yield element
214 | index += 1
215 |
--------------------------------------------------------------------------------
/dreamer/tools/overshooting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import functools
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 | from dreamer.tools import nested
25 | from dreamer.tools import shape
26 |
27 |
28 | def overshooting(
29 | cell, target, embedded, prev_action, length, amount, posterior=None,
30 | ignore_input=False):
31 | # Closed loop unroll to get posterior states, which are the starting points
32 | # for open loop unrolls. We don't need the last time step, since we have no
33 | # targets for unrolls from it.
34 | if posterior is None:
35 | use_obs = tf.ones(tf.shape(
36 | nested.flatten(embedded)[0][:, :, :1])[:3], tf.bool)
37 | use_obs = tf.cond(
38 | tf.convert_to_tensor(ignore_input),
39 | lambda: tf.zeros_like(use_obs, tf.bool),
40 | lambda: use_obs)
41 | (_, posterior), _ = tf.nn.dynamic_rnn(
42 | cell, (embedded, prev_action, use_obs), length, dtype=tf.float32,
43 | swap_memory=True)
44 |
45 | # Arrange inputs for every iteration in the open loop unroll. Every loop
46 | # iteration below corresponds to one row in the docstring illustration.
47 | max_length = shape.shape(nested.flatten(embedded)[0])[1]
48 | first_output = {
49 | # 'observ': embedded,
50 | 'prev_action': prev_action,
51 | 'posterior': posterior,
52 | 'target': target,
53 | 'mask': tf.sequence_mask(length, max_length, tf.int32),
54 | }
55 |
56 | progress_fn = lambda tensor: tf.concat([tensor[:, 1:], 0 * tensor[:, :1]], 1)
57 | other_outputs = tf.scan(
58 | lambda past_output, _: nested.map(progress_fn, past_output),
59 | tf.range(amount), first_output)
60 | sequences = nested.map(
61 | lambda lhs, rhs: tf.concat([lhs[None], rhs], 0),
62 | first_output, other_outputs)
63 |
64 | # Merge batch and time dimensions of steps to compute unrolls from every
65 | # time step as one batch. The time dimension becomes the number of
66 | # overshooting distances.
67 | sequences = nested.map(
68 | lambda tensor: _merge_dims(tensor, [1, 2]),
69 | sequences)
70 | sequences = nested.map(
71 | lambda tensor: tf.transpose(
72 | tensor, [1, 0] + list(range(2, tensor.shape.ndims))),
73 | sequences)
74 | merged_length = tf.reduce_sum(sequences['mask'], 1)
75 |
76 | # Mask out padding frames; unnecessary if the input is already masked.
77 | sequences = nested.map(
78 | lambda tensor: tensor * tf.cast(
79 | _pad_dims(sequences['mask'], tensor.shape.ndims),
80 | tensor.dtype),
81 | sequences)
82 |
83 | # Compute open loop rollouts.
84 | use_obs = tf.zeros(tf.shape(sequences['mask']), tf.bool)[..., None]
85 | embed_size = nested.flatten(embedded)[0].shape[2].value
86 | obs = tf.zeros(shape.shape(sequences['mask']) + [embed_size])
87 | prev_state = nested.map(
88 | lambda tensor: tf.concat([0 * tensor[:, :1], tensor[:, :-1]], 1),
89 | posterior)
90 | prev_state = nested.map(
91 | lambda tensor: _merge_dims(tensor, [0, 1]), prev_state)
92 | (priors, _), _ = tf.nn.dynamic_rnn(
93 | cell, (obs, sequences['prev_action'], use_obs),
94 | merged_length,
95 | prev_state)
96 |
97 | # Restore batch dimension.
98 | target, prior, posterior, mask = nested.map(
99 | functools.partial(_restore_batch_dim, batch_size=shape.shape(length)[0]),
100 | (sequences['target'], priors, sequences['posterior'], sequences['mask']))
101 |
102 | mask = tf.cast(mask, tf.bool)
103 | return target, prior, posterior, mask
104 |
105 |
106 | def _merge_dims(tensor, dims):
107 | if isinstance(tensor, (list, tuple, dict)):
108 | return nested.map(tensor, lambda x: _merge_dims(x, dims))
109 | tensor = tf.convert_to_tensor(tensor)
110 | if (np.array(dims) - min(dims) != np.arange(len(dims))).all():
111 | raise ValueError('Dimensions to merge must all follow each other.')
112 | start, end = dims[0], dims[-1]
113 | output = tf.reshape(tensor, tf.concat([
114 | tf.shape(tensor)[:start],
115 | [tf.reduce_prod(tf.shape(tensor)[start: end + 1])],
116 | tf.shape(tensor)[end + 1:]], axis=0))
117 | merged = tensor.shape[start: end + 1].as_list()
118 | output.set_shape(
119 | tensor.shape[:start].as_list() +
120 | [None if None in merged else np.prod(merged)] +
121 | tensor.shape[end + 1:].as_list())
122 | return output
123 |
124 |
125 | def _pad_dims(tensor, rank):
126 | for _ in range(rank - tensor.shape.ndims):
127 | tensor = tensor[..., None]
128 | return tensor
129 |
130 |
131 | def _restore_batch_dim(tensor, batch_size):
132 | initial = shape.shape(tensor)
133 | desired = [batch_size, initial[0] // batch_size] + initial[1:]
134 | return tf.reshape(tensor, desired)
135 |
--------------------------------------------------------------------------------
/dreamer/tools/preprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def preprocess(observ, bits):
23 | if 'image' in observ:
24 | bins = 2 ** bits
25 | image = tf.cast(observ['image'], tf.float32)
26 | if bits < 8:
27 | image = tf.floor(image / 2 ** (8 - bits))
28 | image = image / bins
29 | image = image + tf.random_uniform(tf.shape(image), 0, 1.0 / bins)
30 | image = image - 0.5
31 | observ['image'] = image
32 | return observ
33 |
34 |
35 | def postprocess(image, bits, dtype=tf.float32):
36 | bins = 2 ** bits
37 | if dtype == tf.float32:
38 | image = tf.floor(bins * (image + 0.5)) / bins
39 | elif dtype == tf.uint8:
40 | image = image + 0.5
41 | image = tf.floor(bins * image)
42 | image = image * (256.0 / bins)
43 | image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8)
44 | else:
45 | raise NotImplementedError(dtype)
46 | return image
47 |
--------------------------------------------------------------------------------
/dreamer/tools/reshape_as.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer.tools import nested
22 |
23 |
24 | def reshape_as(tensor, reference):
25 | if isinstance(tensor, (list, tuple, dict)):
26 | return nested.map(tensor, lambda x: reshape_as(x, reference))
27 | tensor = tf.convert_to_tensor(tensor)
28 | reference = tf.convert_to_tensor(reference)
29 | statics = reference.shape.as_list()
30 | dynamics = tf.shape(reference)
31 | shape = [
32 | static if static is not None else dynamics[index]
33 | for index, static in enumerate(statics)]
34 | return tf.reshape(tensor, shape)
35 |
--------------------------------------------------------------------------------
/dreamer/tools/schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 |
22 |
23 | def binary(step, batch_size, after, every, until):
24 | # https://www.desmos.com/calculator/csbhr5cjzz
25 | offset_step = step - after
26 | phase = tf.less(offset_step % every, batch_size)
27 | active = tf.greater_equal(step, after)
28 | if until > 0:
29 | active = tf.logical_and(active, tf.less(step, until))
30 | result = tf.logical_and(phase, active)
31 | result.set_shape(tf.TensorShape([]))
32 | return result
33 |
34 |
35 | def linear(step, ramp, min=None, max=None):
36 | # https://www.desmos.com/calculator/nrumhgvxql
37 | if ramp == 0:
38 | result = tf.constant(1, tf.float32)
39 | if ramp > 0:
40 | result = tf.minimum(
41 | tf.cast(step, tf.float32) / tf.cast(ramp, tf.float32), 1)
42 | if ramp < 0:
43 | result = 1 - linear(step, abs(ramp))
44 | if min is not None and max is not None:
45 | assert min <= max
46 | if min is not None:
47 | min = float(min)
48 | assert 0 <= min <= 1
49 | result = tf.maximum(min, result)
50 | if max is not None:
51 | max = float(max)
52 | assert 0 <= min <= 1
53 | result = tf.minimum(result, max)
54 | result.set_shape(tf.TensorShape([]))
55 | return result
56 |
57 |
58 | def step(step, distance):
59 | # https://www.desmos.com/calculator/sh1zsjtsqg
60 | if distance == 0:
61 | result = tf.constant(1, tf.float32)
62 | if distance > 0:
63 | result = tf.cast(tf.less(step, distance), tf.float32)
64 | if distance < 0:
65 | result = 1 - tf.cast(tf.less(step, abs(distance)), tf.float32)
66 | result.set_shape(tf.TensorShape([]))
67 | return result
68 |
69 |
70 | def exponential(step, distance, target, min=None, max=None):
71 | # https://www.desmos.com/calculator/egt24luequ
72 | target = tf.constant(target, tf.float32)
73 | step = tf.cast(step, tf.float32)
74 | if distance == 0:
75 | result = tf.constant(1, tf.float32)
76 | elif distance > 0:
77 | distance = tf.constant(abs(distance), tf.float32)
78 | result = tf.exp(tf.log(target) / distance) ** step
79 | elif distance < 0:
80 | distance = tf.constant(abs(distance), tf.float32)
81 | result = 1 - tf.exp(tf.log(1 - target) / distance) ** step
82 | if min is not None and max is not None:
83 | assert min <= max
84 | if min is not None:
85 | min = float(min)
86 | assert 0 <= min <= 1
87 | result = tf.maximum(min, result)
88 | if max is not None:
89 | max = float(max)
90 | assert 0 <= min <= 1
91 | result = tf.minimum(result, max)
92 | result.set_shape(tf.TensorShape([]))
93 | return result
94 |
95 |
96 | def linear_reset(step, ramp, after, every):
97 | # https://www.desmos.com/calculator/motbnqhacw
98 | assert every > ramp, (every, ramp) # Would never reach max value.
99 | assert not (every != np.inf and after == np.inf), (every, after)
100 | step, ramp, after, every = [
101 | tf.cast(x, tf.float32) for x in (step, ramp, after, every)]
102 | before = tf.cast(tf.less(step, after), tf.float32) * step
103 | after_mask = tf.cast(tf.greater_equal(step, after), tf.float32)
104 | after = after_mask * ((step - after) % every)
105 | result = tf.minimum((before + after) / ramp, 1)
106 | result.set_shape(tf.TensorShape([]))
107 | return result
108 |
--------------------------------------------------------------------------------
/dreamer/tools/shape.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def shape(tensor):
23 | static = tensor.get_shape().as_list()
24 | dynamic = tf.unstack(tf.shape(tensor))
25 | assert len(static) == len(dynamic)
26 | combined = [d if s is None else s for s, d in zip(static, dynamic)]
27 | return combined
28 |
--------------------------------------------------------------------------------
/dreamer/tools/streaming_mean.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | class StreamingMean(object):
23 |
24 | def __init__(self, shape, dtype, name):
25 | self._dtype = dtype
26 | with tf.variable_scope(name):
27 | self._sum = tf.get_variable(
28 | 'sum', shape, dtype,
29 | tf.constant_initializer(0),
30 | trainable=False)
31 | self._count = tf.get_variable(
32 | 'count', (), tf.int32,
33 | tf.constant_initializer(0),
34 | trainable=False)
35 |
36 | @property
37 | def value(self):
38 | return self._sum / tf.cast(self._count, self._dtype)
39 |
40 | @property
41 | def count(self):
42 | return self._count
43 |
44 | def submit(self, value):
45 | value = tf.convert_to_tensor(value)
46 | # Add a batch dimension if necessary.
47 | if value.shape.ndims == self._sum.shape.ndims:
48 | value = value[None, ...]
49 | if str(value.shape[1:]) != str(self._sum.shape):
50 | message = 'Value shape ({}) does not fit tracked tensor ({}).'
51 | raise ValueError(message.format(value.shape[1:], self._sum.shape))
52 | def assign():
53 | return tf.group(
54 | self._sum.assign_add(tf.reduce_sum(value, 0)),
55 | self._count.assign_add(tf.shape(value)[0]))
56 | not_empty = tf.cast(tf.reduce_prod(tf.shape(value)), tf.bool)
57 | return tf.cond(not_empty, assign, tf.no_op)
58 |
59 | def clear(self):
60 | value = self._sum / tf.cast(self._count, self._dtype)
61 | with tf.control_dependencies([value]):
62 | reset_value = self._sum.assign(tf.zeros_like(self._sum))
63 | reset_count = self._count.assign(0)
64 | with tf.control_dependencies([reset_value, reset_count]):
65 | return tf.identity(value)
66 |
--------------------------------------------------------------------------------
/dreamer/tools/summary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import matplotlib
20 | matplotlib.use('Agg')
21 | import matplotlib.pyplot as plt
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 | from dreamer.tools import count_dataset
26 | from dreamer.tools import gif_summary
27 | from dreamer.tools import image_strip_summary
28 | from dreamer.tools import shape as shapelib
29 |
30 |
31 | def plot_summary(titles, lines, labels, name):
32 | # Use control dependencies to ensure that only one plot summary is executed
33 | # at a time, since matplotlib is not thread safe.
34 | def body_fn(lines):
35 | fig, axes = plt.subplots(
36 | nrows=len(titles), ncols=1, sharex=True, sharey=False,
37 | squeeze=False, figsize=(6, 3 * len(lines)))
38 | axes = axes[:, 0]
39 | for index, ax in enumerate(axes):
40 | ax.set_title(titles[index])
41 | for line, label in zip(lines[index], labels[index]):
42 | ax.plot(line, label=label)
43 | if any(labels[index]):
44 | ax.legend(frameon=False)
45 | fig.tight_layout()
46 | fig.canvas.draw()
47 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
48 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
49 | plt.close(fig)
50 | return image
51 | image = tf.py_func(body_fn, (lines,), tf.uint8)
52 | image = image[None]
53 | summary = tf.summary.image(name, image)
54 | return summary
55 |
56 |
57 | def data_summaries(
58 | data, postprocess_fn, histograms=False, max_batch=6, name='data'):
59 | summaries = []
60 | with tf.variable_scope(name):
61 | if histograms:
62 | for key, value in data.items():
63 | if key in ('image',):
64 | continue
65 | summaries.append(tf.summary.histogram(key, data[key]))
66 | image = data['image'][:max_batch]
67 | if postprocess_fn:
68 | image = postprocess_fn(image)
69 | summaries.append(image_strip_summary.image_strip_summary('image', image))
70 | return summaries
71 |
72 |
73 | def dataset_summaries(directory, name='dataset'):
74 | summaries = []
75 | with tf.variable_scope(name):
76 | episodes = count_dataset.count_dataset(directory)
77 | summaries.append(tf.summary.scalar('episodes', episodes))
78 | return summaries
79 |
80 |
81 | def state_summaries(
82 | cell, prior, posterior, histograms=False, name='state'):
83 | summaries = []
84 | divergence = cell.divergence_from_states(posterior, prior)
85 | prior = cell.dist_from_state(prior)
86 | posterior = cell.dist_from_state(posterior)
87 | prior_entropy = prior.entropy()
88 | posterior_entropy = posterior.entropy()
89 | nan_to_num = lambda x: tf.where(tf.is_nan(x), tf.zeros_like(x), x)
90 | with tf.variable_scope(name):
91 | if histograms:
92 | summaries.append(tf.summary.histogram(
93 | 'prior_entropy_hist', nan_to_num(prior_entropy)))
94 | summaries.append(tf.summary.scalar(
95 | 'prior_entropy', tf.reduce_mean(prior_entropy)))
96 | summaries.append(tf.summary.scalar(
97 | 'prior_std', tf.reduce_mean(prior.stddev())))
98 | if histograms:
99 | summaries.append(tf.summary.histogram(
100 | 'posterior_entropy_hist', nan_to_num(posterior_entropy)))
101 | summaries.append(tf.summary.scalar(
102 | 'posterior_entropy', tf.reduce_mean(posterior_entropy)))
103 | summaries.append(tf.summary.scalar(
104 | 'posterior_std', tf.reduce_mean(posterior.stddev())))
105 | summaries.append(tf.summary.scalar(
106 | 'divergence', tf.reduce_mean(divergence)))
107 | return summaries
108 |
109 |
110 | def dist_summaries(dists, obs, name='dist_summaries'):
111 | summaries = []
112 | with tf.variable_scope(name):
113 | for name, dist in dists.items():
114 | mode = tf.cast(dist.mode(), tf.float32)
115 | mode_mean, mode_var = tf.nn.moments(mode, list(range(mode.shape.ndims)))
116 | mode_std = tf.sqrt(mode_var)
117 | summaries.append(tf.summary.scalar(name + '_mode_mean', mode_mean))
118 | summaries.append(tf.summary.scalar(name + '_mode_std', mode_std))
119 | std = dist.stddev()
120 | std_mean, std_var = tf.nn.moments(std, list(range(std.shape.ndims)))
121 | std_std = tf.sqrt(std_var)
122 | summaries.append(tf.summary.scalar(name + '_std_mean', std_mean))
123 | summaries.append(tf.summary.scalar(name + '_std_std', std_std))
124 | if hasattr(dist, 'distribution') and hasattr(
125 | dist.distribution, 'distribution'):
126 | inner = dist.distribution.distribution
127 | inner_std = tf.reduce_mean(inner.stddev())
128 | summaries.append(tf.summary.scalar(name + '_inner_std', inner_std))
129 | if name in obs:
130 | log_prob = tf.reduce_mean(dist.log_prob(obs[name]))
131 | summaries.append(tf.summary.scalar(name + '_log_prob', log_prob))
132 | abs_error = tf.reduce_mean(tf.abs(mode - obs[name]))
133 | summaries.append(tf.summary.scalar(name + '_abs_error', abs_error))
134 | return summaries
135 |
136 |
137 | def image_summaries(dist, target, name='image', max_batch=6):
138 | summaries = []
139 | with tf.variable_scope(name):
140 | image = dist.mode()[:max_batch]
141 | target = target[:max_batch]
142 | error = ((image - target) + 1) / 2
143 | # empty_frame = 0 * target[:max_batch, :1]
144 | # change = tf.concat([empty_frame, image[:, 1:] - image[:, :-1]], 1)
145 | # change = (change + 1) / 2
146 | summaries.append(image_strip_summary.image_strip_summary(
147 | 'prediction', image))
148 | # summaries.append(image_strip_summary.image_strip_summary(
149 | # 'change', change))
150 | # summaries.append(image_strip_summary.image_strip_summary(
151 | # 'error', error))
152 | # Concat prediction and target vertically.
153 | frames = tf.concat([target, image, error], 2)
154 | # Stack batch entries horizontally.
155 | frames = tf.transpose(frames, [1, 2, 0, 3, 4])
156 | s = shapelib.shape(frames)
157 | frames = tf.reshape(frames, [s[0], s[1], s[2] * s[3], s[4]])
158 | summaries.append(gif_summary.gif_summary(
159 | 'gif', frames[None], max_outputs=1, fps=20))
160 | return summaries
161 |
162 |
163 | def objective_summaries(objectives, name='objectives'):
164 | summaries = []
165 | with tf.variable_scope(name):
166 | for objective in objectives:
167 | summaries.append(tf.summary.scalar(objective.name, objective.value))
168 | return summaries
169 |
170 |
171 | def prediction_summaries(dists, data, state, name='state'):
172 | summaries = []
173 | with tf.variable_scope(name):
174 | # Predictions.
175 | log_probs = {}
176 | for key, dist in dists.items():
177 | if key in ('image',):
178 | continue
179 | if key not in data:
180 | continue
181 | # We only look at the first example in the batch.
182 | log_prob = dist.log_prob(data[key])[0]
183 | prediction = dist.mode()[0]
184 | truth = data[key][0]
185 | plot_name = key
186 | # Ensure that there is a feature dimension.
187 | if prediction.shape.ndims == 1:
188 | prediction = prediction[:, None]
189 | truth = truth[:, None]
190 | prediction = tf.unstack(tf.transpose(prediction, (1, 0)))
191 | truth = tf.unstack(tf.transpose(truth, (1, 0)))
192 | lines = list(zip(prediction, truth))
193 | titles = ['{} {}'.format(key.title(), i) for i in range(len(lines))]
194 | labels = [['Prediction', 'Truth']] * len(lines)
195 | plot_name = '{}_trajectory'.format(key)
196 | # The control dependencies are needed because rendering in matplotlib
197 | # uses global state, so rendering two plots in parallel interferes.
198 | with tf.control_dependencies(summaries):
199 | summaries.append(plot_summary(titles, lines, labels, plot_name))
200 | log_probs[key] = log_prob
201 | log_probs = sorted(log_probs.items(), key=lambda x: x[0])
202 | titles, lines = zip(*log_probs)
203 | titles = [title.title() for title in titles]
204 | lines = [[line] for line in lines]
205 | labels = [[None]] * len(titles)
206 | plot_name = 'logprobs'
207 | with tf.control_dependencies(summaries):
208 | summaries.append(plot_summary(titles, lines, labels, plot_name))
209 | return summaries
210 |
--------------------------------------------------------------------------------
/dreamer/tools/target_network.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer.tools import schedule as schedule_lib
22 | from dreamer.tools import copy_weights
23 |
24 |
25 | def track_network(
26 | trainer, batch_size, source_pattern, target_pattern, every, amount):
27 | init_op = tf.cond(
28 | tf.equal(trainer.global_step, 0),
29 | lambda: copy_weights.soft_copy_weights(
30 | source_pattern, target_pattern, 1.0),
31 | tf.no_op)
32 | schedule = schedule_lib.binary(trainer.step, batch_size, 0, every, -1)
33 | with tf.control_dependencies([init_op]):
34 | return tf.cond(
35 | tf.logical_and(tf.equal(trainer.phase, 'train'), schedule),
36 | lambda: copy_weights.soft_copy_weights(
37 | source_pattern, target_pattern, amount),
38 | tf.no_op)
39 |
--------------------------------------------------------------------------------
/dreamer/tools/unroll.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from dreamer.tools import nested
22 | from dreamer.tools import shape
23 |
24 |
25 | def closed_loop(cell, embedded, prev_action, debug=False):
26 | use_obs = tf.ones(tf.shape(embedded[:, :, :1])[:3], tf.bool)
27 | (prior, posterior), _ = tf.nn.dynamic_rnn(
28 | cell, (embedded, prev_action, use_obs), dtype=tf.float32)
29 | if debug:
30 | with tf.control_dependencies([tf.assert_equal(
31 | tf.shape(nested.flatten(posterior)[0])[1], tf.shape(embedded)[1])]):
32 | prior = nested.map(tf.identity, prior)
33 | posterior = nested.map(tf.identity, posterior)
34 | return prior, posterior
35 |
36 |
37 | def open_loop(cell, embedded, prev_action, context=1, debug=False):
38 | use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool)
39 | (_, closed_state), last_state = tf.nn.dynamic_rnn(
40 | cell, (embedded[:, :context], prev_action[:, :context], use_obs),
41 | dtype=tf.float32)
42 | use_obs = tf.zeros(tf.shape(embedded[:, context:, :1])[:3], tf.bool)
43 | (_, open_state), _ = tf.nn.dynamic_rnn(
44 | cell, (0 * embedded[:, context:], prev_action[:, context:], use_obs),
45 | initial_state=last_state)
46 | state = nested.map(
47 | lambda x, y: tf.concat([x, y], 1),
48 | closed_state, open_state)
49 | if debug:
50 | with tf.control_dependencies([tf.assert_equal(
51 | tf.shape(nested.flatten(state)[0])[1], tf.shape(embedded)[1])]):
52 | state = nested.map(tf.identity, state)
53 | return state
54 |
55 |
56 | def planned(
57 | cell, objective_fn, embedded, prev_action, planner, context=1, length=20,
58 | amount=1000, debug=False):
59 | use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool)
60 | (_, closed_state), last_state = tf.nn.dynamic_rnn(
61 | cell, (embedded[:, :context], prev_action[:, :context], use_obs),
62 | dtype=tf.float32)
63 | _, plan_state, return_ = planner(
64 | cell, objective_fn, last_state,
65 | obs_shape=shape.shape(embedded)[2:],
66 | action_shape=shape.shape(prev_action)[2:],
67 | horizon=length, amount=amount)
68 | state = nested.map(
69 | lambda x, y: tf.concat([x, y], 1),
70 | closed_state, plan_state)
71 | if debug:
72 | with tf.control_dependencies([tf.assert_equal(
73 | tf.shape(nested.flatten(state)[0])[1], context + length)]):
74 | state = nested.map(tf.identity, state)
75 | return_ = tf.identity(return_)
76 | return state, return_
77 |
--------------------------------------------------------------------------------
/dreamer/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import utility
20 | from .define_model import define_model
21 | from .define_summaries import define_summaries
22 | from .running import Experiment
23 | from .trainer import Trainer
24 |
--------------------------------------------------------------------------------
/dreamer/training/define_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import datetime
20 | import functools
21 |
22 | import tensorflow as tf
23 |
24 | from dreamer import tools
25 | from dreamer.training import define_summaries
26 | from dreamer.training import utility
27 |
28 |
29 | def define_model(logdir, metrics, data, trainer, config):
30 | print('Build TensorFlow compute graph.')
31 | dependencies = []
32 | cleanups = []
33 | step = trainer.step
34 | global_step = trainer.global_step
35 | phase = trainer.phase
36 | timestamp = tf.py_func(
37 | lambda: datetime.datetime.utcnow().strftime('%Y%m%dT%H%M%S'),
38 | [], tf.string)
39 | dependencies.append(metrics.set_tags(
40 | global_step=global_step, step=step, phase=phase,
41 | time=timestamp))
42 |
43 | # Instantiate network blocks. Note, this initialization would be expensive
44 | # when using tf.function since it would run at every step.
45 | try:
46 | cell = config.cell()
47 | except TypeError:
48 | cell = config.cell(action_size=data['action'].shape[-1].value)
49 | kwargs = dict(create_scope_now_=True)
50 | encoder = tf.make_template('encoder', config.encoder, **kwargs)
51 | heads = tools.AttrDict(_unlocked=True)
52 | raw_dummy_features = cell.features_from_state(
53 | cell.zero_state(1, tf.float32))[:, None]
54 | for key, head in config.heads.items():
55 | name = 'head_{}'.format(key)
56 | kwargs = dict(create_scope_now_=True)
57 | if key in data:
58 | kwargs['data_shape'] = data[key].shape[2:].as_list()
59 | if key == 'action_target':
60 | kwargs['data_shape'] = data['action'].shape[2:].as_list()
61 | if key == 'cpc':
62 | kwargs['data_shape'] = [cell.feature_size]
63 | dummy_features = encoder(data)[:1, :1]
64 | else:
65 | dummy_features = raw_dummy_features
66 | heads[key] = tf.make_template(name, head, **kwargs)
67 | heads[key](dummy_features) # Initialize weights.
68 |
69 | # Update target networks.
70 | if 'value_target' in heads:
71 | dependencies.append(tools.track_network(
72 | trainer, config.batch_shape[0],
73 | r'.*/head_value/.*', r'.*/head_value_target/.*',
74 | config.value_target_period, config.value_target_update))
75 | if 'value_target_2' in heads:
76 | dependencies.append(tools.track_network(
77 | trainer, config.batch_shape[0],
78 | r'.*/head_value/.*', r'.*/head_value_target_2/.*',
79 | config.value_target_period, config.value_target_update))
80 | if 'action_target' in heads:
81 | dependencies.append(tools.track_network(
82 | trainer, config.batch_shape[0],
83 | r'.*/head_action/.*', r'.*/head_action_target/.*',
84 | config.action_target_period, config.action_target_update))
85 |
86 | # Apply and optimize model.
87 | embedded = encoder(data)
88 | with tf.control_dependencies(dependencies):
89 | embedded = tf.identity(embedded)
90 | graph = tools.AttrDict(locals())
91 | prior, posterior = tools.unroll.closed_loop(
92 | cell, embedded, data['action'], config.debug)
93 | objectives = utility.compute_objectives(
94 | posterior, prior, data, graph, config)
95 | summaries, grad_norms = utility.apply_optimizers(
96 | objectives, trainer, config)
97 | dependencies += summaries
98 |
99 | # Active data collection.
100 | with tf.variable_scope('collection'):
101 | with tf.control_dependencies(dependencies): # Make sure to train first.
102 | for name, params in config.train_collects.items():
103 | schedule = tools.schedule.binary(
104 | step, config.batch_shape[0],
105 | params.steps_after, params.steps_every, params.steps_until)
106 | summary, _ = tf.cond(
107 | tf.logical_and(tf.equal(trainer.phase, 'train'), schedule),
108 | functools.partial(
109 | utility.simulate, metrics, config, params, graph, cleanups,
110 | gif_summary=False, name=name),
111 | lambda: (tf.constant(''), tf.constant(0.0)),
112 | name='should_collect_' + name)
113 | summaries.append(summary)
114 | dependencies.append(summary)
115 |
116 | # Compute summaries.
117 | graph = tools.AttrDict(locals())
118 | summary, score = tf.cond(
119 | trainer.log,
120 | lambda: define_summaries.define_summaries(graph, config, cleanups),
121 | lambda: (tf.constant(''), tf.zeros((0,), tf.float32)),
122 | name='summaries')
123 | summaries = tf.summary.merge([summaries, summary])
124 | dependencies.append(utility.print_metrics(
125 | {ob.name: ob.value for ob in objectives},
126 | step, config.print_metrics_every, 2, 'objectives'))
127 | dependencies.append(utility.print_metrics(
128 | grad_norms, step, config.print_metrics_every, 2, 'grad_norms'))
129 | dependencies.append(tf.cond(trainer.log, metrics.flush, tf.no_op))
130 | with tf.control_dependencies(dependencies):
131 | score = tf.identity(score)
132 | return score, summaries, cleanups
133 |
--------------------------------------------------------------------------------
/dreamer/training/define_summaries.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | from tensorflow_probability import distributions as tfd
21 |
22 | from dreamer import tools
23 | from dreamer.training import utility
24 | from dreamer.tools import summary
25 |
26 |
27 | def define_summaries(graph, config, cleanups):
28 | summaries = []
29 | # plot_summaries = [] # Control dependencies for non thread-safe matplot.
30 | heads = graph.heads.copy(_unlocked=True)
31 | last_time = tf.Variable(lambda: tf.timestamp(), trainable=False)
32 | last_step = tf.Variable(lambda: 0.0, trainable=False, dtype=tf.float64)
33 |
34 | def transform(dist):
35 | mean = config.postprocess_fn(dist.mean())
36 | mean = tf.clip_by_value(mean, 0.0, 1.0)
37 | return tfd.Independent(tfd.Normal(mean, 1.0), len(dist.event_shape))
38 | heads['image'] = lambda features: transform(graph.heads['image'](features))
39 | heads.pop('cpc', None) # Not applied to RNN states.
40 |
41 | with tf.variable_scope('general'):
42 | summaries += summary.data_summaries(graph.data, config.postprocess_fn)
43 | summaries += summary.dataset_summaries(config.train_dir, 'train_dataset')
44 | summaries += summary.dataset_summaries(config.test_dir, 'test_dataset')
45 | summaries += summary.objective_summaries(graph.objectives)
46 | summaries.append(tf.summary.scalar('step', graph.step))
47 | new_time, new_step = tf.timestamp(), tf.cast(graph.global_step, tf.float64)
48 | delta_time, delta_step = new_time - last_time, new_step - last_step
49 | with tf.control_dependencies([delta_time, delta_step]):
50 | assign_ops = [last_time.assign(new_time), last_step.assign(new_step)]
51 | with tf.control_dependencies(assign_ops):
52 | summaries.append(tf.summary.scalar(
53 | 'steps_per_second', delta_step / delta_time))
54 | summaries.append(tf.summary.scalar(
55 | 'seconds_per_step', delta_time / delta_step))
56 |
57 | with tf.variable_scope('closedloop'):
58 | prior, posterior = tools.unroll.closed_loop(
59 | graph.cell, graph.embedded, graph.data['action'], config.debug)
60 | summaries += summary.state_summaries(graph.cell, prior, posterior)
61 | # with tf.variable_scope('prior'):
62 | # prior_features = graph.cell.features_from_state(prior)
63 | # prior_dists = {
64 | # name: head(prior_features)
65 | # for name, head in heads.items()}
66 | # summaries += summary.dist_summaries(prior_dists, graph.data)
67 | # summaries += summary.image_summaries(
68 | # prior_dists['image'], config.postprocess_fn(graph.data['image']))
69 | # with tf.variable_scope('posterior'):
70 | # posterior_features = graph.cell.features_from_state(posterior)
71 | # posterior_dists = {
72 | # name: head(posterior_features)
73 | # for name, head in heads.items()}
74 | # summaries += summary.dist_summaries(
75 | # posterior_dists, graph.data)
76 | # summaries += summary.image_summaries(
77 | # posterior_dists['image'],
78 | # config.postprocess_fn(graph.data['image']))
79 |
80 | with tf.variable_scope('openloop'):
81 | state = tools.unroll.open_loop(
82 | graph.cell, graph.embedded, graph.data['action'],
83 | config.open_loop_context, config.debug)
84 | state_features = graph.cell.features_from_state(state)
85 | state_dists = {name: head(state_features) for name, head in heads.items()}
86 | if 'cpc' in graph.heads:
87 | state_dists['cpc'] = graph.heads.cpc(graph.embedded)
88 | summaries += summary.dist_summaries(state_dists, graph.data)
89 | summaries += summary.image_summaries(
90 | state_dists['image'], config.postprocess_fn(graph.data['image']))
91 | summaries += summary.state_summaries(graph.cell, state, posterior)
92 | # with tf.control_dependencies(plot_summaries):
93 | # plot_summary = summary.prediction_summaries(
94 | # state_dists, graph.data, state)
95 | # plot_summaries += plot_summary
96 | # summaries += plot_summary
97 |
98 | with tf.variable_scope('simulation'):
99 | sim_returns = []
100 | for name, params in config.test_collects.items():
101 | # These are expensive and equivalent for train and test phases, so only
102 | # do one of them.
103 | sim_summary, sim_return = tf.cond(
104 | tf.equal(graph.phase, 'test'),
105 | lambda: utility.simulate(
106 | graph.metrics, config, params, graph, cleanups,
107 | gif_summary=True, name=name),
108 | lambda: ('', 0.0),
109 | name='should_simulate_' + params.task.name)
110 | summaries.append(sim_summary)
111 | sim_returns.append(sim_return)
112 |
113 | summaries = tf.summary.merge(summaries)
114 | score = tf.reduce_mean(sim_returns)[None]
115 | return summaries, score
116 |
--------------------------------------------------------------------------------
/dreamer/training/running.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import datetime
20 | import itertools
21 | import os
22 | import pickle
23 | import six
24 | import sys
25 | import threading
26 | import time
27 | import traceback
28 | import uuid
29 |
30 | import numpy as np
31 | import tensorflow as tf
32 |
33 |
34 | class StopExperiment(Exception):
35 | pass
36 |
37 |
38 | class WorkerConflict(Exception):
39 | pass
40 |
41 |
42 | class SkipRun(Exception):
43 | pass
44 |
45 |
46 | class Experiment(object):
47 |
48 | def __init__(
49 | self, basedir, process_fn, start_fn=None, resume_fn=None,
50 | num_runs=None, worker_name=None, ping_every=30, resume_runs=True):
51 | self._basedir = basedir
52 | self._process_fn = process_fn
53 | self._start_fn = start_fn
54 | self._resume_fn = resume_fn
55 | self._num_runs = num_runs
56 | self._worker_name = worker_name or str(uuid.uuid4())
57 | self._ping_every = ping_every
58 | self._ping_stale = ping_every and 2 * ping_every
59 | self._resume_runs = resume_runs
60 |
61 | def __iter__(self):
62 | for current_run in self._generate_run_numbers():
63 | logdir = self._basedir and os.path.join(
64 | self._basedir, '{:05}'.format(current_run))
65 | try:
66 | run = Run(
67 | logdir, self._process_fn, self._start_fn, self._resume_fn,
68 | self._worker_name, self._ping_every, self._ping_stale,
69 | self._resume_runs)
70 | yield run
71 | except SkipRun:
72 | continue
73 | except StopExperiment:
74 | print('Stopping.')
75 | break
76 | print('All runs completed.')
77 |
78 | def _generate_run_numbers(self):
79 | if self._num_runs:
80 | # Don't wait initially and see if there are runs that are already stale.
81 | runs = np.random.permutation(range(self._num_runs))
82 | for run in runs:
83 | yield run + 1
84 | # At the end, wait for all dead runs to become stale, and pick them up.
85 | # This is necessary for complete runs of workers that died very recently.
86 | if self._ping_stale:
87 | time.sleep(self._ping_stale)
88 | for run in runs:
89 | yield run + 1
90 | else:
91 | # For infinite runs, we want to always finish started jobs first.
92 | # Therefore, we need to wait for them to become stale in the beginning.
93 | if self._ping_stale:
94 | time.sleep(self._ping_stale)
95 | for run in itertools.count():
96 | yield run + 1
97 |
98 |
99 | class Run(object):
100 |
101 | def __init__(
102 | self, logdir, process_fn, start_fn, resume_fn, worker_name,
103 | ping_every=30, ping_stale=60, reuse_if_exists=True):
104 | self._logdir = os.path.expanduser(logdir)
105 | self._process_fn = process_fn
106 | self._worker_name = worker_name
107 | self._ping_every = ping_every
108 | self._ping_stale = ping_stale
109 | self._logger = self._create_logger()
110 | try:
111 | if self._should_start():
112 | self._claim()
113 | self._logger.info('Start.')
114 | self._init_fn = start_fn
115 | elif reuse_if_exists and self._should_resume():
116 | self._claim()
117 | self._logger.info('Resume.')
118 | self._init_fn = resume_fn
119 | else:
120 | raise SkipRun
121 | except WorkerConflict:
122 | self._logger.info('Leave to other worker.')
123 | raise SkipRun
124 | self._thread = None
125 | self._running = [True]
126 | self._thread = threading.Thread(target=self._store_ping_thread)
127 | self._thread.daemon = True # Terminate with main thread.
128 | self._thread.start()
129 |
130 | def __iter__(self):
131 | try:
132 | args = self._init_fn and self._init_fn(self._logdir)
133 | if args is None:
134 | args = ()
135 | if not isinstance(args, tuple):
136 | args = (args,)
137 | for value in self._process_fn(self._logdir, *args):
138 | if not self._running[0]:
139 | break
140 | yield value
141 | self._logger.info('Done.')
142 | self._store_done()
143 | except WorkerConflict:
144 | self._logging.warn('Unexpected takeover.')
145 | raise SkipRun
146 | except Exception as e:
147 | exc_info = sys.exc_info()
148 | self._handle_exception(e)
149 | six.reraise(*exc_info)
150 | finally:
151 | self._running[0] = False
152 | self._thread and self._thread.join()
153 |
154 | def _should_start(self):
155 | if not self._logdir:
156 | return True
157 | if tf.gfile.Exists(os.path.join(self._logdir, 'PING')):
158 | return False
159 | if tf.gfile.Exists(os.path.join(self._logdir, 'DONE')):
160 | return False
161 | return True
162 |
163 | def _should_resume(self):
164 | if not self._logdir:
165 | return False
166 | if tf.gfile.Exists(os.path.join(self._logdir, 'DONE')):
167 | # self._logger.debug('Already done.')
168 | return False
169 | if not tf.gfile.Exists(os.path.join(self._logdir, 'PING')):
170 | # self._logger.debug('Not started yet.')
171 | return False
172 | last_worker, last_ping = self._read_ping()
173 | if last_worker != self._worker_name and last_ping < self._ping_stale:
174 | # self._logger.debug('Already in progress.')
175 | return False
176 | return True
177 |
178 | def _claim(self):
179 | if not self._logdir:
180 | return False
181 | self._store_ping(overwrite=True)
182 | if self._ping_every:
183 | time.sleep(self._ping_every)
184 | if self._read_ping()[0] != self._worker_name:
185 | raise WorkerConflict
186 | self._store_ping()
187 |
188 | def _store_done(self):
189 | if not self._logdir:
190 | return
191 | with tf.gfile.Open(os.path.join(self._logdir, 'DONE'), 'w') as file_:
192 | file_.write('\n')
193 |
194 | def _store_fail(self, message):
195 | if not self._logdir:
196 | return
197 | with tf.gfile.Open(os.path.join(self._logdir, 'FAIL'), 'w') as file_:
198 | file_.write(message + '\n')
199 |
200 | def _read_ping(self):
201 | if not tf.gfile.Exists(os.path.join(self._logdir, 'PING')):
202 | return None, None
203 | try:
204 | with tf.gfile.Open(os.path.join(self._logdir, 'PING'), 'rb') as file_:
205 | last_worker, last_ping = pickle.load(file_)
206 | duration = (datetime.datetime.utcnow() - last_ping).total_seconds()
207 | return last_worker, duration
208 | except (EOFError, IOError, tf.errors.NotFoundError):
209 | raise WorkerConflict
210 |
211 | def _store_ping(self, overwrite=False):
212 | if not self._logdir:
213 | return
214 | try:
215 | last_worker, _ = self._read_ping()
216 | if last_worker is None:
217 | self._logger.info("Create directory '{}'.".format(self._logdir))
218 | tf.gfile.MakeDirs(self._logdir)
219 | elif last_worker != self._worker_name and not overwrite:
220 | raise WorkerConflict
221 | # self._logger.debug('Store ping.')
222 | with tf.gfile.Open(os.path.join(self._logdir, 'PING'), 'wb') as file_:
223 | pickle.dump((self._worker_name, datetime.datetime.utcnow()), file_)
224 | except (EOFError, IOError, tf.errors.NotFoundError):
225 | raise WorkerConflict
226 |
227 | def _store_ping_thread(self):
228 | if not self._ping_every:
229 | return
230 | try:
231 | last_write = time.time()
232 | self._store_ping(self._logdir)
233 | while self._running[0]:
234 | if time.time() >= last_write + self._ping_every:
235 | last_write = time.time()
236 | self._store_ping(self._logdir)
237 | # Only wait short times to quickly react to abort.
238 | time.sleep(0.01)
239 | except WorkerConflict:
240 | self._running[0] = False
241 |
242 | def _handle_exception(self, exception):
243 | message = ''.join(traceback.format_exception(*sys.exc_info()))
244 | self._logger.warning('Exception:\n{}'.format(message))
245 | self._logger.warning('Failed.')
246 | try:
247 | self._store_done()
248 | self._store_fail(message)
249 | except Exception:
250 | message = ''.join(traceback.format_exception(*sys.exc_info()))
251 | template = 'Exception in exception handler:\n{}'
252 | self._logger.warning(template.format(message))
253 |
254 | def _create_logger(self):
255 | run_name = self._logdir and os.path.basename(self._logdir)
256 | methods = {}
257 | for name in 'debug info warning'.split():
258 | methods[name] = lambda unused_self, message: print(
259 | 'Worker {} run {}: {}'.format(self._worker_name, run_name, message))
260 | return type('PrefixedLogger', (object,), methods)()
261 |
--------------------------------------------------------------------------------
/dreamer/training/trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import os
21 |
22 | import tensorflow as tf
23 |
24 | from dreamer import tools
25 |
26 |
27 | _Phase = collections.namedtuple(
28 | 'Phase',
29 | 'name, writer, op, batch_size, steps, feed, report_every, log_every,'
30 | 'checkpoint_every, restore_every')
31 |
32 |
33 | class Trainer(object):
34 |
35 | def __init__(self, logdir, config=None):
36 | self._logdir = logdir
37 | self._global_step = tf.train.get_or_create_global_step()
38 | self._step = tf.placeholder(tf.int32, name='step')
39 | self._phase = tf.placeholder(tf.string, name='phase')
40 | self._log = tf.placeholder(tf.bool, name='log')
41 | self._report = tf.placeholder(tf.bool, name='report')
42 | self._reset = tf.placeholder(tf.bool, name='reset')
43 | self._phases = []
44 | # Checkpointing.
45 | self._loaders = []
46 | self._savers = []
47 | self._logdirs = []
48 | self._checkpoints = []
49 | self._config = config or tools.AttrDict()
50 |
51 | @property
52 | def global_step(self):
53 | return self._global_step
54 |
55 | @property
56 | def step(self):
57 | return self._step
58 |
59 | @property
60 | def phase(self):
61 | return self._phase
62 |
63 | @property
64 | def log(self):
65 | return self._log
66 |
67 | @property
68 | def reset(self):
69 | return self._reset
70 |
71 | def add_saver(
72 | self, include=r'.*', exclude=r'.^', logdir=None, load=True, save=True,
73 | checkpoint=None):
74 | variables = tools.filter_variables(include, exclude)
75 | saver = tf.train.Saver(variables, max_to_keep=1)
76 | if load:
77 | self._loaders.append(saver)
78 | if save:
79 | self._savers.append(saver)
80 | self._logdirs.append(logdir or self._logdir)
81 | if checkpoint is None and self._config.checkpoint_to_load:
82 | self._checkpoints.append(
83 | os.path.join(self._logdirs[-1], self._config.checkpoint_to_load))
84 | else:
85 | self._checkpoints.append(checkpoint)
86 |
87 | def add_phase(
88 | self, name, steps, score, summary, batch_size=1,
89 | report_every=None, log_every=None, checkpoint_every=None,
90 | restore_every=None, feed=None):
91 | score = tf.convert_to_tensor(score, tf.float32)
92 | summary = tf.convert_to_tensor(summary, tf.string)
93 | feed = feed or {}
94 | if not score.shape.ndims:
95 | score = score[None]
96 | writer = self._logdir and tf.summary.FileWriter(
97 | os.path.join(self._logdir, name),
98 | tf.get_default_graph(), flush_secs=30)
99 | op = self._define_step(name, batch_size, score, summary)
100 | self._phases.append(_Phase(
101 | name, writer, op, batch_size, int(steps), feed, report_every,
102 | log_every, checkpoint_every, restore_every))
103 |
104 | def run(self, max_step=None, sess=None, unused_saver=None):
105 | for _ in self.iterate(max_step, sess):
106 | pass
107 |
108 | def iterate(self, max_step=None, sess=None):
109 | sess = sess or self._create_session()
110 | with sess:
111 | self._initialize_variables(
112 | sess, self._loaders, self._logdirs, self._checkpoints)
113 | sess.graph.finalize()
114 | while True:
115 | global_step = sess.run(self._global_step)
116 | if max_step and global_step >= max_step:
117 | break
118 | phase, epoch, steps_in = self._find_current_phase(global_step)
119 | phase_step = epoch * phase.steps + steps_in
120 | if steps_in % phase.steps < phase.batch_size:
121 | message = '\n' + ('-' * 50) + '\n'
122 | message += 'Epoch {} phase {} (phase step {}, global step {}).'
123 | print(message.format(epoch + 1, phase.name, phase_step, global_step))
124 | # Populate book keeping tensors.
125 | phase.feed[self._step] = phase_step
126 | phase.feed[self._phase] = phase.name
127 | phase.feed[self._reset] = (steps_in < phase.batch_size)
128 | phase.feed[self._log] = phase.writer and self._is_every_steps(
129 | phase_step, phase.batch_size, phase.log_every)
130 | phase.feed[self._report] = self._is_every_steps(
131 | phase_step, phase.batch_size, phase.report_every)
132 | summary, mean_score, global_step = sess.run(phase.op, phase.feed)
133 | if self._is_every_steps(
134 | phase_step, phase.batch_size, phase.checkpoint_every):
135 | for saver in self._savers:
136 | self._store_checkpoint(sess, saver, global_step)
137 | if self._is_every_steps(
138 | phase_step, phase.batch_size, phase.report_every):
139 | print('Score {}.'.format(mean_score))
140 | yield mean_score
141 | if summary and phase.writer:
142 | # We want smaller phases to catch up at the beginnig of each epoch so
143 | # that their graphs are aligned.
144 | longest_phase = max(phase_.steps for phase_ in self._phases)
145 | summary_step = epoch * longest_phase + steps_in
146 | phase.writer.add_summary(summary, summary_step)
147 | if self._is_every_steps(
148 | phase_step, phase.batch_size, phase.restore_every):
149 | self._initialize_variables(
150 | sess, self._loaders, self._logdirs, self._checkpoints)
151 |
152 | def _is_every_steps(self, phase_step, batch, every):
153 | if not every:
154 | return False
155 | covered_steps = range(phase_step, phase_step + batch)
156 | return any((step + 1) % every == 0 for step in covered_steps)
157 |
158 | def _find_current_phase(self, global_step):
159 | epoch_size = sum(phase.steps for phase in self._phases)
160 | epoch = int(global_step // epoch_size)
161 | steps_in = global_step % epoch_size
162 | for phase in self._phases:
163 | if steps_in < phase.steps:
164 | return phase, epoch, steps_in
165 | steps_in -= phase.steps
166 |
167 | def _define_step(self, name, batch_size, score, summary):
168 | with tf.variable_scope('phase_{}'.format(name)):
169 | score_mean = tools.StreamingMean((), tf.float32, 'score_mean')
170 | score.set_shape((None,))
171 | with tf.control_dependencies([score, summary]):
172 | submit_score = score_mean.submit(score)
173 | with tf.control_dependencies([submit_score]):
174 | mean_score = tf.cond(self._report, score_mean.clear, float)
175 | summary = tf.cond(
176 | self._report,
177 | lambda: tf.summary.merge([summary, tf.summary.scalar(
178 | name + '/score', mean_score, family='trainer')]),
179 | lambda: summary)
180 | next_step = self._global_step.assign_add(batch_size)
181 | with tf.control_dependencies([summary, mean_score, next_step]):
182 | return (
183 | tf.identity(summary),
184 | tf.identity(mean_score),
185 | tf.identity(next_step))
186 |
187 | def _create_session(self):
188 | config = tf.ConfigProto()
189 | config.gpu_options.allow_growth = True
190 | try:
191 | return tf.Session('local', config=config)
192 | except tf.errors.NotFoundError:
193 | return tf.Session(config=config)
194 |
195 | def _initialize_variables(self, sess, savers, logdirs, checkpoints):
196 | sess.run(tf.group(
197 | tf.local_variables_initializer(),
198 | tf.global_variables_initializer()))
199 | assert len(savers) == len(logdirs) == len(checkpoints)
200 | for i, (saver, logdir, checkpoint) in enumerate(
201 | zip(savers, logdirs, checkpoints)):
202 | logdir = os.path.expanduser(logdir)
203 | state = tf.train.get_checkpoint_state(logdir)
204 | if checkpoint:
205 | checkpoint = os.path.join(logdir, checkpoint)
206 | if not checkpoint and state and state.model_checkpoint_path:
207 | checkpoint = state.model_checkpoint_path
208 | if checkpoint:
209 | saver.restore(sess, checkpoint)
210 |
211 | def _store_checkpoint(self, sess, saver, global_step):
212 | if not self._logdir or not saver:
213 | return
214 | tf.gfile.MakeDirs(self._logdir)
215 | filename = os.path.join(self._logdir, 'model.ckpt')
216 | saver.save(sess, filename, global_step)
217 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import setuptools
20 |
21 |
22 | setuptools.setup(
23 | name='dreamer',
24 | version='1.0.0',
25 | description=(
26 | 'Dream to Explore: Learning Behaviors by Latent Imagination'),
27 | license='Apache 2.0',
28 | url='http://github.com/google-research/dreamer',
29 | install_requires=[
30 | 'dm_control',
31 | 'gym',
32 | 'matplotlib',
33 | 'ruamel.yaml',
34 | 'scikit-image',
35 | 'scipy',
36 | 'tensorflow-gpu==1.13.1',
37 | 'tensorflow_probability==0.6.0',
38 | ],
39 | packages=setuptools.find_packages(),
40 | classifiers=[
41 | 'Programming Language :: Python :: 3',
42 | 'License :: OSI Approved :: Apache Software License',
43 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
44 | 'Intended Audience :: Science/Research',
45 | ],
46 | )
47 |
--------------------------------------------------------------------------------
/tests/test_dreamer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import pathlib
20 | import sys
21 |
22 | sys.path.append(str(pathlib.Path(__file__).parent.parent))
23 |
24 | import tensorflow as tf
25 | from dreamer import tools
26 | from dreamer.scripts import train
27 |
28 |
29 | class DreamerTest(tf.test.TestCase):
30 |
31 | def test_dreamer(self):
32 | args = tools.AttrDict(
33 | logdir=self.get_temp_dir(),
34 | num_runs=1,
35 | params=tools.AttrDict(
36 | defaults=['dreamer', 'debug'],
37 | tasks=['dummy'],
38 | isolate_envs='none',
39 | max_steps=30,
40 | train_planner='policy_sample',
41 | test_planner='policy_mode',
42 | planner_objective='reward_value',
43 | action_head=True,
44 | value_head=True,
45 | imagination_horizon=3),
46 | ping_every=0,
47 | resume_runs=False)
48 | train.main(args)
49 |
50 | def test_dreamer_discrete(self):
51 | args = tools.AttrDict(
52 | logdir=self.get_temp_dir(),
53 | num_runs=1,
54 | params=tools.AttrDict(
55 | defaults=['dreamer', 'debug'],
56 | tasks=['dummy'],
57 | isolate_envs='none',
58 | max_steps=30,
59 | train_planner='policy_sample',
60 | test_planner='policy_mode',
61 | planner_objective='reward_value',
62 | action_head=True,
63 | value_head=True,
64 | imagination_horizon=3,
65 | action_head_dist='onehot_score',
66 | action_noise_type='epsilon_greedy'),
67 | ping_every=0,
68 | resume_runs=False)
69 | train.main(args)
70 |
71 | def test_dreamer_target(self):
72 | args = tools.AttrDict(
73 | logdir=self.get_temp_dir(),
74 | num_runs=1,
75 | params=tools.AttrDict(
76 | defaults=['dreamer', 'debug'],
77 | tasks=['dummy'],
78 | isolate_envs='none',
79 | max_steps=30,
80 | train_planner='policy_sample',
81 | test_planner='policy_mode',
82 | planner_objective='reward_value',
83 | action_head=True,
84 | value_head=True,
85 | value_target_head=True,
86 | imagination_horizon=3),
87 | ping_every=0,
88 | resume_runs=False)
89 | train.main(args)
90 |
91 | def test_no_value(self):
92 | args = tools.AttrDict(
93 | logdir=self.get_temp_dir(),
94 | num_runs=1,
95 | params=tools.AttrDict(
96 | defaults=['actor', 'debug'],
97 | tasks=['dummy'],
98 | isolate_envs='none',
99 | max_steps=30,
100 | imagination_horizon=3),
101 | ping_every=0,
102 | resume_runs=False)
103 | train.main(args)
104 |
105 | def test_planet(self):
106 | args = tools.AttrDict(
107 | logdir=self.get_temp_dir(),
108 | num_runs=1,
109 | params=tools.AttrDict(
110 | defaults=['planet', 'debug'],
111 | tasks=['dummy'],
112 | isolate_envs='none',
113 | max_steps=30,
114 | planner_horizon=3),
115 | ping_every=0,
116 | resume_runs=False)
117 | train.main(args)
118 |
119 |
120 | if __name__ == '__main__':
121 | tf.test.main()
122 |
--------------------------------------------------------------------------------
/tests/test_isolate_envs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import pathlib
20 | import sys
21 |
22 | sys.path.append(str(pathlib.Path(__file__).parent.parent))
23 |
24 | import tensorflow as tf
25 | from dreamer import tools
26 | from dreamer.scripts import train
27 |
28 |
29 | class IsolateEnvsTest(tf.test.TestCase):
30 |
31 | def test_dummy_thread(self):
32 | args = tools.AttrDict(
33 | logdir=self.get_temp_dir(),
34 | num_runs=1,
35 | params=tools.AttrDict(
36 | defaults=['dreamer', 'debug'],
37 | tasks=['dummy'],
38 | isolate_envs='thread',
39 | max_steps=30),
40 | ping_every=0,
41 | resume_runs=False)
42 | train.main(args)
43 |
44 | def test_dm_control_thread(self):
45 | args = tools.AttrDict(
46 | logdir=self.get_temp_dir(),
47 | num_runs=1,
48 | params=tools.AttrDict(
49 | defaults=['dreamer', 'debug'],
50 | tasks=['cup_catch'],
51 | isolate_envs='thread',
52 | max_steps=30),
53 | ping_every=0,
54 | resume_runs=False)
55 | train.main(args)
56 |
57 | def test_atari_thread(self):
58 | args = tools.AttrDict(
59 | logdir=self.get_temp_dir(),
60 | num_runs=1,
61 | params=tools.AttrDict(
62 | defaults=['dreamer', 'debug'],
63 | tasks=['atari_pong'],
64 | isolate_envs='thread',
65 | action_head_dist='onehot_score',
66 | action_noise_type='epsilon_greedy',
67 | max_steps=30),
68 | ping_every=0,
69 | resume_runs=False)
70 | train.main(args)
71 |
72 | # def test_dmlab_thread(self):
73 | # args = tools.AttrDict(
74 | # logdir=self.get_temp_dir(),
75 | # num_runs=1,
76 | # params=tools.AttrDict(
77 | # defaults=['dreamer', 'debug'],
78 | # tasks=['dmlab_collect'],
79 | # isolate_envs='thread',
80 | # action_head_dist='onehot_score',
81 | # action_noise_type='epsilon_greedy',
82 | # max_steps=30),
83 | # ping_every=0,
84 | # resume_runs=False)
85 | # train.main(args)
86 |
87 |
88 | if __name__ == '__main__':
89 | tf.test.main()
90 |
--------------------------------------------------------------------------------
/tests/test_nested.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import pathlib
21 | import sys
22 |
23 | sys.path.append(str(pathlib.Path(__file__).parent.parent))
24 |
25 | import tensorflow as tf
26 | from dreamer.tools import nested
27 |
28 |
29 | class ZipTest(tf.test.TestCase):
30 |
31 | def test_scalar(self):
32 | self.assertEqual(42, nested.zip(42))
33 | self.assertEqual((13, 42), nested.zip(13, 42))
34 |
35 | def test_empty(self):
36 | self.assertEqual({}, nested.zip({}, {}))
37 |
38 | def test_base_case(self):
39 | self.assertEqual((1, 2, 3), nested.zip(1, 2, 3))
40 |
41 | def test_shallow_list(self):
42 | a = [1, 2, 3]
43 | b = [4, 5, 6]
44 | c = [7, 8, 9]
45 | result = nested.zip(a, b, c)
46 | self.assertEqual([(1, 4, 7), (2, 5, 8), (3, 6, 9)], result)
47 |
48 | def test_shallow_tuple(self):
49 | a = (1, 2, 3)
50 | b = (4, 5, 6)
51 | c = (7, 8, 9)
52 | result = nested.zip(a, b, c)
53 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result)
54 |
55 | def test_shallow_dict(self):
56 | a = {'a': 1, 'b': 2, 'c': 3}
57 | b = {'a': 4, 'b': 5, 'c': 6}
58 | c = {'a': 7, 'b': 8, 'c': 9}
59 | result = nested.zip(a, b, c)
60 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8), 'c': (3, 6, 9)}, result)
61 |
62 | def test_single(self):
63 | a = [[1, 2], 3]
64 | result = nested.zip(a)
65 | self.assertEqual(a, result)
66 |
67 | def test_mixed_structures(self):
68 | a = [(1, 2), 3, {'foo': [4]}]
69 | b = [(5, 6), 7, {'foo': [8]}]
70 | result = nested.zip(a, b)
71 | self.assertEqual([((1, 5), (2, 6)), (3, 7), {'foo': [(4, 8)]}], result)
72 |
73 | def test_different_types(self):
74 | a = [1, 2, 3]
75 | b = 'a b c'.split()
76 | result = nested.zip(a, b)
77 | self.assertEqual([(1, 'a'), (2, 'b'), (3, 'c')], result)
78 |
79 | def test_use_type_of_first(self):
80 | a = (1, 2, 3)
81 | b = [4, 5, 6]
82 | c = [7, 8, 9]
83 | result = nested.zip(a, b, c)
84 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result)
85 |
86 | def test_namedtuple(self):
87 | Foo = collections.namedtuple('Foo', 'value')
88 | foo, bar = Foo(42), Foo(13)
89 | self.assertEqual(Foo((42, 13)), nested.zip(foo, bar))
90 |
91 |
92 | class MapTest(tf.test.TestCase):
93 |
94 | def test_scalar(self):
95 | self.assertEqual(42, nested.map(lambda x: x, 42))
96 |
97 | def test_empty(self):
98 | self.assertEqual({}, nested.map(lambda x: x, {}))
99 |
100 | def test_shallow_list(self):
101 | self.assertEqual([2, 4, 6], nested.map(lambda x: 2 * x, [1, 2, 3]))
102 |
103 | def test_shallow_dict(self):
104 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
105 | self.assertEqual(data, nested.map(lambda x: x, data))
106 |
107 | def test_mixed_structure(self):
108 | structure = [(1, 2), 3, {'foo': [4]}]
109 | result = nested.map(lambda x: 2 * x, structure)
110 | self.assertEqual([(2, 4), 6, {'foo': [8]}], result)
111 |
112 | def test_mixed_types(self):
113 | self.assertEqual([14, 'foofoo'], nested.map(lambda x: x * 2, [7, 'foo']))
114 |
115 | def test_multiple_lists(self):
116 | a = [1, 2, 3]
117 | b = [4, 5, 6]
118 | c = [7, 8, 9]
119 | result = nested.map(lambda x, y, z: x + y + z, a, b, c)
120 | self.assertEqual([12, 15, 18], result)
121 |
122 | def test_namedtuple(self):
123 | Foo = collections.namedtuple('Foo', 'value')
124 | foo, bar = [Foo(42)], [Foo(13)]
125 | function = nested.map(lambda x, y: (y, x), foo, bar)
126 | self.assertEqual([Foo((13, 42))], function)
127 | function = nested.map(lambda x, y: x + y, foo, bar)
128 | self.assertEqual([Foo(55)], function)
129 |
130 |
131 | class FlattenTest(tf.test.TestCase):
132 |
133 | def test_scalar(self):
134 | self.assertEqual((42,), nested.flatten(42))
135 |
136 | def test_empty(self):
137 | self.assertEqual((), nested.flatten({}))
138 |
139 | def test_base_case(self):
140 | self.assertEqual((1,), nested.flatten(1))
141 |
142 | def test_convert_type(self):
143 | self.assertEqual((1, 2, 3), nested.flatten([1, 2, 3]))
144 |
145 | def test_mixed_structure(self):
146 | self.assertEqual((1, 2, 3, 4), nested.flatten([(1, 2), 3, {'foo': [4]}]))
147 |
148 | def test_value_ordering(self):
149 | self.assertEqual((1, 2, 3), nested.flatten({'a': 1, 'b': 2, 'c': 3}))
150 |
151 |
152 | class FilterTest(tf.test.TestCase):
153 |
154 | def test_empty(self):
155 | self.assertEqual({}, nested.filter(lambda x: True, {}))
156 | self.assertEqual({}, nested.filter(lambda x: False, {}))
157 |
158 | def test_base_case(self):
159 | self.assertEqual((), nested.filter(lambda x: False, 1))
160 |
161 | def test_single_dict(self):
162 | predicate = lambda x: x % 2 == 0
163 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
164 | self.assertEqual({'b': 2, 'd': 4}, nested.filter(predicate, data))
165 |
166 | def test_multiple_lists(self):
167 | a = [1, 2, 3]
168 | b = [4, 5, 6]
169 | c = [7, 8, 9]
170 | predicate = lambda *args: any(x % 4 == 0 for x in args)
171 | result = nested.filter(predicate, a, b, c)
172 | self.assertEqual([(1, 4, 7), (2, 5, 8)], result)
173 |
174 | def test_multiple_dicts(self):
175 | a = {'a': 1, 'b': 2, 'c': 3}
176 | b = {'a': 4, 'b': 5, 'c': 6}
177 | c = {'a': 7, 'b': 8, 'c': 9}
178 | predicate = lambda *args: any(x % 4 == 0 for x in args)
179 | result = nested.filter(predicate, a, b, c)
180 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8)}, result)
181 |
182 | def test_mixed_structure(self):
183 | predicate = lambda x: x % 2 == 0
184 | data = [(1, 2), 3, {'foo': [4]}]
185 | self.assertEqual([(2,), {'foo': [4]}], nested.filter(predicate, data))
186 |
187 | def test_remove_empty_containers(self):
188 | data = [(1, 2, 3), 4, {'foo': [5, 6], 'bar': 7}]
189 | self.assertEqual([], nested.filter(lambda x: False, data))
190 |
191 | def test_namedtuple(self):
192 | Foo = collections.namedtuple('Foo', 'value1, value2')
193 | self.assertEqual(Foo(1, None), nested.filter(lambda x: x == 1, Foo(1, 2)))
194 |
195 | def test_namedtuple_multiple(self):
196 | Foo = collections.namedtuple('Foo', 'value1, value2')
197 | foo = Foo(1, 2)
198 | bar = Foo(2, 3)
199 | result = nested.filter(lambda x, y: x + y > 3, foo, bar)
200 | self.assertEqual(Foo(None, (2, 3)), result)
201 |
202 | def test_namedtuple_nested(self):
203 | Foo = collections.namedtuple('Foo', 'value1, value2')
204 | foo = Foo(1, [1, 2, 3])
205 | self.assertEqual(Foo(None, [2, 3]), nested.filter(lambda x: x > 1, foo))
206 |
--------------------------------------------------------------------------------
/tests/test_overshooting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import pathlib
20 | import sys
21 |
22 | sys.path.append(str(pathlib.Path(__file__).parent.parent))
23 |
24 | import tensorflow as tf
25 | from dreamer import models
26 | from dreamer.tools import overshooting
27 |
28 |
29 | class _MockCell(models.Base):
30 |
31 | def __init__(self, obs_size):
32 | self._obs_size = obs_size
33 | super(_MockCell, self).__init__(
34 | tf.make_template('transition', self._transition),
35 | tf.make_template('posterior', self._posterior))
36 |
37 | @property
38 | def state_size(self):
39 | return {'obs': self._obs_size}
40 |
41 | def _transition(self, prev_state, prev_action, zero_obs):
42 | if isinstance(prev_action, (tuple, list)):
43 | prev_action = prev_action[0]
44 | return {'obs': prev_state['obs'] + prev_action}
45 |
46 | def _posterior(self, prev_state, prev_action, obs):
47 | if isinstance(obs, (tuple, list)):
48 | obs = obs[0]
49 | return {'obs': obs}
50 |
51 |
52 | class OvershootingTest(tf.test.TestCase):
53 |
54 | def test_example(self):
55 | obs = tf.constant([
56 | [10, 20, 30, 40, 50, 60],
57 | [70, 80, 0, 0, 0, 0],
58 | ], dtype=tf.float32)[:, :, None]
59 | prev_action = tf.constant([
60 | [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
61 | [9.0, 0.7, 0, 0, 0, 0],
62 | ], dtype=tf.float32)[:, :, None]
63 | length = tf.constant([6, 2], dtype=tf.int32)
64 | cell = _MockCell(1)
65 | _, prior, posterior, mask = overshooting(
66 | cell, obs, obs, prev_action, length, 3)
67 | prior = tf.squeeze(prior['obs'], 3)
68 | posterior = tf.squeeze(posterior['obs'], 3)
69 | mask = tf.to_int32(mask)
70 | with self.test_session():
71 | # Each column corresponds to a different state step, and each row
72 | # corresponds to a different overshooting distance from there.
73 | self.assertAllEqual([
74 | [1, 1, 1, 1, 1, 1],
75 | [1, 1, 1, 1, 1, 0],
76 | [1, 1, 1, 1, 0, 0],
77 | [1, 1, 1, 0, 0, 0],
78 | ], mask.eval()[0].T)
79 | self.assertAllEqual([
80 | [1, 1, 0, 0, 0, 0],
81 | [1, 0, 0, 0, 0, 0],
82 | [0, 0, 0, 0, 0, 0],
83 | [0, 0, 0, 0, 0, 0],
84 | ], mask.eval()[1].T)
85 | self.assertAllClose([
86 | [0.0, 10.1, 20.2, 30.3, 40.4, 50.5],
87 | [0.1, 10.3, 20.5, 30.7, 40.9, 0],
88 | [0.3, 10.6, 20.9, 31.2, 0, 0],
89 | [0.6, 11.0, 21.4, 0, 0, 0],
90 | ], prior.eval()[0].T)
91 | self.assertAllClose([
92 | [10, 20, 30, 40, 50, 60],
93 | [20, 30, 40, 50, 60, 0],
94 | [30, 40, 50, 60, 0, 0],
95 | [40, 50, 60, 0, 0, 0],
96 | ], posterior.eval()[0].T)
97 | self.assertAllClose([
98 | [9.0, 70.7, 0, 0, 0, 0],
99 | [9.7, 0, 0, 0, 0, 0],
100 | [0, 0, 0, 0, 0, 0],
101 | [0, 0, 0, 0, 0, 0],
102 | ], prior.eval()[1].T)
103 | self.assertAllClose([
104 | [70, 80, 0, 0, 0, 0],
105 | [80, 0, 0, 0, 0, 0],
106 | [0, 0, 0, 0, 0, 0],
107 | [0, 0, 0, 0, 0, 0],
108 | ], posterior.eval()[1].T)
109 |
110 | def test_nested(self):
111 | obs = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)), tf.ones((3, 50, 3)))
112 | prev_action = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)))
113 | length = tf.constant([49, 50, 3], dtype=tf.int32)
114 | cell = _MockCell(1)
115 | overshooting(cell, obs, obs, prev_action, length, 3)
116 |
117 |
118 | if __name__ == '__main__':
119 | tf.test.main()
120 |
--------------------------------------------------------------------------------
/tests/test_running.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Dreamer Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import pathlib
21 | import pickle
22 | import sys
23 | import threading
24 | import time
25 |
26 | sys.path.append(str(pathlib.Path(__file__).parent.parent))
27 |
28 | import numpy as np
29 | import tensorflow as tf
30 | from dreamer.training import running
31 |
32 |
33 | class TestExperiment(tf.test.TestCase):
34 |
35 | def test_no_kills(self):
36 | basedir = os.path.join(tf.test.get_temp_dir(), 'test_no_kills')
37 | processes = []
38 | for worker_name in range(20):
39 | processes.append(threading.Thread(
40 | target=_worker_normal, args=(basedir, str(worker_name))))
41 | processes[-1].start()
42 | for process in processes:
43 | process.join()
44 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/DONE'))
45 | self.assertEqual(100, len(filepaths))
46 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/PING'))
47 | self.assertEqual(100, len(filepaths))
48 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/started'))
49 | self.assertEqual(100, len(filepaths))
50 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/resumed'))
51 | self.assertEqual(0, len(filepaths))
52 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/failed'))
53 | self.assertEqual(0, len(filepaths))
54 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/numbers'))
55 | self.assertEqual(100, len(filepaths))
56 | for filepath in filepaths:
57 | with tf.gfile.GFile(filepath, 'rb') as file_:
58 | self.assertEqual(10, len(pickle.load(file_)))
59 |
60 | def test_dying_workers(self):
61 | basedir = os.path.join(tf.test.get_temp_dir(), 'test_dying_workers')
62 | processes = []
63 | for worker_name in range(20):
64 | processes.append(threading.Thread(
65 | target=_worker_dying, args=(basedir, 15, str(worker_name))))
66 | processes[-1].start()
67 | for process in processes:
68 | process.join()
69 | processes = []
70 | for worker_name in range(20):
71 | processes.append(threading.Thread(
72 | target=_worker_normal, args=(basedir, str(worker_name))))
73 | processes[-1].start()
74 | for process in processes:
75 | process.join()
76 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/DONE'))
77 | self.assertEqual(100, len(filepaths))
78 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/PING'))
79 | self.assertEqual(100, len(filepaths))
80 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/FAIL'))
81 | self.assertEqual(0, len(filepaths))
82 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/started'))
83 | self.assertEqual(100, len(filepaths))
84 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/resumed'))
85 | self.assertEqual(20, len(filepaths))
86 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/numbers'))
87 | self.assertEqual(100, len(filepaths))
88 | for filepath in filepaths:
89 | with tf.gfile.GFile(filepath, 'rb') as file_:
90 | self.assertEqual(10, len(pickle.load(file_)))
91 |
92 |
93 | def _worker_normal(basedir, worker_name):
94 | experiment = running.Experiment(
95 | basedir, _process_fn, _start_fn, _resume_fn,
96 | num_runs=100, worker_name=worker_name, ping_every=1.0)
97 | for run in experiment:
98 | for score in run:
99 | pass
100 |
101 |
102 | def _worker_dying(basedir, die_at_step, worker_name):
103 | experiment = running.Experiment(
104 | basedir, _process_fn, _start_fn, _resume_fn,
105 | num_runs=100, worker_name=worker_name, ping_every=1.0)
106 | step = 0
107 | for run in experiment:
108 | for score in run:
109 | step += 1
110 | if step >= die_at_step:
111 | return
112 |
113 |
114 | def _start_fn(logdir):
115 | assert not tf.gfile.Exists(os.path.join(logdir, 'DONE'))
116 | assert not tf.gfile.Exists(os.path.join(logdir, 'started'))
117 | assert not tf.gfile.Exists(os.path.join(logdir, 'resumed'))
118 | with tf.gfile.GFile(os.path.join(logdir, 'started'), 'w') as file_:
119 | file_.write('\n')
120 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'wb') as file_:
121 | pickle.dump([], file_)
122 | return []
123 |
124 |
125 | def _resume_fn(logdir):
126 | assert not tf.gfile.Exists(os.path.join(logdir, 'DONE'))
127 | assert tf.gfile.Exists(os.path.join(logdir, 'started'))
128 | with tf.gfile.GFile(os.path.join(logdir, 'resumed'), 'w') as file_:
129 | file_.write('\n')
130 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'rb') as file_:
131 | numbers = pickle.load(file_)
132 | if len(numbers) != 5:
133 | raise Exception('Expected to be resumed in the middle for this test.')
134 | return numbers
135 |
136 |
137 | def _process_fn(logdir, numbers):
138 | assert tf.gfile.Exists(os.path.join(logdir, 'started'))
139 | while len(numbers) < 10:
140 | number = np.random.uniform(0, 0.1)
141 | time.sleep(number)
142 | numbers.append(number)
143 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'wb') as file_:
144 | pickle.dump(numbers, file_)
145 | yield number
146 |
147 |
148 | if __name__ == '__main__':
149 | tf.test.main()
150 |
--------------------------------------------------------------------------------