├── .gitignore ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dreamer ├── __init__.py ├── control │ ├── __init__.py │ ├── batch_env.py │ ├── dummy_env.py │ ├── mpc_agent.py │ ├── planning.py │ ├── random_episodes.py │ ├── simulate.py │ ├── temporal_difference.py │ └── wrappers.py ├── distributions │ ├── __init__.py │ ├── one_hot.py │ └── tanh_normal.py ├── models │ ├── __init__.py │ ├── base.py │ └── rssm.py ├── networks │ ├── __init__.py │ ├── basic.py │ ├── conv.py │ └── proprio.py ├── scripts │ ├── __init__.py │ ├── configs.py │ ├── fileutil_fetch.py │ ├── objectives.py │ ├── sync.py │ ├── tasks.py │ └── train.py ├── tools │ ├── __init__.py │ ├── attr_dict.py │ ├── bind.py │ ├── chunk_sequence.py │ ├── copy_weights.py │ ├── count_dataset.py │ ├── count_weights.py │ ├── custom_optimizer.py │ ├── filter_variables_lib.py │ ├── gif_summary.py │ ├── image_strip_summary.py │ ├── mask.py │ ├── metrics.py │ ├── nested.py │ ├── numpy_episodes.py │ ├── overshooting.py │ ├── preprocess.py │ ├── reshape_as.py │ ├── schedule.py │ ├── shape.py │ ├── streaming_mean.py │ ├── summary.py │ ├── target_network.py │ └── unroll.py └── training │ ├── __init__.py │ ├── define_model.py │ ├── define_summaries.py │ ├── running.py │ ├── trainer.py │ └── utility.py ├── setup.py └── tests ├── test_dreamer.py ├── test_isolate_envs.py ├── test_nested.py ├── test_overshooting.py └── test_running.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.egg-info 4 | /dist 5 | /logdir 6 | MUJOCO_LOG.TXT 7 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of PlaNet authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | Google LLC 7 | Danijar Hafner 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dream to Control 2 | 3 | Danijar Hafner, Timothy Lillicrap, Jimmy Ba, Mohammad Norouzi 4 | 5 | **Note:** This is the original implementation. To build upon Dreamer, we 6 | recommend the newer implementation of [Dreamer in TensorFlow 7 | 2](https://github.com/danijar/dreamer). It is substantially simpler 8 | and faster while replicating the results. 9 | 10 | 11 | 12 | Implementation of Dreamer, the reinforcement learning agent introduced in 13 | [Dream to Control: Learning Behaviors by Latent Imagination][paper]. Dreamer 14 | learns long-horizon behaviors from images purely by latent imagination. For 15 | this, it backpropagates value estimates through trajectories imagined in the 16 | compact latent space of a learned world model. Dreamer solves visual control 17 | tasks using substantilly fewer episodes than strong model-free agents. 18 | 19 | If you find this open source release useful, please reference in your paper: 20 | 21 | ``` 22 | @article{hafner2019dreamer, 23 | title={Dream to Control: Learning Behaviors by Latent Imagination}, 24 | author={Hafner, Danijar and Lillicrap, Timothy and Ba, Jimmy and Norouzi, Mohammad}, 25 | journal={arXiv preprint arXiv:1912.01603}, 26 | year={2019} 27 | } 28 | ``` 29 | 30 | ## Method 31 | 32 | ![Dreamer model diagram](https://imgur.com/JrXC4rh.png) 33 | 34 | Dreamer learns a world model from past experience that can predict into the 35 | future. It then learns action and value models in its compact latent space. The 36 | value model optimizes Bellman consistency of imagined trajectories. The action 37 | model maximizes value estimates by propgating their analytic gradients back 38 | through imagined trajectories. When interacting with the environment, it simply 39 | executes the action model. 40 | 41 | Find out more: 42 | 43 | - [Project website][website] 44 | - [PDF paper][paper] 45 | 46 | [website]: https://danijar.com/dreamer 47 | [paper]: https://arxiv.org/pdf/1912.01603.pdf 48 | 49 | ## Instructions 50 | 51 | To train an agent, install the dependencies and then run one of these commands: 52 | 53 | ```sh 54 | python3 -m dreamer.scripts.train --logdir ./logdir/debug \ 55 | --params '{defaults: [dreamer, debug], tasks: [dummy]}' \ 56 | --num_runs 1000 --resume_runs False 57 | ``` 58 | 59 | ```sh 60 | python3 -m dreamer.scripts.train --logdir ./logdir/control \ 61 | --params '{defaults: [dreamer], tasks: [walker_run]}' 62 | ``` 63 | 64 | ```sh 65 | python3 -m dreamer.scripts.train --logdir ./logdir/atari \ 66 | --params '{defaults: [dreamer, pcont, discrete, atari], tasks: [atari_boxing]}' 67 | ``` 68 | 69 | ```sh 70 | python3 -m dreamer.scripts.train --logdir ./logdir/dmlab \ 71 | --params '{defaults: [dreamer, discrete], tasks: [dmlab_collect]}' 72 | ``` 73 | 74 | The available tasks are listed in `scripts/tasks.py`. The hyper parameters can 75 | be found in `scripts/configs.py`. 76 | 77 | Tips: 78 | 79 | - Add `debug` to the list of defaults to use a smaller config and reach 80 | the code you're developing more quickly. 81 | - Add the flags `--resume_runs False` and `--num_runs 1000` 82 | to automatically create unique logdirs. 83 | - To train the baseline without value function, add `value_head: False` to the 84 | params. 85 | - To train PlaNet, add `train_planner: cem, test_planner: cem, 86 | planner_objective: reward, action_head: False, value_head: False, 87 | imagination_horizon: 0` to the params. 88 | 89 | ## Dependencies 90 | 91 | The code was tested under Ubuntu 18 and uses these packages: 92 | tensorflow-gpu==1.13.1, tensorflow_probability==0.6.0, dm_control (`egl` 93 | [rendering option][rendering] recommended), gym, imageio, matplotlib, 94 | ruamel.yaml, scikit-image, scipy. 95 | 96 | [rendering]: https://github.com/deepmind/dm_control#rendering 97 | 98 | Disclaimer: This is not an official Google product. 99 | -------------------------------------------------------------------------------- /dreamer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import control 20 | from . import models 21 | from . import networks 22 | from . import scripts 23 | from . import tools 24 | from . import training 25 | -------------------------------------------------------------------------------- /dreamer/control/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import planning 20 | from .batch_env import PyBatchEnv 21 | from .batch_env import TFBatchEnv 22 | from .dummy_env import DummyEnv 23 | from .mpc_agent import MPCAgent 24 | from .random_episodes import random_episodes 25 | from .simulate import create_batch_env 26 | from .simulate import create_env 27 | from .simulate import simulate 28 | from .temporal_difference import discounted_return 29 | from .temporal_difference import fixed_step_return 30 | from .temporal_difference import lambda_return 31 | -------------------------------------------------------------------------------- /dreamer/control/batch_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import gym 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | 24 | class TFBatchEnv(object): 25 | 26 | def __init__(self, envs, blocking): 27 | self._batch_env = PyBatchEnv(envs, blocking, flatten=True) 28 | spaces = self._batch_env.observation_space.spaces 29 | self._dtypes = [self._parse_dtype(spaces[key]) for key in self._keys[:-2]] 30 | self._dtypes += [tf.float32, tf.bool] # Reward and done flag. 31 | self._shapes = [self._parse_shape(spaces[key]) for key in self._keys[:-2]] 32 | self._shapes += [(), ()] # Reward and done flag. 33 | 34 | def __getattr__(self, name): 35 | return getattr(self._batch_env, name) 36 | 37 | def __len__(self): 38 | return len(self._batch_env) 39 | 40 | def __getitem__(self, index): 41 | return self._batch_env[index] 42 | 43 | def step(self, action): 44 | output = tf.py_func( 45 | self._batch_env.step, [action], self._dtypes, name='step') 46 | return self._process_output(output, len(self._batch_env)) 47 | 48 | def reset(self, indices=None): 49 | if indices is None: 50 | indices = tf.range(len(self._batch_env)) 51 | output = tf.py_func( 52 | self._batch_env.reset, [indices], self._dtypes, name='reset') 53 | return self._process_output(output, None) 54 | 55 | def _process_output(self, output, batch_size): 56 | for tensor, shape in zip(output, self._shapes): 57 | tensor.set_shape((batch_size,) + shape) 58 | return {key: tensor for key, tensor in zip(self._keys, output)} 59 | 60 | def _parse_dtype(self, space): 61 | if isinstance(space, gym.spaces.Discrete): 62 | return tf.int32 63 | if isinstance(space, gym.spaces.Box): 64 | if space.low.dtype == np.uint8: 65 | return tf.uint8 66 | else: 67 | return tf.float32 68 | raise NotImplementedError() 69 | 70 | def _parse_shape(self, space): 71 | if isinstance(space, gym.spaces.Discrete): 72 | return () 73 | if isinstance(space, gym.spaces.Box): 74 | return space.shape 75 | raise NotImplementedError("Unsupported space '{}.'".format(space)) 76 | 77 | 78 | class PyBatchEnv(object): 79 | 80 | def __init__(self, envs, blocking, flatten=False): 81 | observ_space = envs[0].observation_space 82 | if not all(env.observation_space == observ_space for env in envs): 83 | raise ValueError('All environments must use the same observation space.') 84 | action_space = envs[0].action_space 85 | if not all(env.action_space == action_space for env in envs): 86 | raise ValueError('All environments must use the same observation space.') 87 | self._envs = envs 88 | self._blocking = blocking 89 | self._flatten = flatten 90 | self._keys = list(sorted(observ_space.spaces.keys())) + ['reward', 'done'] 91 | 92 | def __len__(self): 93 | return len(self._envs) 94 | 95 | def __getitem__(self, index): 96 | return self._envs[index] 97 | 98 | def __getattr__(self, name): 99 | return getattr(self._envs[0], name) 100 | 101 | def step(self, actions): 102 | for index, (env, action) in enumerate(zip(self._envs, actions)): 103 | if not env.action_space.contains(action): 104 | message = 'Invalid action for batch index {}: {}' 105 | raise ValueError(message.format(index, action)) 106 | if self._blocking: 107 | transitions = [ 108 | env.step(action) 109 | for env, action in zip(self._envs, actions)] 110 | else: 111 | transitions = [ 112 | env.step(action, blocking=False) 113 | for env, action in zip(self._envs, actions)] 114 | transitions = [transition() for transition in transitions] 115 | outputs = {key: [] for key in self._keys} 116 | for observ, reward, done, _ in transitions: 117 | for key, value in observ.items(): 118 | outputs[key].append(np.array(value)) 119 | outputs['reward'].append(np.array(reward, np.float32)) 120 | outputs['done'].append(np.array(done, np.bool)) 121 | outputs = {key: np.stack(value) for key, value in outputs.items()} 122 | if self._flatten: 123 | outputs = tuple(outputs[key] for key in self._keys) 124 | return outputs 125 | 126 | def reset(self, indices=None): 127 | if indices is None: 128 | indices = range(len(self._envs)) 129 | if self._blocking: 130 | observs = [self._envs[index].reset() for index in indices] 131 | else: 132 | observs = [self._envs[index].reset(blocking=False) for index in indices] 133 | observs = [observ() for observ in observs] 134 | outputs = {key: [] for key in self._keys} 135 | for observ in observs: 136 | for key, value in observ.items(): 137 | outputs[key].append(np.array(value)) 138 | outputs['reward'].append(np.array(0.0, np.float32)) 139 | outputs['done'].append(np.array(False, np.bool)) 140 | outputs = {key: np.stack(value) for key, value in outputs.items()} 141 | if self._flatten: 142 | outputs = tuple(outputs[key] for key in self._keys) 143 | return outputs 144 | 145 | def close(self): 146 | for env in self._envs: 147 | if hasattr(env, 'close'): 148 | env.close() 149 | -------------------------------------------------------------------------------- /dreamer/control/dummy_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import gym 20 | import numpy as np 21 | 22 | 23 | class DummyEnv(object): 24 | 25 | def __init__(self): 26 | self._random = np.random.RandomState(seed=0) 27 | self._step = None 28 | 29 | @property 30 | def observation_space(self): 31 | low = np.zeros([64, 64, 3], dtype=np.float32) 32 | high = np.ones([64, 64, 3], dtype=np.float32) 33 | spaces = {'image': gym.spaces.Box(low, high)} 34 | return gym.spaces.Dict(spaces) 35 | 36 | @property 37 | def action_space(self): 38 | low = -np.ones([5], dtype=np.float32) 39 | high = np.ones([5], dtype=np.float32) 40 | return gym.spaces.Box(low, high) 41 | 42 | def reset(self): 43 | self._step = 0 44 | obs = self.observation_space.sample() 45 | return obs 46 | 47 | def step(self, action): 48 | obs = self.observation_space.sample() 49 | reward = self._random.uniform(0, 1) 50 | self._step += 1 51 | done = self._step >= 1000 52 | info = {} 53 | return obs, reward, done, info 54 | -------------------------------------------------------------------------------- /dreamer/control/mpc_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow_probability import distributions as tfd 22 | 23 | from dreamer.tools import nested 24 | 25 | 26 | class MPCAgent(object): 27 | 28 | def __init__(self, batch_env, step, is_training, should_log, config): 29 | self._step = step # Trainer step, not environment step. 30 | self._is_training = is_training 31 | self._should_log = should_log 32 | self._config = config 33 | self._cell = config.cell 34 | self._num_envs = len(batch_env) 35 | state = self._cell.zero_state(self._num_envs, tf.float32) 36 | var_like = lambda x: tf.get_local_variable( 37 | x.name.split(':')[0].replace('/', '_') + '_var', 38 | shape=x.shape, 39 | initializer=lambda *_, **__: tf.zeros_like(x), use_resource=True) 40 | self._state = nested.map(var_like, state) 41 | batch_action_shape = (self._num_envs,) + batch_env.action_space.shape 42 | self._prev_action = tf.get_local_variable( 43 | 'prev_action_var', shape=batch_action_shape, 44 | initializer=lambda *_, **__: tf.zeros(batch_action_shape), 45 | use_resource=True) 46 | 47 | def reset(self, agent_indices): 48 | state = nested.map( 49 | lambda tensor: tf.gather(tensor, agent_indices), 50 | self._state) 51 | reset_state = nested.map( 52 | lambda var, val: tf.scatter_update(var, agent_indices, 0 * val), 53 | self._state, state, flatten=True) 54 | reset_prev_action = self._prev_action.assign( 55 | tf.zeros_like(self._prev_action)) 56 | return tf.group(reset_prev_action, *reset_state) 57 | 58 | def step(self, agent_indices, observ): 59 | observ = self._config.preprocess_fn(observ) 60 | # Converts observ to sequence. 61 | observ = nested.map(lambda x: x[:, None], observ) 62 | embedded = self._config.encoder(observ)[:, 0] 63 | state = nested.map( 64 | lambda tensor: tf.gather(tensor, agent_indices), 65 | self._state) 66 | prev_action = self._prev_action + 0 67 | with tf.control_dependencies([prev_action]): 68 | use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None] 69 | _, state = self._cell((embedded, prev_action, use_obs), state) 70 | action = self._config.planner( 71 | self._cell, self._config.objective, state, 72 | embedded.shape[1:].as_list(), 73 | prev_action.shape[1:].as_list()) 74 | action = action[:, 0] 75 | if self._config.exploration: 76 | expl = self._config.exploration 77 | scale = tf.cast(expl.scale, tf.float32)[None] # Batch dimension. 78 | if expl.schedule: 79 | scale *= expl.schedule(self._step) 80 | if expl.factors: 81 | scale *= np.array(expl.factors) 82 | if expl.type == 'additive_normal': 83 | action = tfd.Normal(action, scale[:, None]).sample() 84 | elif expl.type == 'epsilon_greedy': 85 | random_action = tf.one_hot( 86 | tfd.Categorical(0 * action).sample(), action.shape[-1]) 87 | switch = tf.cast(tf.less( 88 | tf.random.uniform((self._num_envs,)), 89 | scale), tf.float32)[:, None] 90 | action = switch * random_action + (1 - switch) * action 91 | else: 92 | raise NotImplementedError(expl.type) 93 | action = tf.clip_by_value(action, -1, 1) 94 | remember_action = self._prev_action.assign(action) 95 | remember_state = nested.map( 96 | lambda var, val: tf.scatter_update(var, agent_indices, val), 97 | self._state, state, flatten=True) 98 | with tf.control_dependencies(remember_state + (remember_action,)): 99 | return tf.identity(action) 100 | -------------------------------------------------------------------------------- /dreamer/control/planning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer import tools 22 | 23 | 24 | def cross_entropy_method( 25 | cell, objective, state, obs_shape, action_shape, horizon, graph, 26 | beams=1000, topk=100, iterations=10, min_action=-1, max_action=1): 27 | obs_shape, action_shape = tuple(obs_shape), tuple(action_shape) 28 | batch = tools.shape(tools.nested.flatten(state)[0])[0] 29 | initial_state = tools.nested.map(lambda tensor: tf.tile( 30 | tensor, [beams] + [1] * (tensor.shape.ndims - 1)), state) 31 | extended_batch = tools.shape(tools.nested.flatten(initial_state)[0])[0] 32 | use_obs = tf.zeros([extended_batch, horizon, 1], tf.bool) 33 | obs = tf.zeros((extended_batch, horizon) + obs_shape) 34 | 35 | def iteration(index, mean, stddev): 36 | # Sample action proposals from belief. 37 | normal = tf.random_normal((batch, beams, horizon) + action_shape) 38 | action = normal * stddev[:, None] + mean[:, None] 39 | action = tf.clip_by_value(action, min_action, max_action) 40 | # Evaluate proposal actions. 41 | action = tf.reshape( 42 | action, (extended_batch, horizon) + action_shape) 43 | (_, state), _ = tf.nn.dynamic_rnn( 44 | cell, (0 * obs, action, use_obs), initial_state=initial_state) 45 | return_ = objective(state) 46 | return_ = tf.reshape(return_, (batch, beams)) 47 | # Re-fit belief to the best ones. 48 | _, indices = tf.nn.top_k(return_, topk, sorted=False) 49 | indices += tf.range(batch)[:, None] * beams 50 | best_actions = tf.gather(action, indices) 51 | mean, variance = tf.nn.moments(best_actions, 1) 52 | stddev = tf.sqrt(variance + 1e-6) 53 | return index + 1, mean, stddev 54 | 55 | mean = tf.zeros((batch, horizon) + action_shape) 56 | stddev = tf.ones((batch, horizon) + action_shape) 57 | _, mean, std = tf.while_loop( 58 | lambda index, mean, stddev: index < iterations, iteration, 59 | (0, mean, stddev), back_prop=False) 60 | return mean 61 | 62 | 63 | def action_head_policy( 64 | cell, objective, state, obs_shape, action_shape, graph, config, strategy): 65 | features = cell.features_from_state(state) 66 | policy = graph.heads.action(features) 67 | if strategy == 'sample': 68 | action = policy.sample() 69 | elif strategy == 'mode': 70 | action = policy.mode() 71 | else: 72 | raise NotImplementedError(strategy) 73 | plan = action[:, None, :] 74 | return plan 75 | -------------------------------------------------------------------------------- /dreamer/control/random_episodes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from dreamer import control 20 | 21 | 22 | def random_episodes( 23 | env_ctor, num_episodes, num_steps, outdir=None, isolate_envs='none'): 24 | # If using environment processes or threads, we should also use them here to 25 | # avoid loading their dependencies into the global name space. This way, 26 | # their imports will be isolated from the main process and later created envs 27 | # do not inherit them via global state but import their own copies. 28 | env, _ = control.create_env(env_ctor, isolate_envs) 29 | env = control.wrappers.CollectDataset(env, outdir) 30 | episodes = [] if outdir else None 31 | while num_episodes > 0 or num_steps > 0: 32 | policy = lambda env, obs: env.action_space.sample() 33 | done = False 34 | obs = env.reset() 35 | while not done: 36 | action = policy(env, obs) 37 | obs, _, done, info = env.step(action) 38 | episode = env._get_episode() 39 | episodes.append(episode) 40 | num_episodes -= 1 41 | num_steps -= len(episode['reward']) 42 | try: 43 | env.close() 44 | except AttributeError: 45 | pass 46 | return episodes 47 | -------------------------------------------------------------------------------- /dreamer/control/simulate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer.control import batch_env 22 | from dreamer.control import wrappers 23 | 24 | 25 | def create_batch_env(env_ctor, num_envs, isolate_envs): 26 | envs, blockings = zip(*[ 27 | create_env(env_ctor, isolate_envs) 28 | for _ in range(num_envs)]) 29 | assert all(blocking == blockings[0] for blocking in blockings) 30 | return batch_env.TFBatchEnv(envs, blockings[0]) 31 | 32 | 33 | def create_env(env_ctor, isolate_envs): 34 | if isolate_envs == 'none': 35 | env = env_ctor() 36 | blocking = True 37 | elif isolate_envs == 'thread': 38 | env = wrappers.Async(env_ctor, 'thread') 39 | blocking = False 40 | elif isolate_envs == 'process': 41 | env = wrappers.Async(env_ctor, 'process') 42 | blocking = False 43 | else: 44 | raise NotImplementedError(isolate_envs) 45 | return env, blocking 46 | 47 | 48 | def simulate(agent, env, episodes=None, steps=None): 49 | 50 | def pred(step, episode, *args, **kwargs): 51 | if episodes is None: 52 | return tf.less(step, steps) 53 | if steps is None: 54 | return tf.less(episode, episodes) 55 | return tf.logical_or(tf.less(episode, episodes), tf.less(step, steps)) 56 | 57 | def reset(mask, scores, lengths, previous): 58 | indices = tf.where(mask)[:, 0] 59 | reset_agent = agent.reset(indices) 60 | values = env.reset(indices) 61 | # This would be shorter but gives an internal TensorFlow error. 62 | # previous = { 63 | # key: tf.tensor_scatter_update( 64 | # previous[key], indices[:, None], values[key]) 65 | # for key in previous} 66 | idx = tf.cast(indices[:, None], tf.int32) 67 | previous = { 68 | key: tf.where( 69 | mask, 70 | tf.scatter_nd(idx, values[key], tf.shape(previous[key])), 71 | previous[key]) 72 | for key in previous} 73 | scores = tf.where(mask, tf.zeros_like(scores), scores) 74 | lengths = tf.where(mask, tf.zeros_like(lengths), lengths) 75 | with tf.control_dependencies([reset_agent]): 76 | return tf.identity(scores), tf.identity(lengths), previous 77 | 78 | def body(step, episode, scores, lengths, previous, ta_out, ta_sco, ta_len): 79 | # Reset episodes and agents if necessary. 80 | reset_mask = tf.cond( 81 | tf.equal(step, 0), 82 | lambda: tf.ones(len(env), tf.bool), 83 | lambda: previous['done']) 84 | reset_mask.set_shape([len(env)]) 85 | scores, lengths, previous = tf.cond( 86 | tf.reduce_any(reset_mask), 87 | lambda: reset(reset_mask, scores, lengths, previous), 88 | lambda: (scores, lengths, previous)) 89 | step_indices = tf.range(len(env)) 90 | action = agent.step(step_indices, previous) 91 | values = env.step(action) 92 | # Update book keeping variables. 93 | done_indices = tf.cast(tf.where(values['done']), tf.int32) 94 | step += tf.shape(step_indices)[0] 95 | episode += tf.shape(done_indices)[0] 96 | scores += values['reward'] 97 | lengths += tf.shape(step_indices)[0] 98 | # Write transitions, scores, and lengths to tensor arrays. 99 | ta_out = { 100 | key: array.write(array.size(), values[key]) 101 | for key, array in ta_out.items()} 102 | ta_sco = tf.cond( 103 | tf.greater(tf.shape(done_indices)[0], 0), 104 | lambda: ta_sco.write( 105 | ta_sco.size(), tf.gather(scores, done_indices)[:, 0]), 106 | lambda: ta_sco) 107 | ta_len = tf.cond( 108 | tf.greater(tf.shape(done_indices)[0], 0), 109 | lambda: ta_len.write( 110 | ta_len.size(), tf.gather(lengths, done_indices)[:, 0]), 111 | lambda: ta_len) 112 | return step, episode, scores, lengths, values, ta_out, ta_sco, ta_len 113 | 114 | initial = env.reset() 115 | ta_out = { 116 | key: tf.TensorArray( 117 | value.dtype, 0, True, element_shape=value.shape) 118 | for key, value in initial.items()} 119 | ta_sco = tf.TensorArray( 120 | tf.float32, 0, True, element_shape=[None], infer_shape=False) 121 | ta_len = tf.TensorArray( 122 | tf.int32, 0, True, element_shape=[None], infer_shape=False) 123 | zero_scores = tf.zeros(len(env), tf.float32, name='scores') 124 | zero_lengths = tf.zeros(len(env), tf.int32, name='lengths') 125 | ta_out, ta_sco, ta_len = tf.while_loop( 126 | pred, body, 127 | (0, 0, zero_scores, zero_lengths, initial, ta_out, ta_sco, ta_len), 128 | parallel_iterations=1, back_prop=False)[-3:] 129 | transitions = {key: array.stack() for key, array in ta_out.items()} 130 | transitions = { 131 | key: tf.transpose(value, [1, 0] + list(range(2, value.shape.ndims))) 132 | for key, value in transitions.items()} 133 | scores = ta_sco.concat() 134 | lengths = ta_len.concat() 135 | return scores, lengths, transitions 136 | -------------------------------------------------------------------------------- /dreamer/control/temporal_difference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def discounted_return(reward, pcont, bootstrap, axis, stop_gradient=True): 23 | if isinstance(pcont, (float, int)): 24 | if pcont == 1 and bootstrap is None: 25 | return tf.reduce_sum(reward, axis) 26 | if pcont == 1: 27 | return tf.reduce_sum(reward, axis) + bootstrap 28 | pcont = pcont * tf.ones_like(reward) 29 | # Bring the aggregation dimension front. 30 | dims = list(range(reward.shape.ndims)) 31 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 32 | reward = tf.transpose(reward, dims) 33 | pcont = tf.transpose(pcont, dims) 34 | if bootstrap is None: 35 | bootstrap = tf.zeros_like(reward[-1]) 36 | return_ = tf.scan( 37 | fn=lambda agg, cur: cur[0] + cur[1] * agg, 38 | elems=(reward, pcont), 39 | initializer=bootstrap, 40 | back_prop=not stop_gradient, 41 | reverse=True) 42 | return_ = tf.transpose(return_, dims) 43 | if stop_gradient: 44 | return_ = tf.stop_gradient(return_) 45 | return return_ 46 | 47 | 48 | def lambda_return( 49 | reward, value, bootstrap, pcont, lambda_, axis, stop_gradient=True): 50 | # Setting lambda=1 gives a discounted Monte Carlo return. 51 | # Setting lambda=0 gives a fixed 1-step return. 52 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) 53 | # Bring the aggregation dimension front. 54 | dims = list(range(reward.shape.ndims)) 55 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 56 | if isinstance(pcont, (int, float)): 57 | pcont = pcont * tf.ones_like(reward) 58 | reward = tf.transpose(reward, dims) 59 | value = tf.transpose(value, dims) 60 | pcont = tf.transpose(pcont, dims) 61 | if bootstrap is None: 62 | bootstrap = tf.zeros_like(value[-1]) 63 | next_values = tf.concat([value[1:], bootstrap[None]], 0) 64 | inputs = reward + pcont * next_values * (1 - lambda_) 65 | return_ = tf.scan( 66 | fn=lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, 67 | elems=(inputs, pcont), 68 | initializer=bootstrap, 69 | back_prop=not stop_gradient, 70 | reverse=True) 71 | return_ = tf.transpose(return_, dims) 72 | if stop_gradient: 73 | return_ = tf.stop_gradient(return_) 74 | return return_ 75 | 76 | 77 | def fixed_step_return( 78 | reward, value, discount, steps, axis, stop_gradient=True): 79 | # Brings the aggregation dimension front. 80 | dims = list(range(reward.shape.ndims)) 81 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 82 | reward = tf.transpose(reward, dims) 83 | length = tf.shape(reward)[0] 84 | _, return_ = tf.while_loop( 85 | cond=lambda i, p: i < steps + 1, 86 | body=lambda i, p: (i + 1, reward[steps - i: length - i] + discount * p), 87 | loop_vars=[tf.constant(1), tf.zeros_like(reward[steps:])], 88 | back_prop=not stop_gradient) 89 | if value is not None: 90 | return_ += discount ** steps * tf.transpose(value, dims)[steps:] 91 | return_ = tf.transpose(return_, dims) 92 | if stop_gradient: 93 | return_ = tf.stop_gradient(return_) 94 | return return_ 95 | -------------------------------------------------------------------------------- /dreamer/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from .one_hot import OneHot 20 | from .tanh_normal import TanhNormal 21 | -------------------------------------------------------------------------------- /dreamer/distributions/one_hot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | from tensorflow_probability import distributions as tfd 21 | 22 | 23 | class OneHot(object): 24 | 25 | def __init__(self, logits=None, probs=None, gradient='score'): 26 | self._gradient = gradient 27 | self._dist = tfd.Categorical(logits=logits, probs=probs) 28 | self._num_classes = self.mean().shape[-1].value 29 | 30 | @property 31 | def name(self): 32 | return 'OneHotDistribution' 33 | 34 | def __getattr__(self, name): 35 | return getattr(self._dist, name) 36 | 37 | def prob(self, events): 38 | indices = tf.argmax(events, axis=-1) 39 | return self._dist.prob(indices) 40 | 41 | def log_prob(self, events): 42 | indices = tf.argmax(events, axis=-1) 43 | return self._dist.log_prob(indices) 44 | 45 | def mean(self): 46 | return self._dist.probs_parameter() 47 | 48 | def stddev(self): 49 | values = tf.one_hot(tf.range(self._num_classes), self._num_classes) 50 | distances = tf.reduce_sum((values - self.mean()[..., None, :]) ** 2, -1) 51 | return tf.sqrt(distances) 52 | 53 | def mode(self): 54 | return tf.one_hot(self._dist.mode(), self._num_classes) 55 | 56 | def sample(self, amount=None): 57 | amount = [amount] if amount else [] 58 | sample = tf.one_hot(self._dist.sample(*amount), self._num_classes) 59 | if self._gradient == 'score': # Implemented as DiCE. 60 | logp = self.log_prob(sample)[..., None] 61 | sample *= tf.exp(logp - tf.stop_gradient(logp)) 62 | elif self._gradient == 'straight': # Gradient for all classes. 63 | probs = self._dist.probs_parameter() 64 | sample += probs - tf.stop_gradient(probs) 65 | else: 66 | raise NotImplementedError(self._gradient) 67 | return sample 68 | -------------------------------------------------------------------------------- /dreamer/distributions/tanh_normal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | import tensorflow_probability as tfp 21 | from tensorflow_probability import distributions as tfd 22 | 23 | 24 | class TanhNormal(object): 25 | 26 | def __init__(self, mean, std, samples=100): 27 | dist = tfd.Normal(mean, std) 28 | dist = tfd.TransformedDistribution(dist, TanhBijector()) 29 | dist = tfd.Independent(dist, 1) 30 | self._dist = dist 31 | self._samples = samples 32 | 33 | @property 34 | def name(self): 35 | return 'TanhNormalDistribution' 36 | 37 | def __getattr__(self, name): 38 | return getattr(self._dist, name) 39 | 40 | def mean(self): 41 | samples = self._dist.sample(self._samples) 42 | return tf.reduce_mean(samples, 0) 43 | 44 | def stddev(self): 45 | samples = self._dist.sample(self._samples) 46 | mean = tf.reduce_mean(samples, 0, keep_dims=True) 47 | return tf.reduce_mean(tf.pow(samples - mean, 2), 0) 48 | 49 | def mode(self): 50 | samples = self._dist.sample(self._samples) 51 | logprobs = self._dist.log_prob(samples) 52 | mask = tf.one_hot(tf.argmax(logprobs, axis=0), self._samples, axis=0) 53 | return tf.reduce_sum(samples * mask[..., None], 0) 54 | 55 | def entropy(self): 56 | sample = self._dist.sample(self._samples) 57 | logprob = self.log_prob(sample) 58 | return -tf.reduce_mean(logprob, 0) 59 | 60 | 61 | class TanhBijector(tfp.bijectors.Bijector): 62 | 63 | def __init__(self, validate_args=False, name='tanh'): 64 | super(TanhBijector, self).__init__( 65 | forward_min_event_ndims=0, 66 | validate_args=validate_args, 67 | name=name) 68 | 69 | def _forward(self, x): 70 | return tf.nn.tanh(x) 71 | 72 | def _inverse(self, y): 73 | precision = 0.99999997 74 | clipped = tf.where( 75 | tf.less_equal(tf.abs(y), 1.), 76 | tf.clip_by_value(y, -precision, precision), y) 77 | # y = tf.stop_gradient(clipped) + y - tf.stop_gradient(y) 78 | return tf.atanh(clipped) 79 | 80 | def _forward_log_det_jacobian(self, x): 81 | log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) 82 | return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) 83 | -------------------------------------------------------------------------------- /dreamer/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from .base import Base 20 | from .rssm import RSSM 21 | -------------------------------------------------------------------------------- /dreamer/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | from tensorflow_probability import distributions as tfd 21 | 22 | from dreamer import tools 23 | 24 | 25 | class Base(tf.nn.rnn_cell.RNNCell): 26 | 27 | def __init__(self, transition_tpl, posterior_tpl, reuse=None): 28 | super(Base, self).__init__(_reuse=reuse) 29 | self._posterior_tpl = posterior_tpl 30 | self._transition_tpl = transition_tpl 31 | self._debug = False 32 | 33 | @property 34 | def state_size(self): 35 | raise NotImplementedError 36 | 37 | @property 38 | def updates(self): 39 | return [] 40 | 41 | @property 42 | def losses(self): 43 | return [] 44 | 45 | @property 46 | def output_size(self): 47 | return (self.state_size, self.state_size) 48 | 49 | def zero_state(self, batch_size, dtype): 50 | return tools.nested.map( 51 | lambda size: tf.zeros([batch_size, size], dtype), 52 | self.state_size) 53 | 54 | def features_from_state(self, state): 55 | raise NotImplementedError 56 | 57 | def dist_from_state(self, state, mask=None): 58 | raise NotImplementedError 59 | 60 | def divergence_from_states(self, lhs, rhs, mask=None): 61 | lhs = self.dist_from_state(lhs, mask) 62 | rhs = self.dist_from_state(rhs, mask) 63 | divergence = tfd.kl_divergence(lhs, rhs) 64 | if mask is not None: 65 | divergence = tools.mask(divergence, mask) 66 | return divergence 67 | 68 | def call(self, inputs, prev_state): 69 | obs, prev_action, use_obs = inputs 70 | if self._debug: 71 | with tf.control_dependencies([tf.assert_equal(use_obs, use_obs[0, 0])]): 72 | use_obs = tf.identity(use_obs) 73 | use_obs = use_obs[0, 0] 74 | zero_obs = tools.nested.map(tf.zeros_like, obs) 75 | prior = self._transition_tpl(prev_state, prev_action, zero_obs) 76 | posterior = tf.cond( 77 | use_obs, 78 | lambda: self._posterior_tpl(prev_state, prev_action, obs), 79 | lambda: prior) 80 | return (prior, posterior), posterior 81 | -------------------------------------------------------------------------------- /dreamer/models/rssm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | from tensorflow_probability import distributions as tfd 21 | 22 | from dreamer import tools 23 | from dreamer.models import base 24 | 25 | 26 | class RSSM(base.Base): 27 | 28 | def __init__( 29 | self, state_size, belief_size, embed_size, 30 | future_rnn=True, mean_only=False, min_stddev=0.1, activation=tf.nn.elu, 31 | num_layers=1): 32 | self._state_size = state_size 33 | self._belief_size = belief_size 34 | self._embed_size = embed_size 35 | self._future_rnn = future_rnn 36 | self._cell = tf.contrib.rnn.GRUBlockCell(self._belief_size) 37 | self._kwargs = dict(units=self._embed_size, activation=activation) 38 | self._mean_only = mean_only 39 | self._min_stddev = min_stddev 40 | self._num_layers = num_layers 41 | super(RSSM, self).__init__( 42 | tf.make_template('transition', self._transition), 43 | tf.make_template('posterior', self._posterior)) 44 | 45 | @property 46 | def state_size(self): 47 | return { 48 | 'mean': self._state_size, 49 | 'stddev': self._state_size, 50 | 'sample': self._state_size, 51 | 'belief': self._belief_size, 52 | 'rnn_state': self._belief_size, 53 | } 54 | 55 | @property 56 | def feature_size(self): 57 | return self._belief_size + self._state_size 58 | 59 | def dist_from_state(self, state, mask=None): 60 | if mask is not None: 61 | stddev = tools.mask(state['stddev'], mask, value=1) 62 | else: 63 | stddev = state['stddev'] 64 | dist = tfd.MultivariateNormalDiag(state['mean'], stddev) 65 | return dist 66 | 67 | def features_from_state(self, state): 68 | return tf.concat([state['sample'], state['belief']], -1) 69 | 70 | def divergence_from_states(self, lhs, rhs, mask=None): 71 | lhs = self.dist_from_state(lhs, mask) 72 | rhs = self.dist_from_state(rhs, mask) 73 | divergence = tfd.kl_divergence(lhs, rhs) 74 | if mask is not None: 75 | divergence = tools.mask(divergence, mask) 76 | return divergence 77 | 78 | def _transition(self, prev_state, prev_action, zero_obs): 79 | hidden = tf.concat([prev_state['sample'], prev_action], -1) 80 | for _ in range(self._num_layers): 81 | hidden = tf.layers.dense(hidden, **self._kwargs) 82 | belief, rnn_state = self._cell(hidden, prev_state['rnn_state']) 83 | if self._future_rnn: 84 | hidden = belief 85 | for _ in range(self._num_layers): 86 | hidden = tf.layers.dense(hidden, **self._kwargs) 87 | mean = tf.layers.dense(hidden, self._state_size, None) 88 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) 89 | stddev += self._min_stddev 90 | if self._mean_only: 91 | sample = mean 92 | else: 93 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample() 94 | return { 95 | 'mean': mean, 96 | 'stddev': stddev, 97 | 'sample': sample, 98 | 'belief': belief, 99 | 'rnn_state': rnn_state, 100 | } 101 | 102 | def _posterior(self, prev_state, prev_action, obs): 103 | prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs)) 104 | hidden = tf.concat([prior['belief'], obs], -1) 105 | for _ in range(self._num_layers): 106 | hidden = tf.layers.dense(hidden, **self._kwargs) 107 | mean = tf.layers.dense(hidden, self._state_size, None) 108 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) 109 | stddev += self._min_stddev 110 | if self._mean_only: 111 | sample = mean 112 | else: 113 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample() 114 | return { 115 | 'mean': mean, 116 | 'stddev': stddev, 117 | 'sample': sample, 118 | 'belief': prior['belief'], 119 | 'rnn_state': prior['rnn_state'], 120 | } 121 | -------------------------------------------------------------------------------- /dreamer/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import conv 20 | from . import proprio 21 | from .basic import feed_forward 22 | -------------------------------------------------------------------------------- /dreamer/networks/basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow_probability import distributions as tfd 22 | 23 | from dreamer import distributions 24 | from dreamer import tools 25 | 26 | 27 | def feed_forward( 28 | features, data_shape, num_layers=2, activation=tf.nn.relu, 29 | mean_activation=None, stop_gradient=False, trainable=True, units=100, 30 | std=1.0, low=-1.0, high=1.0, dist='normal', min_std=1e-2, init_std=1.0): 31 | hidden = features 32 | if stop_gradient: 33 | hidden = tf.stop_gradient(hidden) 34 | for _ in range(num_layers): 35 | hidden = tf.layers.dense(hidden, units, activation, trainable=trainable) 36 | mean = tf.layers.dense( 37 | hidden, int(np.prod(data_shape)), mean_activation, trainable=trainable) 38 | mean = tf.reshape(mean, tools.shape(features)[:-1] + data_shape) 39 | if std == 'learned': 40 | std = tf.layers.dense( 41 | hidden, int(np.prod(data_shape)), None, trainable=trainable) 42 | init_std = np.log(np.exp(init_std) - 1) 43 | std = tf.nn.softplus(std + init_std) + min_std 44 | std = tf.reshape(std, tools.shape(features)[:-1] + data_shape) 45 | if dist == 'normal': 46 | dist = tfd.Normal(mean, std) 47 | dist = tfd.Independent(dist, len(data_shape)) 48 | elif dist == 'deterministic': 49 | dist = tfd.Deterministic(mean) 50 | dist = tfd.Independent(dist, len(data_shape)) 51 | elif dist == 'binary': 52 | dist = tfd.Bernoulli(mean) 53 | dist = tfd.Independent(dist, len(data_shape)) 54 | elif dist == 'trunc_normal': 55 | # https://www.desmos.com/calculator/rnksmhtgui 56 | dist = tfd.TruncatedNormal(mean, std, low, high) 57 | dist = tfd.Independent(dist, len(data_shape)) 58 | elif dist == 'tanh_normal': 59 | # https://www.desmos.com/calculator/794s8kf0es 60 | dist = distributions.TanhNormal(mean, std) 61 | elif dist == 'tanh_normal_tanh': 62 | # https://www.desmos.com/calculator/794s8kf0es 63 | mean = 5.0 * tf.tanh(mean / 5.0) 64 | dist = distributions.TanhNormal(mean, std) 65 | elif dist == 'onehot_score': 66 | dist = distributions.OneHot(mean, gradient='score') 67 | elif dist == 'onehot_straight': 68 | dist = distributions.OneHot(mean, gradient='straight') 69 | else: 70 | raise NotImplementedError(dist) 71 | return dist 72 | -------------------------------------------------------------------------------- /dreamer/networks/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow_probability import distributions as tfd 22 | 23 | from dreamer import tools 24 | 25 | 26 | def encoder(obs): 27 | kwargs = dict(strides=2, activation=tf.nn.relu) 28 | hidden = tf.reshape(obs['image'], [-1] + obs['image'].shape[2:].as_list()) 29 | hidden = tf.layers.conv2d(hidden, 32, 4, **kwargs) 30 | hidden = tf.layers.conv2d(hidden, 64, 4, **kwargs) 31 | hidden = tf.layers.conv2d(hidden, 128, 4, **kwargs) 32 | hidden = tf.layers.conv2d(hidden, 256, 4, **kwargs) 33 | hidden = tf.layers.flatten(hidden) 34 | assert hidden.shape[1:].as_list() == [1024], hidden.shape.as_list() 35 | hidden = tf.reshape(hidden, tools.shape(obs['image'])[:2] + [ 36 | np.prod(hidden.shape[1:].as_list())]) 37 | return hidden 38 | 39 | 40 | def decoder(features, data_shape, std=1.0): 41 | kwargs = dict(strides=2, activation=tf.nn.relu) 42 | hidden = tf.layers.dense(features, 1024, None) 43 | hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1].value]) 44 | hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs) 45 | hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs) 46 | hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs) 47 | mean = tf.layers.conv2d_transpose(hidden, data_shape[-1], 6, strides=2) 48 | assert mean.shape[1:].as_list() == data_shape, mean.shape 49 | mean = tf.reshape(mean, tools.shape(features)[:-1] + data_shape) 50 | return tfd.Independent(tfd.Normal(mean, std), len(data_shape)) 51 | -------------------------------------------------------------------------------- /dreamer/networks/proprio.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer import tools 22 | 23 | 24 | def encoder(obs, keys=None, num_layers=3, units=300, activation=tf.nn.relu): 25 | if not keys: 26 | keys = [key for key in obs.keys() if key != 'image'] 27 | inputs = tf.concat([obs[key] for key in keys], -1) 28 | hidden = tf.reshape(inputs, [-1] + inputs.shape[2:].as_list()) 29 | for _ in range(num_layers): 30 | hidden = tf.layers.dense(hidden, units, activation) 31 | hidden = tf.reshape(hidden, tools.shape(inputs)[:2] + [ 32 | hidden.shape[1].value]) 33 | return hidden 34 | -------------------------------------------------------------------------------- /dreamer/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import configs 20 | from . import tasks 21 | from . import train 22 | -------------------------------------------------------------------------------- /dreamer/scripts/fileutil_fetch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import pathlib 17 | import re 18 | import threading 19 | 20 | import sh 21 | 22 | 23 | def execute_commands(commands, parallel): 24 | semaphore = threading.Semaphore(parallel) 25 | def done_fn(cmd, success, exit_code): 26 | print(cmd._foo) 27 | semaphore.release() 28 | running = [] 29 | for command in commands: 30 | semaphore.acquire() 31 | running.append(command(_bg=True, _done=done_fn)) 32 | running[-1]._foo = command._foo 33 | failures = 0 34 | outputs = [] 35 | for command in running: 36 | try: 37 | command.wait() 38 | outputs.append(command.stdout.decode('utf-8')) 39 | except sh.ErrorReturnCode as e: 40 | print(e) 41 | failures += 1 42 | print('') 43 | return outputs, failures 44 | 45 | 46 | def main(args): 47 | pattern = re.compile(args.pattern) 48 | filenames = sh.fileutil.ls('-R', args.indir) 49 | filenames = [filename.strip() for filename in filenames] 50 | filenames = [filename for filename in filenames if pattern.search(filename)] 51 | print('Found', len(filenames), 'filenames.') 52 | commands = [] 53 | for index, filename in enumerate(filenames): 54 | relative = pathlib.Path(filename).relative_to(args.indir) 55 | destination = args.outdir / relative 56 | if not args.overwrite and destination.exists(): 57 | continue 58 | destination.parent.mkdir(parents=True, exist_ok=True) 59 | flags = [filename, destination] 60 | if args.overwrite: 61 | flags = ['-f'] + flags 62 | command = sh.fileutil.cp.bake(*flags) 63 | command._foo = f'{index + 1}/{len(filenames)} {relative}' 64 | commands.append(command) 65 | print(f'Executing {args.parallel} in parallel.') 66 | execute_commands(commands, args.parallel) 67 | 68 | 69 | if __name__ == '__main__': 70 | boolean = lambda x: bool(['False', 'True'].index(x)) 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--indir', type=pathlib.Path, required=True) 73 | parser.add_argument('--outdir', type=pathlib.Path, required=True) 74 | parser.add_argument('--parallel', type=int, default=100) 75 | # parser.add_argument('--pattern', type=str, default='.*/records.yaml$') 76 | parser.add_argument( 77 | '--pattern', type=str, 78 | default='.*/(return|length)/records.jsonl$') 79 | parser.add_argument('--subdir', type=boolean, default=True) 80 | parser.add_argument('--overwrite', type=boolean, default=False) 81 | _args = parser.parse_args() 82 | if _args.subdir: 83 | _args.outdir /= _args.indir.stem 84 | main(_args) 85 | -------------------------------------------------------------------------------- /dreamer/scripts/objectives.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer import tools 22 | from dreamer.control import temporal_difference as td 23 | 24 | 25 | def reward(state, graph, params): 26 | features = graph.cell.features_from_state(state) 27 | reward = graph.heads.reward(features).mean() 28 | return tf.reduce_sum(reward, 1) 29 | 30 | 31 | def reward_value(state, graph, params): 32 | features = graph.cell.features_from_state(state) 33 | reward = graph.heads.reward(features).mean() 34 | value = graph.heads.value(features).mean() 35 | value *= tools.schedule.linear( 36 | graph.step, params.get('objective_value_ramp', 0)) 37 | return_ = td.lambda_return( 38 | reward[:, :-1], value[:, :-1], value[:, -1], 39 | params.get('planner_discount', 0.99), 40 | params.get('planner_lambda', 0.95), 41 | axis=1) 42 | return return_[:, 0] 43 | -------------------------------------------------------------------------------- /dreamer/scripts/sync.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | # Copyright 2019 The Dreamer Authors. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import glob 19 | import os 20 | import shutil 21 | 22 | 23 | def find_source_files(directory): 24 | top_level = glob.glob(directory + '*.py') 25 | recursive = glob.glob(directory + '/**/*.py') 26 | return top_level + recursive 27 | 28 | 29 | def copy_source_tree(source_dir, target_dir): 30 | for source in find_source_files(source_dir): 31 | target = os.path.join(target_dir, os.path.relpath(source, source_dir)) 32 | os.makedirs(os.path.dirname(target), exist_ok=True) 33 | if os.path.exists(target): 34 | print('Override', os.path.relpath(target, target_dir)) 35 | else: 36 | print('Add', os.path.relpath(target, target_dir)) 37 | shutil.copy(source, target) 38 | for target in find_source_files(target_dir): 39 | source = os.path.join(source_dir, os.path.relpath(target, target_dir)) 40 | if not os.path.exists(source): 41 | print('Remove', os.path.relpath(target, target_dir)) 42 | os.remove(target) 43 | 44 | 45 | def infer_headers(directory): 46 | try: 47 | filename = find_source_files(directory)[0] 48 | except IndexError: 49 | raise RuntimeError('No code files found in {}.'.format(directory)) 50 | header = [] 51 | with open(filename, 'r') as f: 52 | for index, line in enumerate(f): 53 | if index == 0 and not line.startswith('#'): 54 | break 55 | if not line.startswith('#') and line.strip(' \n'): 56 | break 57 | header.append(line) 58 | return header 59 | 60 | 61 | def add_headers(directory, header): 62 | for filename in find_source_files(directory): 63 | with open(filename, 'r') as f: 64 | text = f.readlines() 65 | with open(filename, 'w') as f: 66 | f.write(''.join(header + text)) 67 | 68 | 69 | def remove_headers(directory, header): 70 | for filename in find_source_files(directory): 71 | with open(filename, 'r') as f: 72 | text = f.readlines() 73 | if text[:len(header)] == header: 74 | text = text[len(header):] 75 | with open(filename, 'w') as f: 76 | f.write(''.join(text)) 77 | 78 | 79 | def main(args): 80 | print('Inferring headers.\n') 81 | source_header = infer_headers(args.source) 82 | print('{} Source header {}\n{}{}\n'.format( 83 | '-' * 32, '-' * 32, ''.join(source_header), '-' * 79)) 84 | target_header = infer_headers(args.target) 85 | print('{} Target header {}\n{}{}\n'.format( 86 | '-' * 32, '-' * 32, ''.join(target_header), '-' * 79)) 87 | print('Synchronizing directories.') 88 | copy_source_tree(args.source, args.target) 89 | if target_header and not source_header: 90 | print('Adding headers.') 91 | add_headers(args.target, target_header) 92 | if source_header and not target_header: 93 | print('Removing headers.') 94 | remove_headers(args.target, source_header) 95 | print('Done.') 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--source', type=os.path.expanduser, required=True) 101 | parser.add_argument('--target', type=os.path.expanduser, required=True) 102 | main(parser.parse_args()) 103 | -------------------------------------------------------------------------------- /dreamer/scripts/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import functools 21 | import os 22 | import pathlib 23 | import sys 24 | 25 | sys.path.append(str(pathlib.Path(__file__).resolve().parent.parent.parent)) 26 | 27 | import ruamel.yaml as yaml 28 | import tensorflow as tf 29 | 30 | from dreamer import tools 31 | from dreamer import training 32 | from dreamer.scripts import configs 33 | 34 | 35 | def process(logdir, args): 36 | with args.params.unlocked: 37 | args.params.logdir = logdir 38 | config = configs.make_config(args.params) 39 | logdir = pathlib.Path(logdir) 40 | metrics = tools.Metrics(logdir / 'metrics', workers=5) 41 | training.utility.collect_initial_episodes(metrics, config) 42 | tf.reset_default_graph() 43 | dataset = tools.numpy_episodes.numpy_episodes( 44 | config.train_dir, config.test_dir, config.batch_shape, 45 | reader=config.data_reader, 46 | loader=config.data_loader, 47 | num_chunks=config.num_chunks, 48 | preprocess_fn=config.preprocess_fn, 49 | gpu_prefetch=config.gpu_prefetch) 50 | metrics = tools.InGraphMetrics(metrics) 51 | build_graph = tools.bind(training.define_model, logdir, metrics) 52 | for score in training.utility.train(build_graph, dataset, logdir, config): 53 | yield score 54 | 55 | 56 | def main(args): 57 | experiment = training.Experiment( 58 | args.logdir, 59 | process_fn=functools.partial(process, args=args), 60 | num_runs=args.num_runs, 61 | ping_every=args.ping_every, 62 | resume_runs=args.resume_runs) 63 | for run in experiment: 64 | for unused_score in run: 65 | pass 66 | 67 | 68 | if __name__ == '__main__': 69 | boolean = lambda x: bool(['False', 'True'].index(x)) 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--logdir', type=pathlib.Path, required=True) 72 | parser.add_argument('--params', default='{}') 73 | parser.add_argument('--num_runs', type=int, default=1) 74 | parser.add_argument('--ping_every', type=int, default=0) 75 | parser.add_argument('--resume_runs', type=boolean, default=True) 76 | parser.add_argument('--dmlab_runfiles_path', default=None) 77 | args_, remaining = parser.parse_known_args() 78 | params_ = args_.params.replace('#', ',').replace('\\', '') 79 | args_.params = tools.AttrDict(yaml.safe_load(params_)) 80 | if args_.dmlab_runfiles_path: 81 | with args_.params.unlocked: 82 | args_.params.dmlab_runfiles_path = args_.dmlab_runfiles_path 83 | assert args_.params.dmlab_runfiles_path # Mark as accessed. 84 | args_.logdir = args_.logdir and os.path.expanduser(args_.logdir) 85 | remaining.insert(0, sys.argv[0]) 86 | tf.app.run(lambda _: main(args_), remaining) 87 | -------------------------------------------------------------------------------- /dreamer/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import nested 20 | from . import numpy_episodes 21 | from . import preprocess 22 | from . import schedule 23 | from . import summary 24 | from . import unroll 25 | from .attr_dict import AttrDict 26 | from .bind import bind 27 | from .copy_weights import soft_copy_weights 28 | from .count_dataset import count_dataset 29 | from .count_weights import count_weights 30 | from .custom_optimizer import CustomOptimizer 31 | from .filter_variables_lib import filter_variables 32 | from .gif_summary import gif_summary 33 | from .image_strip_summary import image_strip_summary 34 | from .mask import mask 35 | from .metrics import InGraphMetrics 36 | from .metrics import Metrics 37 | from .overshooting import overshooting 38 | from .reshape_as import reshape_as 39 | from .shape import shape 40 | from .streaming_mean import StreamingMean 41 | from .target_network import track_network 42 | -------------------------------------------------------------------------------- /dreamer/tools/attr_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import contextlib 21 | import os 22 | 23 | import numpy as np 24 | import ruamel.yaml as yaml 25 | 26 | 27 | class AttrDict(dict): 28 | 29 | def __init__(self, *args, **kwargs): 30 | unlocked = kwargs.pop('_unlocked', not (args or kwargs)) 31 | defaults = kwargs.pop('_defaults', {}) 32 | touched = kwargs.pop('_touched', set()) 33 | super(AttrDict, self).__setattr__('_unlocked', True) 34 | super(AttrDict, self).__setattr__('_touched', set()) 35 | super(AttrDict, self).__setattr__('_defaults', {}) 36 | super(AttrDict, self).__init__(*args, **kwargs) 37 | super(AttrDict, self).__setattr__('_unlocked', unlocked) 38 | super(AttrDict, self).__setattr__('_defaults', defaults) 39 | super(AttrDict, self).__setattr__('_touched', touched) 40 | 41 | def __getattr__(self, name): 42 | try: 43 | return self[name] 44 | except KeyError: 45 | raise AttributeError(name) 46 | 47 | def __setattr__(self, name, value): 48 | self[name] = value 49 | 50 | def __getitem__(self, name): 51 | if name.startswith('__'): 52 | raise AttributeError(name) 53 | self._touched.add(name) 54 | if name in self: 55 | return super(AttrDict, self).__getitem__(name) 56 | if name in self._defaults: 57 | return self._defaults[name] 58 | raise AttributeError(name) 59 | 60 | def __setitem__(self, name, value): 61 | if name.startswith('_'): 62 | raise AttributeError('Cannot set private attribute {}'.format(name)) 63 | if name.startswith('__'): 64 | raise AttributeError("Cannot set magic attribute '{}'".format(name)) 65 | if not self._unlocked: 66 | message = 'Use obj.unlock() before setting {}' 67 | raise RuntimeError(message.format(name)) 68 | super(AttrDict, self).__setitem__(name, value) 69 | 70 | def __repr__(self): 71 | items = [] 72 | for key, value in self.items(): 73 | items.append('{}: {}'.format(key, self._format_value(value))) 74 | return '{' + ', '.join(items) + '}' 75 | 76 | def get(self, key, default=None): 77 | self._touched.add(key) 78 | if key not in self: 79 | return default 80 | return self[key] 81 | 82 | @property 83 | def untouched(self): 84 | return sorted(set(self.keys()) - self._touched) 85 | 86 | @property 87 | @contextlib.contextmanager 88 | def unlocked(self): 89 | self.unlock() 90 | yield 91 | self.lock() 92 | 93 | def lock(self): 94 | super(AttrDict, self).__setattr__('_unlocked', False) 95 | for value in self.values(): 96 | if isinstance(value, AttrDict): 97 | value.lock() 98 | 99 | def unlock(self): 100 | super(AttrDict, self).__setattr__('_unlocked', True) 101 | for value in self.values(): 102 | if isinstance(value, AttrDict): 103 | value.unlock() 104 | 105 | def summarize(self): 106 | items = [] 107 | for key, value in self.items(): 108 | items.append('{}: {}'.format(key, self._format_value(value))) 109 | return '\n'.join(items) 110 | 111 | def update(self, mapping): 112 | if not self._unlocked: 113 | message = 'Use obj.unlock() before updating' 114 | raise RuntimeError(message) 115 | super(AttrDict, self).update(mapping) 116 | return self 117 | 118 | def copy(self, _unlocked=False): 119 | return type(self)(super(AttrDict, self).copy(), _unlocked=_unlocked) 120 | 121 | def save(self, filename): 122 | assert str(filename).endswith('.yaml') 123 | directory = os.path.dirname(str(filename)) 124 | os.makedirs(directory, exist_ok=True) 125 | with open(filename, 'w') as f: 126 | yaml.dump(collections.OrderedDict(self), f) 127 | 128 | @classmethod 129 | def load(cls, filename): 130 | assert str(filename).endswith('.yaml') 131 | with open(filename, 'r') as f: 132 | return cls(yaml.load(f, Loader=yaml.Loader)) 133 | 134 | def _format_value(self, value): 135 | if isinstance(value, np.ndarray): 136 | template = '' 137 | min_ = self._format_value(value.min()) 138 | mean = self._format_value(value.mean()) 139 | max_ = self._format_value(value.max()) 140 | return template.format(value.shape, value.dtype, min_, mean, max_) 141 | if isinstance(value, float) and 1e-3 < abs(value) < 1e6: 142 | return '{:.3f}'.format(value) 143 | if isinstance(value, float): 144 | return '{:4.1e}'.format(value) 145 | if hasattr(value, '__name__'): 146 | return value.__name__ 147 | return str(value) 148 | -------------------------------------------------------------------------------- /dreamer/tools/bind.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | 20 | class bind(object): 21 | 22 | def __init__(self, fn, *args, **kwargs): 23 | self._fn = fn 24 | self._args = args 25 | self._kwargs = kwargs 26 | 27 | def __call__(self, *args, **kwargs): 28 | args_ = self._args + args 29 | kwargs_ = self._kwargs.copy() 30 | kwargs_.update(kwargs) 31 | return self._fn(*args_, **kwargs_) 32 | 33 | def __repr__(self): 34 | return 'bind({})'.format(self._fn.__name__) 35 | 36 | def copy(self, *args, **kwargs): 37 | args_ = self._args + args 38 | kwargs_ = self._kwargs.copy() 39 | kwargs_.update(kwargs) 40 | return bind(self._fn, *args_, **kwargs_) 41 | -------------------------------------------------------------------------------- /dreamer/tools/chunk_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer.tools import nested 22 | 23 | 24 | def chunk_sequence(sequence, chunk_length, randomize=True, num_chunks=None): 25 | if 'length' in sequence: 26 | length = sequence.pop('length') 27 | else: 28 | length = tf.shape(nested.flatten(sequence)[0])[0] 29 | if randomize: 30 | if not num_chunks: 31 | num_chunks = tf.maximum(1, length // chunk_length - 1) 32 | else: 33 | num_chunks = num_chunks + 0 * length 34 | used_length = num_chunks * chunk_length 35 | max_offset = length - used_length 36 | offset = tf.random_uniform((), 0, max_offset + 1, dtype=tf.int32) 37 | else: 38 | if num_chunks is None: 39 | num_chunks = length // chunk_length 40 | else: 41 | num_chunks = num_chunks + 0 * length 42 | used_length = num_chunks * chunk_length 43 | offset = 0 44 | clipped = nested.map( 45 | lambda tensor: tensor[offset: offset + used_length], 46 | sequence) 47 | chunks = nested.map( 48 | lambda tensor: tf.reshape( 49 | tensor, [num_chunks, chunk_length] + tensor.shape[1:].as_list()), 50 | clipped) 51 | return chunks 52 | -------------------------------------------------------------------------------- /dreamer/tools/copy_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer.tools import filter_variables_lib 22 | 23 | 24 | def soft_copy_weights(source_pattern, target_pattern, amount): 25 | if not (0 < amount <= 0.5 or amount == 1): 26 | message = 'Copy amount should probably be less than half or everything,' 27 | message += ' not {}'.format(amount) 28 | raise ValueError(message) 29 | source_vars = filter_variables_lib.filter_variables(include=source_pattern) 30 | target_vars = filter_variables_lib.filter_variables(include=target_pattern) 31 | source_vars = sorted(source_vars, key=lambda x: x.name) 32 | target_vars = sorted(target_vars, key=lambda x: x.name) 33 | assert len(source_vars) == len(target_vars) 34 | updates = [] 35 | for source, target in zip(source_vars, target_vars): 36 | assert source.name != target.name 37 | if amount == 1.0: 38 | updates.append(target.assign(source)) 39 | else: 40 | updates.append(target.assign((1 - amount) * target + amount * source)) 41 | return tf.group(*updates) 42 | -------------------------------------------------------------------------------- /dreamer/tools/count_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | def count_dataset(directory): 26 | directory = os.path.expanduser(directory) 27 | if not tf.gfile.Exists(directory): 28 | message = "Data set directory '{}' does not exist." 29 | raise ValueError(message.format(directory)) 30 | pattern = os.path.join(directory, '*.npz') 31 | def func(): 32 | filenames = tf.gfile.Glob(pattern) 33 | episodes = len(filenames) 34 | episodes = np.array(episodes, dtype=np.int32) 35 | return episodes 36 | return tf.py_func(func, [], tf.int32) 37 | -------------------------------------------------------------------------------- /dreamer/tools/count_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import re 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | def count_weights(scope=None, exclude=None, graph=None): 26 | if scope: 27 | scope = scope if scope.endswith('/') else scope + '/' 28 | graph = graph or tf.get_default_graph() 29 | vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 30 | if scope: 31 | vars_ = [var for var in vars_ if var.name.startswith(scope)] 32 | if exclude: 33 | exclude = re.compile(exclude) 34 | vars_ = [var for var in vars_ if not exclude.match(var.name)] 35 | shapes = [] 36 | for var in vars_: 37 | if not var.shape.is_fully_defined(): 38 | message = "Trainable variable '{}' has undefined shape '{}'." 39 | raise ValueError(message.format(var.name, var.shape)) 40 | shapes.append(var.shape.as_list()) 41 | return int(sum(np.prod(shape) for shape in shapes)) 42 | -------------------------------------------------------------------------------- /dreamer/tools/custom_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer.tools import filter_variables_lib 22 | 23 | 24 | class CustomOptimizer(object): 25 | 26 | def __init__( 27 | self, optimizer_cls, step, log, learning_rate, 28 | include=None, exclude=None, clipping=None, schedule=None, 29 | debug=False, name='custom_optimizer'): 30 | if schedule: 31 | learning_rate *= schedule(step) 32 | self._step = step 33 | self._log = log 34 | self._learning_rate = learning_rate 35 | self._variables = filter_variables_lib.filter_variables(include, exclude) 36 | self._clipping = float(clipping) 37 | self._debug = debug 38 | self._name = name 39 | self._optimizer = optimizer_cls(learning_rate=learning_rate, name=name) 40 | 41 | def maybe_minimize(self, condition, loss): 42 | with tf.name_scope('optimizer_{}'.format(self._name)): 43 | # loss = tf.cond(condition, lambda: loss, float) 44 | update_op, grad_norm = tf.cond( 45 | condition, 46 | lambda: self.minimize(loss), 47 | lambda: (tf.no_op(), 0.0)) 48 | with tf.control_dependencies([update_op]): 49 | summary = tf.cond( 50 | tf.logical_and(condition, self._log), 51 | lambda: self.summarize(grad_norm), str) 52 | if self._debug: 53 | # print_op = tf.print('{}_grad_norm='.format(self._name), grad_norm) 54 | message = 'Zero gradient norm in {} optimizer.'.format(self._name) 55 | assertion = lambda: tf.assert_greater(grad_norm, 0.0, message=message) 56 | assert_op = tf.cond(condition, assertion, tf.no_op) 57 | with tf.control_dependencies([assert_op]): 58 | summary = tf.identity(summary) 59 | return summary, grad_norm 60 | 61 | def minimize(self, loss): 62 | if self._debug: 63 | loss = tf.check_numerics(loss, '{}_loss'.format(self._name)) 64 | if hasattr(self._optimizer, 'compute_gradients'): # TF 1.0 65 | gradients, variables = zip(*self._optimizer.compute_gradients( 66 | loss, self._variables, colocate_gradients_with_ops=True)) 67 | else: # Keras 68 | gradients = self._optimizer.get_gradients(loss, self._variables) 69 | variables = self._variables 70 | grad_norm = tf.global_norm(gradients) 71 | if self._clipping: 72 | gradients, _ = tf.clip_by_global_norm( 73 | gradients, self._clipping, grad_norm) 74 | optimize = self._optimizer.apply_gradients(zip(gradients, variables)) 75 | return optimize, grad_norm 76 | 77 | def summarize(self, grad_norm): 78 | summaries = [] 79 | with tf.name_scope('optimizer_{}'.format(self._name)): 80 | summaries.append(tf.summary.scalar('learning_rate', self._learning_rate)) 81 | summaries.append(tf.summary.scalar('grad_norm', grad_norm)) 82 | if self._clipping: 83 | clipped = tf.minimum(grad_norm, self._clipping) 84 | summaries.append(tf.summary.scalar('clipped_gradient_norm', clipped)) 85 | summary = tf.summary.merge(summaries) 86 | return summary 87 | -------------------------------------------------------------------------------- /dreamer/tools/filter_variables_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import re 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def filter_variables(include=None, exclude=None): 25 | # Check arguments. 26 | if include is None: 27 | include = (r'.*',) 28 | if exclude is None: 29 | exclude = () 30 | if not isinstance(include, (tuple, list)): 31 | include = (include,) 32 | if not isinstance(exclude, (tuple, list)): 33 | exclude = (exclude,) 34 | # Compile regexes. 35 | include = [re.compile(regex) for regex in include] 36 | exclude = [re.compile(regex) for regex in exclude] 37 | variables = tf.global_variables() 38 | if not variables: 39 | raise RuntimeError('There are no variables to filter.') 40 | # Check regexes. 41 | for regex in include: 42 | message = "Regex r'{}' does not match any variables in the graph.\n" 43 | message += 'All variables:\n' 44 | message += '\n'.join('- {}'.format(var.name) for var in variables) 45 | if not any(regex.match(variable.name) for variable in variables): 46 | raise RuntimeError(message.format(regex.pattern)) 47 | # Filter variables. 48 | filtered = [] 49 | for variable in variables: 50 | if not any(regex.match(variable.name) for regex in include): 51 | continue 52 | if any(regex.match(variable.name) for regex in exclude): 53 | continue 54 | filtered.append(variable) 55 | # Check result. 56 | if not filtered: 57 | message = 'No variables left after filtering.' 58 | message += '\nIncludes:\n' + '\n'.join(regex.pattern for regex in include) 59 | message += '\nExcludes:\n' + '\n'.join(regex.pattern for regex in exclude) 60 | raise RuntimeError(message) 61 | return filtered 62 | -------------------------------------------------------------------------------- /dreamer/tools/gif_summary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow.python.ops import summary_op_util 22 | 23 | 24 | def encode_gif(images, fps): 25 | from subprocess import Popen, PIPE 26 | h, w, c = images[0].shape 27 | cmd = [ 28 | 'ffmpeg', '-y', 29 | '-f', 'rawvideo', 30 | '-vcodec', 'rawvideo', 31 | '-r', '%.02f' % fps, 32 | '-s', '%dx%d' % (w, h), 33 | '-pix_fmt', {1: 'gray', 3: 'rgb24'}[c], 34 | '-i', '-', 35 | '-filter_complex', 36 | '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', 37 | '-r', '%.02f' % fps, 38 | '-f', 'gif', 39 | '-'] 40 | proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) 41 | for image in images: 42 | proc.stdin.write(image.tostring()) 43 | out, err = proc.communicate() 44 | if proc.returncode: 45 | err = '\n'.join([' '.join(cmd), err.decode('utf8')]) 46 | raise IOError(err) 47 | del proc 48 | return out 49 | 50 | 51 | def py_gif_summary(tag, images, max_outputs, fps): 52 | is_bytes = isinstance(tag, bytes) 53 | if is_bytes: 54 | tag = tag.decode('utf-8') 55 | images = np.asarray(images) 56 | if images.dtype != np.uint8: 57 | raise ValueError('Tensor must have dtype uint8 for gif summary.') 58 | if images.ndim != 5: 59 | raise ValueError('Tensor must be 5-D for gif summary.') 60 | batch_size, _, height, width, channels = images.shape 61 | if channels not in (1, 3): 62 | raise ValueError('Tensors must have 1 or 3 channels for gif summary.') 63 | summ = tf.Summary() 64 | num_outputs = min(batch_size, max_outputs) 65 | for i in range(num_outputs): 66 | image_summ = tf.Summary.Image() 67 | image_summ.height = height 68 | image_summ.width = width 69 | image_summ.colorspace = channels # 1: grayscale, 3: RGB 70 | try: 71 | image_summ.encoded_image_string = encode_gif(images[i], fps) 72 | except (IOError, OSError) as e: 73 | print( 74 | 'Unable to encode images to a gif string because either ffmpeg is ' 75 | 'not installed or ffmpeg returned an error: %s. Falling back to an ' 76 | 'image summary of the first frame in the sequence.', e) 77 | try: 78 | from PIL import Image # pylint: disable=g-import-not-at-top 79 | import io # pylint: disable=g-import-not-at-top 80 | with io.BytesIO() as output: 81 | Image.fromarray(images[i][0]).save(output, 'PNG') 82 | image_summ.encoded_image_string = output.getvalue() 83 | except Exception: 84 | print('Gif summaries requires ffmpeg or PIL to be installed: %s', e) 85 | image_summ.encoded_image_string = ( 86 | ''.encode('utf-8') if is_bytes else '') 87 | if num_outputs == 1: 88 | summ_tag = tag 89 | else: 90 | summ_tag = '{}/{}'.format(tag, i) 91 | summ.value.add(tag=summ_tag, image=image_summ) 92 | summ_str = summ.SerializeToString() 93 | return summ_str 94 | 95 | 96 | def gif_summary(name, tensor, max_outputs, fps, collections=None, family=None): 97 | tensor = tf.convert_to_tensor(tensor) 98 | if tensor.dtype in (tf.float32, tf.float64): 99 | tensor = tf.cast(255.0 * tensor, tf.uint8) 100 | with summary_op_util.summary_scope( 101 | name, family, values=[tensor]) as (tag, scope): 102 | val = tf.py_func( 103 | py_gif_summary, 104 | [tag, tensor, max_outputs, fps], 105 | tf.string, 106 | stateful=False, 107 | name=scope) 108 | summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES]) 109 | return val 110 | -------------------------------------------------------------------------------- /dreamer/tools/image_strip_summary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def image_strip_summary(name, images, max_length=100, max_batch=10): 23 | if max_batch: 24 | images = images[:max_batch] 25 | if max_length: 26 | images = images[:, :max_length] 27 | if images.dtype == tf.uint8: 28 | images = tf.cast(images, tf.float32) / 255.0 29 | length, width = tf.shape(images)[1], tf.shape(images)[3] 30 | channels = images.shape[-1].value 31 | images = tf.transpose(images, [0, 2, 1, 3, 4]) 32 | images = tf.reshape(images, [1, -1, length * width, channels]) 33 | images = tf.clip_by_value(images, 0., 1.) 34 | return tf.summary.image(name, images) 35 | -------------------------------------------------------------------------------- /dreamer/tools/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def mask(tensor, mask=None, length=None, value=0, debug=False): 23 | if len([x for x in (mask, length) if x is not None]) != 1: 24 | raise KeyError('Exactly one of mask and length must be provided.') 25 | with tf.name_scope('mask'): 26 | if mask is None: 27 | range_ = tf.range(tensor.shape[1].value) 28 | mask = range_[None, :] < length[:, None] 29 | batch_dims = mask.shape.ndims 30 | while tensor.shape.ndims > mask.shape.ndims: 31 | mask = mask[..., None] 32 | multiples = [1] * batch_dims + tensor.shape[batch_dims:].as_list() 33 | mask = tf.tile(mask, multiples) 34 | masked = tf.where(mask, tensor, value * tf.ones_like(tensor)) 35 | if debug: 36 | masked = tf.check_numerics(masked, 'masked') 37 | return masked 38 | -------------------------------------------------------------------------------- /dreamer/tools/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from concurrent import futures 20 | import collections 21 | import datetime 22 | import itertools 23 | import json 24 | import pathlib 25 | import re 26 | import uuid 27 | 28 | # import imageio 29 | import numpy as np 30 | import tensorflow as tf 31 | 32 | 33 | class InGraphMetrics: 34 | 35 | def __init__(self, metrics): 36 | self._metrics = metrics 37 | 38 | def set_tags(self, **tags): 39 | keys, values = zip(*sorted(tags.items(), key=lambda x: x[0])) 40 | def inner(*values): 41 | parsed = [] 42 | for index, value in enumerate(values): 43 | if value.dtype == tf.string: 44 | parsed.append(value.numpy().decode('utf-8')) 45 | elif value.dtype in (tf.float32, tf.float64): 46 | parsed.append(float(value.numpy())) 47 | elif value.dtype in (tf.int32, tf.int64): 48 | parsed.append(int(value.numpy())) 49 | else: 50 | raise NotImplementedError(value.dtype) 51 | tags = dict(zip(keys, parsed)) 52 | return self._metrics.set_tags(**tags) 53 | # String tensors in tf.py_function are only supported on CPU. 54 | with tf.device('/cpu:0'): 55 | return tf.py_function(inner, values, [], 'set_tags') 56 | 57 | def reset_tags(self): 58 | return tf.py_function( 59 | self._metrics.reset_tags, [], [], 'reset_tags') 60 | 61 | def flush(self): 62 | def inner(): 63 | _ = self._metrics.flush() 64 | return tf.py_function(inner, [], [], 'flush') 65 | 66 | def add_scalar(self, name, value): 67 | assert len(value.shape) == 0, (name, value) 68 | def inner(value): 69 | self._metrics.add_scalar(name, value.numpy()) 70 | return tf.py_function(inner, [value], [], 'add_scalar_' + name) 71 | 72 | def add_scalars(self, name, value): 73 | assert len(value.shape) == 1, (name, value) 74 | def inner(value): 75 | self._metrics.add_scalars(name, value.numpy()) 76 | return tf.py_function(inner, [value], [], 'add_scalars_' + name) 77 | 78 | def add_tensor(self, name, value): 79 | def inner(value): 80 | self._metrics.add_tensor(name, value.numpy()) 81 | return tf.py_function(inner, [value], [], 'add_tensor_' + name) 82 | 83 | 84 | class Metrics: 85 | 86 | def __init__(self, directory, workers=None): 87 | assert workers is None or isinstance(workers, int) 88 | self._directory = pathlib.Path(directory).expanduser() 89 | self._records = collections.defaultdict(collections.deque) 90 | self._created_directories = set() 91 | self._values = {} 92 | self._tags = {} 93 | self._writers = { 94 | 'npy': np.save, 95 | # 'png': imageio.imwrite, 96 | # 'jpg': imageio.imwrite, 97 | # 'bmp': imageio.imwrite, 98 | # 'gif': functools.partial(imageio.mimwrite, fps=30), 99 | # 'mp4': functools.partial(imageio.mimwrite, fps=30), 100 | } 101 | if workers: 102 | self._pool = futures.ThreadPoolExecutor(max_workers=workers) 103 | else: 104 | self._pool = None 105 | self._last_futures = [] 106 | 107 | @property 108 | def names(self): 109 | names = set(self._records.keys()) 110 | for catalogue in self._directory.glob('**/records.jsonl'): 111 | names.add(self._path_to_name(catalogue.parent)) 112 | return sorted(names) 113 | 114 | def set_tags(self, **kwargs): 115 | reserved = ('value', 'name') 116 | if any(key in reserved for key in kwargs): 117 | message = "Reserved keys '{}' and cannot be used for tags" 118 | raise KeyError(message.format(', '.join(reserved))) 119 | for key, value in kwargs.items(): 120 | key = json.loads(json.dumps(key)) 121 | value = json.loads(json.dumps(value)) 122 | self._tags[key] = value 123 | 124 | def reset_tags(self): 125 | self._tags = {} 126 | 127 | def add_scalar(self, name, value): 128 | self._validate_name(name) 129 | record = self._tags.copy() 130 | record['value'] = float(value) 131 | self._records[name].append(record) 132 | 133 | def add_scalars(self, name, values): 134 | self._validate_name(name) 135 | for value in values: 136 | record = self._tags.copy() 137 | record['value'] = float(value) 138 | self._records[name].append(record) 139 | 140 | def add_tensor(self, name, value, format='npy'): 141 | assert format in ('npy',) 142 | self._validate_name(name) 143 | record = self._tags.copy() 144 | record['filename'] = self._random_filename('npy') 145 | self._values[record['filename']] = value 146 | self._records[name].append(record) 147 | 148 | def add_image(self, name, value, format='png'): 149 | assert format in ('png', 'jpg', 'bmp') 150 | self._validate_name(name) 151 | record = self._tags.copy() 152 | record['filename'] = self._random_filename(format) 153 | self._values[record['filename']] = value 154 | self._records[name].append(record) 155 | 156 | def add_video(self, name, value, format='gif'): 157 | assert format in ('gif', 'mp4') 158 | self._validate_name(name) 159 | record = self._tags.copy() 160 | record['filename'] = self._random_filename(format) 161 | self._values[record['filename']] = value 162 | self._records[name].append(record) 163 | 164 | def flush(self, blocking=None): 165 | if blocking is False and not self._pool: 166 | message = 'Create with workers argument for non-blocking flushing.' 167 | raise ValueError(message) 168 | for future in self._last_futures or []: 169 | try: 170 | future.result() 171 | except Exception as e: 172 | message = 'Previous asynchronous flush failed.' 173 | raise RuntimeError(message) from e 174 | self._last_futures = None 175 | jobs = [] 176 | for name in self._records: 177 | records = list(self._records[name]) 178 | if not records: 179 | continue 180 | self._records[name].clear() 181 | filename = self._name_to_path(name) / 'records.jsonl' 182 | jobs += [(self._append_catalogue, filename, records)] 183 | for record in records: 184 | jobs.append((self._write_file, name, record)) 185 | if self._pool and not blocking: 186 | futures = [] 187 | for job in jobs: 188 | futures.append(self._pool.submit(*job)) 189 | self._last_futures = futures.copy() 190 | return futures 191 | else: 192 | for job in jobs: 193 | job[0](*job[1:]) 194 | 195 | def query(self, pattern=None, **tags): 196 | pattern = pattern and re.compile(pattern) 197 | for name in self.names: 198 | if pattern and not pattern.search(name): 199 | continue 200 | for record in self._read_records(name): 201 | if {k: v for k, v in record.items() if k in tags} != tags: 202 | continue 203 | record['name'] = name 204 | yield record 205 | 206 | def _validate_name(self, name): 207 | assert isinstance(name, str) and name 208 | if re.search(r'[^a-z0-9/_-]+', name): 209 | message = ( 210 | "Invalid metric name '{}'. Names must contain only lower case " 211 | "letters, digits, dashes, underscores, and forward slashes.") 212 | raise NameError(message.format(name)) 213 | 214 | def _name_to_path(self, name): 215 | return self._directory.joinpath(*name.split('/')) 216 | 217 | def _path_to_name(self, path): 218 | return '/'.join(path.relative_to(self._directory).parts) 219 | 220 | def _read_records(self, name): 221 | path = self._name_to_path(name) 222 | catalogue = path / 'records.jsonl' 223 | records = self._records[name] 224 | if catalogue.exists(): 225 | records = itertools.chain(records, self._load_catalogue(catalogue)) 226 | for record in records: 227 | if 'filename' in record: 228 | record['filename'] = path / record['filename'] 229 | yield record 230 | 231 | def _write_file(self, name, record): 232 | if 'filename' not in record: 233 | return 234 | format = pathlib.Path(record['filename']).suffix.lstrip('.') 235 | if format not in self._writers: 236 | raise TypeError('Trying to write unknown format {}'.format(format)) 237 | value = self._values.pop(record['filename']) 238 | filename = self._name_to_path(name) / record['filename'] 239 | self._ensure_directory(filename.parent) 240 | self._writers[format](filename, value) 241 | 242 | def _load_catalogue(self, filename): 243 | rows = [json.loads(line) for line in filename.open('r')] 244 | message = 'Metrics files do not contain lists of mappings' 245 | if not isinstance(rows, list): 246 | raise TypeError(message) 247 | if not all(isinstance(row, dict) for row in rows): 248 | raise TypeError(message) 249 | return rows 250 | 251 | def _append_catalogue(self, filename, records): 252 | self._ensure_directory(filename.parent) 253 | records = [ 254 | collections.OrderedDict(sorted(record.items(), key=lambda x: x[0])) 255 | for record in records] 256 | content = ''.join([json.dumps(record) + '\n' for record in records]) 257 | with filename.open('a') as f: 258 | f.write(content) 259 | 260 | def _ensure_directory(self, directory): 261 | if directory in self._created_directories: 262 | return 263 | # We first attempt to create the directory and afterwards add it to the set 264 | # of created directories. This means multiple workers could attempt to 265 | # create the directory at the same time, which is better than one trying to 266 | # create a file while another has not yet created the directory. 267 | directory.mkdir(parents=True, exist_ok=True) 268 | self._created_directories.add(directory) 269 | 270 | def _random_filename(self, extension): 271 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 272 | identifier = str(uuid.uuid4()).replace('-', '') 273 | filename = '{}-{}.{}'.format(timestamp, identifier, extension) 274 | return filename 275 | -------------------------------------------------------------------------------- /dreamer/tools/nested.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | 18 | _builtin_zip = zip 19 | _builtin_map = map 20 | _builtin_filter = filter 21 | 22 | 23 | def zip_(*structures, **kwargs): 24 | # Named keyword arguments are not allowed after *args in Python 2. 25 | flatten = kwargs.pop('flatten', False) 26 | assert not kwargs, 'zip() got unexpected keyword arguments.' 27 | return map( 28 | lambda *x: x if len(x) > 1 else x[0], 29 | *structures, 30 | flatten=flatten) 31 | 32 | 33 | def map_(function, *structures, **kwargs): 34 | # Named keyword arguments are not allowed after *args in Python 2. 35 | flatten = kwargs.pop('flatten', False) 36 | assert not kwargs, 'map() got unexpected keyword arguments.' 37 | 38 | def impl(function, *structures): 39 | if len(structures) == 0: 40 | return structures 41 | if all(isinstance(s, (tuple, list)) for s in structures): 42 | if len(set(len(x) for x in structures)) > 1: 43 | raise ValueError('Cannot merge tuples or lists of different length.') 44 | args = tuple((impl(function, *x) for x in _builtin_zip(*structures))) 45 | if hasattr(structures[0], '_fields'): # namedtuple 46 | return type(structures[0])(*args) 47 | else: # tuple, list 48 | return type(structures[0])(args) 49 | if all(isinstance(s, dict) for s in structures): 50 | if len(set(frozenset(x.keys()) for x in structures)) > 1: 51 | raise ValueError('Cannot merge dicts with different keys.') 52 | merged = { 53 | k: impl(function, *(s[k] for s in structures)) 54 | for k in structures[0]} 55 | return type(structures[0])(merged) 56 | return function(*structures) 57 | 58 | result = impl(function, *structures) 59 | if flatten: 60 | result = flatten_(result) 61 | return result 62 | 63 | 64 | def flatten_(structure): 65 | if isinstance(structure, dict): 66 | result = () 67 | for key in sorted(list(structure.keys())): 68 | result += flatten_(structure[key]) 69 | return result 70 | if isinstance(structure, (tuple, list)): 71 | result = () 72 | for element in structure: 73 | result += flatten_(element) 74 | return result 75 | return (structure,) 76 | 77 | 78 | def filter_(predicate, *structures, **kwargs): 79 | # Named keyword arguments are not allowed after *args in Python 2. 80 | flatten = kwargs.pop('flatten', False) 81 | assert not kwargs, 'filter() got unexpected keyword arguments.' 82 | 83 | def impl(predicate, *structures): 84 | if len(structures) == 0: 85 | return structures 86 | if all(isinstance(s, (tuple, list)) for s in structures): 87 | if len(set(len(x) for x in structures)) > 1: 88 | raise ValueError('Cannot merge tuples or lists of different length.') 89 | # Only wrap in tuples if more than one structure provided. 90 | if len(structures) > 1: 91 | filtered = (impl(predicate, *x) for x in _builtin_zip(*structures)) 92 | else: 93 | filtered = (impl(predicate, x) for x in structures[0]) 94 | # Remove empty containers and construct result structure. 95 | if hasattr(structures[0], '_fields'): # namedtuple 96 | filtered = (x if x != () else None for x in filtered) 97 | return type(structures[0])(*filtered) 98 | else: # tuple, list 99 | filtered = ( 100 | x for x in filtered if not isinstance(x, (tuple, list, dict)) or x) 101 | return type(structures[0])(filtered) 102 | if all(isinstance(s, dict) for s in structures): 103 | if len(set(frozenset(x.keys()) for x in structures)) > 1: 104 | raise ValueError('Cannot merge dicts with different keys.') 105 | # Only wrap in tuples if more than one structure provided. 106 | if len(structures) > 1: 107 | filtered = { 108 | k: impl(predicate, *(s[k] for s in structures)) 109 | for k in structures[0]} 110 | else: 111 | filtered = {k: impl(predicate, v) for k, v in structures[0].items()} 112 | # Remove empty containers and construct result structure. 113 | filtered = { 114 | k: v for k, v in filtered.items() 115 | if not isinstance(v, (tuple, list, dict)) or v} 116 | return type(structures[0])(filtered) 117 | if len(structures) > 1: 118 | return structures if predicate(*structures) else () 119 | else: 120 | return structures[0] if predicate(structures[0]) else () 121 | 122 | result = impl(predicate, *structures) 123 | if flatten: 124 | result = flatten_(result) 125 | return result 126 | 127 | 128 | zip = zip_ 129 | map = map_ 130 | flatten = flatten_ 131 | filter = filter_ 132 | -------------------------------------------------------------------------------- /dreamer/tools/numpy_episodes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import functools 20 | import os 21 | import random 22 | import time 23 | 24 | from scipy.ndimage import interpolation 25 | import numpy as np 26 | import tensorflow as tf 27 | 28 | from dreamer.tools import attr_dict 29 | from dreamer.tools import chunk_sequence 30 | 31 | 32 | def numpy_episodes( 33 | train_dir, test_dir, shape, reader=None, loader=None, num_chunks=None, 34 | preprocess_fn=None, gpu_prefetch=False): 35 | reader = reader or episode_reader 36 | loader = loader or cache_loader 37 | try: 38 | dtypes, shapes = _read_spec(reader, train_dir) 39 | except ZeroDivisionError: 40 | dtypes, shapes = _read_spec(reader, test_dir) 41 | train = tf.data.Dataset.from_generator( 42 | functools.partial(loader, reader, train_dir), 43 | dtypes, shapes) 44 | test = tf.data.Dataset.from_generator( 45 | functools.partial(cache_loader, reader, test_dir, shape[0]), 46 | dtypes, shapes) 47 | chunking = lambda x: tf.data.Dataset.from_tensor_slices( 48 | chunk_sequence.chunk_sequence(x, shape[1], True, num_chunks)) 49 | train = train.flat_map(chunking) 50 | train = train.batch(shape[0], drop_remainder=True) 51 | if preprocess_fn: 52 | train = train.map(preprocess_fn, tf.data.experimental.AUTOTUNE) 53 | if gpu_prefetch: 54 | train = train.apply(tf.data.experimental.copy_to_device('/gpu:0')) 55 | train = train.prefetch(tf.data.experimental.AUTOTUNE) 56 | test = test.flat_map(chunking) 57 | test = test.batch(shape[0], drop_remainder=True) 58 | if preprocess_fn: 59 | test = test.map(preprocess_fn, tf.data.experimental.AUTOTUNE) 60 | if gpu_prefetch: 61 | test = test.apply(tf.data.experimental.copy_to_device('/gpu:0')) 62 | test = test.prefetch(tf.data.experimental.AUTOTUNE) 63 | return attr_dict.AttrDict(train=train, test=test) 64 | 65 | 66 | def cache_loader(reader, directory, every): 67 | cache = {} 68 | while True: 69 | episodes = _sample(cache.values(), every) 70 | for episode in _permuted(episodes, every): 71 | yield episode 72 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz')) 73 | filenames = [filename for filename in filenames if filename not in cache] 74 | for filename in filenames: 75 | cache[filename] = reader(filename) 76 | 77 | 78 | def recent_loader(reader, directory, every): 79 | recent = {} 80 | cache = {} 81 | while True: 82 | episodes = [] 83 | episodes += _sample(recent.values(), every // 2) 84 | episodes += _sample(cache.values(), every // 2) 85 | for episode in _permuted(episodes, every): 86 | yield episode 87 | cache.update(recent) 88 | recent = {} 89 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz')) 90 | filenames = [filename for filename in filenames if filename not in cache] 91 | for filename in filenames: 92 | recent[filename] = reader(filename) 93 | 94 | 95 | def window_loader(reader, directory, window, every): 96 | cache = {} 97 | while True: 98 | episodes = _sample(cache.values(), every) 99 | for episode in _permuted(episodes, every): 100 | yield episode 101 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz')) 102 | filenames = sorted(filenames)[-window:] 103 | for filename in filenames: 104 | if filename not in cache: 105 | cache[filename] = reader(filename) 106 | for key in list(cache.keys()): 107 | if key not in filenames: 108 | del cache[key] 109 | 110 | 111 | def reload_loader(reader, directory): 112 | directory = os.path.expanduser(directory) 113 | while True: 114 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz')) 115 | random.shuffle(filenames) 116 | for filename in filenames: 117 | yield reader(filename) 118 | 119 | 120 | def dummy_loader(reader, directory): 121 | random = np.random.RandomState(seed=0) 122 | dtypes, shapes, length = _read_spec(reader, directory, True, True) 123 | while True: 124 | episode = {} 125 | for key in dtypes: 126 | dtype, shape = dtypes[key], (length,) + shapes[key][1:] 127 | if dtype in (np.float32, np.float64): 128 | episode[key] = random.uniform(0, 1, shape).astype(dtype) 129 | elif dtype in (np.int32, np.int64, np.uint8): 130 | episode[key] = random.uniform(0, 255, shape).astype(dtype) 131 | else: 132 | raise NotImplementedError('Unsupported dtype {}.'.format(dtype)) 133 | yield episode 134 | 135 | 136 | def episode_reader( 137 | filename, resize=None, max_length=None, action_noise=None, 138 | clip_rewards=False, pcont_scale=None): 139 | try: 140 | with tf.gfile.Open(filename, 'rb') as file_: 141 | episode = np.load(file_) 142 | except (IOError, ValueError): 143 | # Try again one second later, in case the file was still being written. 144 | time.sleep(1) 145 | with tf.gfile.Open(filename, 'rb') as file_: 146 | episode = np.load(file_) 147 | episode = {key: _convert_type(episode[key]) for key in episode.keys()} 148 | episode['return'] = np.cumsum(episode['reward']) 149 | if 'reward_mask' not in episode: 150 | episode['reward_mask'] = np.ones_like(episode['reward'])[..., None] 151 | if max_length: 152 | episode = {key: value[:max_length] for key, value in episode.items()} 153 | if resize and resize != 1: 154 | factors = (1, resize, resize, 1) 155 | episode['image'] = interpolation.zoom(episode['image'], factors) 156 | if action_noise: 157 | seed = np.fromstring(filename, dtype=np.uint8) 158 | episode['action'] += np.random.RandomState(seed).normal( 159 | 0, action_noise, episode['action'].shape) 160 | if clip_rewards is False: 161 | pass 162 | elif clip_rewards == 'sign': 163 | episode['reward'] = np.sign(episode['reward']) 164 | elif clip_rewards == 'tanh': 165 | episode['reward'] = np.tanh(episode['reward']) 166 | else: 167 | raise NotImplementedError(clip_rewards) 168 | if pcont_scale is not None: 169 | episode['pcont'] *= pcont_scale 170 | return episode 171 | 172 | 173 | def _read_spec( 174 | reader, directory, return_length=False, numpy_types=False): 175 | episodes = reload_loader(reader, directory) 176 | episode = next(episodes) 177 | episodes.close() 178 | dtypes = {key: value.dtype for key, value in episode.items()} 179 | if not numpy_types: 180 | dtypes = {key: tf.as_dtype(value) for key, value in dtypes.items()} 181 | shapes = {key: value.shape for key, value in episode.items()} 182 | shapes = {key: (None,) + shape[1:] for key, shape in shapes.items()} 183 | if return_length: 184 | length = len(episode[list(shapes.keys())[0]]) 185 | return dtypes, shapes, length 186 | else: 187 | return dtypes, shapes 188 | 189 | 190 | def _convert_type(array): 191 | if array.dtype == np.float64: 192 | return array.astype(np.float32) 193 | if array.dtype == np.int64: 194 | return array.astype(np.int32) 195 | return array 196 | 197 | 198 | def _sample(sequence, amount): 199 | sequence = list(sequence) 200 | amount = min(amount, len(sequence)) 201 | return random.sample(sequence, amount) 202 | 203 | 204 | def _permuted(sequence, amount): 205 | sequence = list(sequence) 206 | if not sequence: 207 | return 208 | index = 0 209 | while True: 210 | for element in np.random.permutation(sequence): 211 | if index >= amount: 212 | return 213 | yield element 214 | index += 1 215 | -------------------------------------------------------------------------------- /dreamer/tools/overshooting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import functools 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from dreamer.tools import nested 25 | from dreamer.tools import shape 26 | 27 | 28 | def overshooting( 29 | cell, target, embedded, prev_action, length, amount, posterior=None, 30 | ignore_input=False): 31 | # Closed loop unroll to get posterior states, which are the starting points 32 | # for open loop unrolls. We don't need the last time step, since we have no 33 | # targets for unrolls from it. 34 | if posterior is None: 35 | use_obs = tf.ones(tf.shape( 36 | nested.flatten(embedded)[0][:, :, :1])[:3], tf.bool) 37 | use_obs = tf.cond( 38 | tf.convert_to_tensor(ignore_input), 39 | lambda: tf.zeros_like(use_obs, tf.bool), 40 | lambda: use_obs) 41 | (_, posterior), _ = tf.nn.dynamic_rnn( 42 | cell, (embedded, prev_action, use_obs), length, dtype=tf.float32, 43 | swap_memory=True) 44 | 45 | # Arrange inputs for every iteration in the open loop unroll. Every loop 46 | # iteration below corresponds to one row in the docstring illustration. 47 | max_length = shape.shape(nested.flatten(embedded)[0])[1] 48 | first_output = { 49 | # 'observ': embedded, 50 | 'prev_action': prev_action, 51 | 'posterior': posterior, 52 | 'target': target, 53 | 'mask': tf.sequence_mask(length, max_length, tf.int32), 54 | } 55 | 56 | progress_fn = lambda tensor: tf.concat([tensor[:, 1:], 0 * tensor[:, :1]], 1) 57 | other_outputs = tf.scan( 58 | lambda past_output, _: nested.map(progress_fn, past_output), 59 | tf.range(amount), first_output) 60 | sequences = nested.map( 61 | lambda lhs, rhs: tf.concat([lhs[None], rhs], 0), 62 | first_output, other_outputs) 63 | 64 | # Merge batch and time dimensions of steps to compute unrolls from every 65 | # time step as one batch. The time dimension becomes the number of 66 | # overshooting distances. 67 | sequences = nested.map( 68 | lambda tensor: _merge_dims(tensor, [1, 2]), 69 | sequences) 70 | sequences = nested.map( 71 | lambda tensor: tf.transpose( 72 | tensor, [1, 0] + list(range(2, tensor.shape.ndims))), 73 | sequences) 74 | merged_length = tf.reduce_sum(sequences['mask'], 1) 75 | 76 | # Mask out padding frames; unnecessary if the input is already masked. 77 | sequences = nested.map( 78 | lambda tensor: tensor * tf.cast( 79 | _pad_dims(sequences['mask'], tensor.shape.ndims), 80 | tensor.dtype), 81 | sequences) 82 | 83 | # Compute open loop rollouts. 84 | use_obs = tf.zeros(tf.shape(sequences['mask']), tf.bool)[..., None] 85 | embed_size = nested.flatten(embedded)[0].shape[2].value 86 | obs = tf.zeros(shape.shape(sequences['mask']) + [embed_size]) 87 | prev_state = nested.map( 88 | lambda tensor: tf.concat([0 * tensor[:, :1], tensor[:, :-1]], 1), 89 | posterior) 90 | prev_state = nested.map( 91 | lambda tensor: _merge_dims(tensor, [0, 1]), prev_state) 92 | (priors, _), _ = tf.nn.dynamic_rnn( 93 | cell, (obs, sequences['prev_action'], use_obs), 94 | merged_length, 95 | prev_state) 96 | 97 | # Restore batch dimension. 98 | target, prior, posterior, mask = nested.map( 99 | functools.partial(_restore_batch_dim, batch_size=shape.shape(length)[0]), 100 | (sequences['target'], priors, sequences['posterior'], sequences['mask'])) 101 | 102 | mask = tf.cast(mask, tf.bool) 103 | return target, prior, posterior, mask 104 | 105 | 106 | def _merge_dims(tensor, dims): 107 | if isinstance(tensor, (list, tuple, dict)): 108 | return nested.map(tensor, lambda x: _merge_dims(x, dims)) 109 | tensor = tf.convert_to_tensor(tensor) 110 | if (np.array(dims) - min(dims) != np.arange(len(dims))).all(): 111 | raise ValueError('Dimensions to merge must all follow each other.') 112 | start, end = dims[0], dims[-1] 113 | output = tf.reshape(tensor, tf.concat([ 114 | tf.shape(tensor)[:start], 115 | [tf.reduce_prod(tf.shape(tensor)[start: end + 1])], 116 | tf.shape(tensor)[end + 1:]], axis=0)) 117 | merged = tensor.shape[start: end + 1].as_list() 118 | output.set_shape( 119 | tensor.shape[:start].as_list() + 120 | [None if None in merged else np.prod(merged)] + 121 | tensor.shape[end + 1:].as_list()) 122 | return output 123 | 124 | 125 | def _pad_dims(tensor, rank): 126 | for _ in range(rank - tensor.shape.ndims): 127 | tensor = tensor[..., None] 128 | return tensor 129 | 130 | 131 | def _restore_batch_dim(tensor, batch_size): 132 | initial = shape.shape(tensor) 133 | desired = [batch_size, initial[0] // batch_size] + initial[1:] 134 | return tf.reshape(tensor, desired) 135 | -------------------------------------------------------------------------------- /dreamer/tools/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def preprocess(observ, bits): 23 | if 'image' in observ: 24 | bins = 2 ** bits 25 | image = tf.cast(observ['image'], tf.float32) 26 | if bits < 8: 27 | image = tf.floor(image / 2 ** (8 - bits)) 28 | image = image / bins 29 | image = image + tf.random_uniform(tf.shape(image), 0, 1.0 / bins) 30 | image = image - 0.5 31 | observ['image'] = image 32 | return observ 33 | 34 | 35 | def postprocess(image, bits, dtype=tf.float32): 36 | bins = 2 ** bits 37 | if dtype == tf.float32: 38 | image = tf.floor(bins * (image + 0.5)) / bins 39 | elif dtype == tf.uint8: 40 | image = image + 0.5 41 | image = tf.floor(bins * image) 42 | image = image * (256.0 / bins) 43 | image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8) 44 | else: 45 | raise NotImplementedError(dtype) 46 | return image 47 | -------------------------------------------------------------------------------- /dreamer/tools/reshape_as.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer.tools import nested 22 | 23 | 24 | def reshape_as(tensor, reference): 25 | if isinstance(tensor, (list, tuple, dict)): 26 | return nested.map(tensor, lambda x: reshape_as(x, reference)) 27 | tensor = tf.convert_to_tensor(tensor) 28 | reference = tf.convert_to_tensor(reference) 29 | statics = reference.shape.as_list() 30 | dynamics = tf.shape(reference) 31 | shape = [ 32 | static if static is not None else dynamics[index] 33 | for index, static in enumerate(statics)] 34 | return tf.reshape(tensor, shape) 35 | -------------------------------------------------------------------------------- /dreamer/tools/schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | 23 | def binary(step, batch_size, after, every, until): 24 | # https://www.desmos.com/calculator/csbhr5cjzz 25 | offset_step = step - after 26 | phase = tf.less(offset_step % every, batch_size) 27 | active = tf.greater_equal(step, after) 28 | if until > 0: 29 | active = tf.logical_and(active, tf.less(step, until)) 30 | result = tf.logical_and(phase, active) 31 | result.set_shape(tf.TensorShape([])) 32 | return result 33 | 34 | 35 | def linear(step, ramp, min=None, max=None): 36 | # https://www.desmos.com/calculator/nrumhgvxql 37 | if ramp == 0: 38 | result = tf.constant(1, tf.float32) 39 | if ramp > 0: 40 | result = tf.minimum( 41 | tf.cast(step, tf.float32) / tf.cast(ramp, tf.float32), 1) 42 | if ramp < 0: 43 | result = 1 - linear(step, abs(ramp)) 44 | if min is not None and max is not None: 45 | assert min <= max 46 | if min is not None: 47 | min = float(min) 48 | assert 0 <= min <= 1 49 | result = tf.maximum(min, result) 50 | if max is not None: 51 | max = float(max) 52 | assert 0 <= min <= 1 53 | result = tf.minimum(result, max) 54 | result.set_shape(tf.TensorShape([])) 55 | return result 56 | 57 | 58 | def step(step, distance): 59 | # https://www.desmos.com/calculator/sh1zsjtsqg 60 | if distance == 0: 61 | result = tf.constant(1, tf.float32) 62 | if distance > 0: 63 | result = tf.cast(tf.less(step, distance), tf.float32) 64 | if distance < 0: 65 | result = 1 - tf.cast(tf.less(step, abs(distance)), tf.float32) 66 | result.set_shape(tf.TensorShape([])) 67 | return result 68 | 69 | 70 | def exponential(step, distance, target, min=None, max=None): 71 | # https://www.desmos.com/calculator/egt24luequ 72 | target = tf.constant(target, tf.float32) 73 | step = tf.cast(step, tf.float32) 74 | if distance == 0: 75 | result = tf.constant(1, tf.float32) 76 | elif distance > 0: 77 | distance = tf.constant(abs(distance), tf.float32) 78 | result = tf.exp(tf.log(target) / distance) ** step 79 | elif distance < 0: 80 | distance = tf.constant(abs(distance), tf.float32) 81 | result = 1 - tf.exp(tf.log(1 - target) / distance) ** step 82 | if min is not None and max is not None: 83 | assert min <= max 84 | if min is not None: 85 | min = float(min) 86 | assert 0 <= min <= 1 87 | result = tf.maximum(min, result) 88 | if max is not None: 89 | max = float(max) 90 | assert 0 <= min <= 1 91 | result = tf.minimum(result, max) 92 | result.set_shape(tf.TensorShape([])) 93 | return result 94 | 95 | 96 | def linear_reset(step, ramp, after, every): 97 | # https://www.desmos.com/calculator/motbnqhacw 98 | assert every > ramp, (every, ramp) # Would never reach max value. 99 | assert not (every != np.inf and after == np.inf), (every, after) 100 | step, ramp, after, every = [ 101 | tf.cast(x, tf.float32) for x in (step, ramp, after, every)] 102 | before = tf.cast(tf.less(step, after), tf.float32) * step 103 | after_mask = tf.cast(tf.greater_equal(step, after), tf.float32) 104 | after = after_mask * ((step - after) % every) 105 | result = tf.minimum((before + after) / ramp, 1) 106 | result.set_shape(tf.TensorShape([])) 107 | return result 108 | -------------------------------------------------------------------------------- /dreamer/tools/shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def shape(tensor): 23 | static = tensor.get_shape().as_list() 24 | dynamic = tf.unstack(tf.shape(tensor)) 25 | assert len(static) == len(dynamic) 26 | combined = [d if s is None else s for s, d in zip(static, dynamic)] 27 | return combined 28 | -------------------------------------------------------------------------------- /dreamer/tools/streaming_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | class StreamingMean(object): 23 | 24 | def __init__(self, shape, dtype, name): 25 | self._dtype = dtype 26 | with tf.variable_scope(name): 27 | self._sum = tf.get_variable( 28 | 'sum', shape, dtype, 29 | tf.constant_initializer(0), 30 | trainable=False) 31 | self._count = tf.get_variable( 32 | 'count', (), tf.int32, 33 | tf.constant_initializer(0), 34 | trainable=False) 35 | 36 | @property 37 | def value(self): 38 | return self._sum / tf.cast(self._count, self._dtype) 39 | 40 | @property 41 | def count(self): 42 | return self._count 43 | 44 | def submit(self, value): 45 | value = tf.convert_to_tensor(value) 46 | # Add a batch dimension if necessary. 47 | if value.shape.ndims == self._sum.shape.ndims: 48 | value = value[None, ...] 49 | if str(value.shape[1:]) != str(self._sum.shape): 50 | message = 'Value shape ({}) does not fit tracked tensor ({}).' 51 | raise ValueError(message.format(value.shape[1:], self._sum.shape)) 52 | def assign(): 53 | return tf.group( 54 | self._sum.assign_add(tf.reduce_sum(value, 0)), 55 | self._count.assign_add(tf.shape(value)[0])) 56 | not_empty = tf.cast(tf.reduce_prod(tf.shape(value)), tf.bool) 57 | return tf.cond(not_empty, assign, tf.no_op) 58 | 59 | def clear(self): 60 | value = self._sum / tf.cast(self._count, self._dtype) 61 | with tf.control_dependencies([value]): 62 | reset_value = self._sum.assign(tf.zeros_like(self._sum)) 63 | reset_count = self._count.assign(0) 64 | with tf.control_dependencies([reset_value, reset_count]): 65 | return tf.identity(value) 66 | -------------------------------------------------------------------------------- /dreamer/tools/summary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import matplotlib 20 | matplotlib.use('Agg') 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | from dreamer.tools import count_dataset 26 | from dreamer.tools import gif_summary 27 | from dreamer.tools import image_strip_summary 28 | from dreamer.tools import shape as shapelib 29 | 30 | 31 | def plot_summary(titles, lines, labels, name): 32 | # Use control dependencies to ensure that only one plot summary is executed 33 | # at a time, since matplotlib is not thread safe. 34 | def body_fn(lines): 35 | fig, axes = plt.subplots( 36 | nrows=len(titles), ncols=1, sharex=True, sharey=False, 37 | squeeze=False, figsize=(6, 3 * len(lines))) 38 | axes = axes[:, 0] 39 | for index, ax in enumerate(axes): 40 | ax.set_title(titles[index]) 41 | for line, label in zip(lines[index], labels[index]): 42 | ax.plot(line, label=label) 43 | if any(labels[index]): 44 | ax.legend(frameon=False) 45 | fig.tight_layout() 46 | fig.canvas.draw() 47 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 48 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 49 | plt.close(fig) 50 | return image 51 | image = tf.py_func(body_fn, (lines,), tf.uint8) 52 | image = image[None] 53 | summary = tf.summary.image(name, image) 54 | return summary 55 | 56 | 57 | def data_summaries( 58 | data, postprocess_fn, histograms=False, max_batch=6, name='data'): 59 | summaries = [] 60 | with tf.variable_scope(name): 61 | if histograms: 62 | for key, value in data.items(): 63 | if key in ('image',): 64 | continue 65 | summaries.append(tf.summary.histogram(key, data[key])) 66 | image = data['image'][:max_batch] 67 | if postprocess_fn: 68 | image = postprocess_fn(image) 69 | summaries.append(image_strip_summary.image_strip_summary('image', image)) 70 | return summaries 71 | 72 | 73 | def dataset_summaries(directory, name='dataset'): 74 | summaries = [] 75 | with tf.variable_scope(name): 76 | episodes = count_dataset.count_dataset(directory) 77 | summaries.append(tf.summary.scalar('episodes', episodes)) 78 | return summaries 79 | 80 | 81 | def state_summaries( 82 | cell, prior, posterior, histograms=False, name='state'): 83 | summaries = [] 84 | divergence = cell.divergence_from_states(posterior, prior) 85 | prior = cell.dist_from_state(prior) 86 | posterior = cell.dist_from_state(posterior) 87 | prior_entropy = prior.entropy() 88 | posterior_entropy = posterior.entropy() 89 | nan_to_num = lambda x: tf.where(tf.is_nan(x), tf.zeros_like(x), x) 90 | with tf.variable_scope(name): 91 | if histograms: 92 | summaries.append(tf.summary.histogram( 93 | 'prior_entropy_hist', nan_to_num(prior_entropy))) 94 | summaries.append(tf.summary.scalar( 95 | 'prior_entropy', tf.reduce_mean(prior_entropy))) 96 | summaries.append(tf.summary.scalar( 97 | 'prior_std', tf.reduce_mean(prior.stddev()))) 98 | if histograms: 99 | summaries.append(tf.summary.histogram( 100 | 'posterior_entropy_hist', nan_to_num(posterior_entropy))) 101 | summaries.append(tf.summary.scalar( 102 | 'posterior_entropy', tf.reduce_mean(posterior_entropy))) 103 | summaries.append(tf.summary.scalar( 104 | 'posterior_std', tf.reduce_mean(posterior.stddev()))) 105 | summaries.append(tf.summary.scalar( 106 | 'divergence', tf.reduce_mean(divergence))) 107 | return summaries 108 | 109 | 110 | def dist_summaries(dists, obs, name='dist_summaries'): 111 | summaries = [] 112 | with tf.variable_scope(name): 113 | for name, dist in dists.items(): 114 | mode = tf.cast(dist.mode(), tf.float32) 115 | mode_mean, mode_var = tf.nn.moments(mode, list(range(mode.shape.ndims))) 116 | mode_std = tf.sqrt(mode_var) 117 | summaries.append(tf.summary.scalar(name + '_mode_mean', mode_mean)) 118 | summaries.append(tf.summary.scalar(name + '_mode_std', mode_std)) 119 | std = dist.stddev() 120 | std_mean, std_var = tf.nn.moments(std, list(range(std.shape.ndims))) 121 | std_std = tf.sqrt(std_var) 122 | summaries.append(tf.summary.scalar(name + '_std_mean', std_mean)) 123 | summaries.append(tf.summary.scalar(name + '_std_std', std_std)) 124 | if hasattr(dist, 'distribution') and hasattr( 125 | dist.distribution, 'distribution'): 126 | inner = dist.distribution.distribution 127 | inner_std = tf.reduce_mean(inner.stddev()) 128 | summaries.append(tf.summary.scalar(name + '_inner_std', inner_std)) 129 | if name in obs: 130 | log_prob = tf.reduce_mean(dist.log_prob(obs[name])) 131 | summaries.append(tf.summary.scalar(name + '_log_prob', log_prob)) 132 | abs_error = tf.reduce_mean(tf.abs(mode - obs[name])) 133 | summaries.append(tf.summary.scalar(name + '_abs_error', abs_error)) 134 | return summaries 135 | 136 | 137 | def image_summaries(dist, target, name='image', max_batch=6): 138 | summaries = [] 139 | with tf.variable_scope(name): 140 | image = dist.mode()[:max_batch] 141 | target = target[:max_batch] 142 | error = ((image - target) + 1) / 2 143 | # empty_frame = 0 * target[:max_batch, :1] 144 | # change = tf.concat([empty_frame, image[:, 1:] - image[:, :-1]], 1) 145 | # change = (change + 1) / 2 146 | summaries.append(image_strip_summary.image_strip_summary( 147 | 'prediction', image)) 148 | # summaries.append(image_strip_summary.image_strip_summary( 149 | # 'change', change)) 150 | # summaries.append(image_strip_summary.image_strip_summary( 151 | # 'error', error)) 152 | # Concat prediction and target vertically. 153 | frames = tf.concat([target, image, error], 2) 154 | # Stack batch entries horizontally. 155 | frames = tf.transpose(frames, [1, 2, 0, 3, 4]) 156 | s = shapelib.shape(frames) 157 | frames = tf.reshape(frames, [s[0], s[1], s[2] * s[3], s[4]]) 158 | summaries.append(gif_summary.gif_summary( 159 | 'gif', frames[None], max_outputs=1, fps=20)) 160 | return summaries 161 | 162 | 163 | def objective_summaries(objectives, name='objectives'): 164 | summaries = [] 165 | with tf.variable_scope(name): 166 | for objective in objectives: 167 | summaries.append(tf.summary.scalar(objective.name, objective.value)) 168 | return summaries 169 | 170 | 171 | def prediction_summaries(dists, data, state, name='state'): 172 | summaries = [] 173 | with tf.variable_scope(name): 174 | # Predictions. 175 | log_probs = {} 176 | for key, dist in dists.items(): 177 | if key in ('image',): 178 | continue 179 | if key not in data: 180 | continue 181 | # We only look at the first example in the batch. 182 | log_prob = dist.log_prob(data[key])[0] 183 | prediction = dist.mode()[0] 184 | truth = data[key][0] 185 | plot_name = key 186 | # Ensure that there is a feature dimension. 187 | if prediction.shape.ndims == 1: 188 | prediction = prediction[:, None] 189 | truth = truth[:, None] 190 | prediction = tf.unstack(tf.transpose(prediction, (1, 0))) 191 | truth = tf.unstack(tf.transpose(truth, (1, 0))) 192 | lines = list(zip(prediction, truth)) 193 | titles = ['{} {}'.format(key.title(), i) for i in range(len(lines))] 194 | labels = [['Prediction', 'Truth']] * len(lines) 195 | plot_name = '{}_trajectory'.format(key) 196 | # The control dependencies are needed because rendering in matplotlib 197 | # uses global state, so rendering two plots in parallel interferes. 198 | with tf.control_dependencies(summaries): 199 | summaries.append(plot_summary(titles, lines, labels, plot_name)) 200 | log_probs[key] = log_prob 201 | log_probs = sorted(log_probs.items(), key=lambda x: x[0]) 202 | titles, lines = zip(*log_probs) 203 | titles = [title.title() for title in titles] 204 | lines = [[line] for line in lines] 205 | labels = [[None]] * len(titles) 206 | plot_name = 'logprobs' 207 | with tf.control_dependencies(summaries): 208 | summaries.append(plot_summary(titles, lines, labels, plot_name)) 209 | return summaries 210 | -------------------------------------------------------------------------------- /dreamer/tools/target_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer.tools import schedule as schedule_lib 22 | from dreamer.tools import copy_weights 23 | 24 | 25 | def track_network( 26 | trainer, batch_size, source_pattern, target_pattern, every, amount): 27 | init_op = tf.cond( 28 | tf.equal(trainer.global_step, 0), 29 | lambda: copy_weights.soft_copy_weights( 30 | source_pattern, target_pattern, 1.0), 31 | tf.no_op) 32 | schedule = schedule_lib.binary(trainer.step, batch_size, 0, every, -1) 33 | with tf.control_dependencies([init_op]): 34 | return tf.cond( 35 | tf.logical_and(tf.equal(trainer.phase, 'train'), schedule), 36 | lambda: copy_weights.soft_copy_weights( 37 | source_pattern, target_pattern, amount), 38 | tf.no_op) 39 | -------------------------------------------------------------------------------- /dreamer/tools/unroll.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from dreamer.tools import nested 22 | from dreamer.tools import shape 23 | 24 | 25 | def closed_loop(cell, embedded, prev_action, debug=False): 26 | use_obs = tf.ones(tf.shape(embedded[:, :, :1])[:3], tf.bool) 27 | (prior, posterior), _ = tf.nn.dynamic_rnn( 28 | cell, (embedded, prev_action, use_obs), dtype=tf.float32) 29 | if debug: 30 | with tf.control_dependencies([tf.assert_equal( 31 | tf.shape(nested.flatten(posterior)[0])[1], tf.shape(embedded)[1])]): 32 | prior = nested.map(tf.identity, prior) 33 | posterior = nested.map(tf.identity, posterior) 34 | return prior, posterior 35 | 36 | 37 | def open_loop(cell, embedded, prev_action, context=1, debug=False): 38 | use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool) 39 | (_, closed_state), last_state = tf.nn.dynamic_rnn( 40 | cell, (embedded[:, :context], prev_action[:, :context], use_obs), 41 | dtype=tf.float32) 42 | use_obs = tf.zeros(tf.shape(embedded[:, context:, :1])[:3], tf.bool) 43 | (_, open_state), _ = tf.nn.dynamic_rnn( 44 | cell, (0 * embedded[:, context:], prev_action[:, context:], use_obs), 45 | initial_state=last_state) 46 | state = nested.map( 47 | lambda x, y: tf.concat([x, y], 1), 48 | closed_state, open_state) 49 | if debug: 50 | with tf.control_dependencies([tf.assert_equal( 51 | tf.shape(nested.flatten(state)[0])[1], tf.shape(embedded)[1])]): 52 | state = nested.map(tf.identity, state) 53 | return state 54 | 55 | 56 | def planned( 57 | cell, objective_fn, embedded, prev_action, planner, context=1, length=20, 58 | amount=1000, debug=False): 59 | use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool) 60 | (_, closed_state), last_state = tf.nn.dynamic_rnn( 61 | cell, (embedded[:, :context], prev_action[:, :context], use_obs), 62 | dtype=tf.float32) 63 | _, plan_state, return_ = planner( 64 | cell, objective_fn, last_state, 65 | obs_shape=shape.shape(embedded)[2:], 66 | action_shape=shape.shape(prev_action)[2:], 67 | horizon=length, amount=amount) 68 | state = nested.map( 69 | lambda x, y: tf.concat([x, y], 1), 70 | closed_state, plan_state) 71 | if debug: 72 | with tf.control_dependencies([tf.assert_equal( 73 | tf.shape(nested.flatten(state)[0])[1], context + length)]): 74 | state = nested.map(tf.identity, state) 75 | return_ = tf.identity(return_) 76 | return state, return_ 77 | -------------------------------------------------------------------------------- /dreamer/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import utility 20 | from .define_model import define_model 21 | from .define_summaries import define_summaries 22 | from .running import Experiment 23 | from .trainer import Trainer 24 | -------------------------------------------------------------------------------- /dreamer/training/define_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import datetime 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from dreamer import tools 25 | from dreamer.training import define_summaries 26 | from dreamer.training import utility 27 | 28 | 29 | def define_model(logdir, metrics, data, trainer, config): 30 | print('Build TensorFlow compute graph.') 31 | dependencies = [] 32 | cleanups = [] 33 | step = trainer.step 34 | global_step = trainer.global_step 35 | phase = trainer.phase 36 | timestamp = tf.py_func( 37 | lambda: datetime.datetime.utcnow().strftime('%Y%m%dT%H%M%S'), 38 | [], tf.string) 39 | dependencies.append(metrics.set_tags( 40 | global_step=global_step, step=step, phase=phase, 41 | time=timestamp)) 42 | 43 | # Instantiate network blocks. Note, this initialization would be expensive 44 | # when using tf.function since it would run at every step. 45 | try: 46 | cell = config.cell() 47 | except TypeError: 48 | cell = config.cell(action_size=data['action'].shape[-1].value) 49 | kwargs = dict(create_scope_now_=True) 50 | encoder = tf.make_template('encoder', config.encoder, **kwargs) 51 | heads = tools.AttrDict(_unlocked=True) 52 | raw_dummy_features = cell.features_from_state( 53 | cell.zero_state(1, tf.float32))[:, None] 54 | for key, head in config.heads.items(): 55 | name = 'head_{}'.format(key) 56 | kwargs = dict(create_scope_now_=True) 57 | if key in data: 58 | kwargs['data_shape'] = data[key].shape[2:].as_list() 59 | if key == 'action_target': 60 | kwargs['data_shape'] = data['action'].shape[2:].as_list() 61 | if key == 'cpc': 62 | kwargs['data_shape'] = [cell.feature_size] 63 | dummy_features = encoder(data)[:1, :1] 64 | else: 65 | dummy_features = raw_dummy_features 66 | heads[key] = tf.make_template(name, head, **kwargs) 67 | heads[key](dummy_features) # Initialize weights. 68 | 69 | # Update target networks. 70 | if 'value_target' in heads: 71 | dependencies.append(tools.track_network( 72 | trainer, config.batch_shape[0], 73 | r'.*/head_value/.*', r'.*/head_value_target/.*', 74 | config.value_target_period, config.value_target_update)) 75 | if 'value_target_2' in heads: 76 | dependencies.append(tools.track_network( 77 | trainer, config.batch_shape[0], 78 | r'.*/head_value/.*', r'.*/head_value_target_2/.*', 79 | config.value_target_period, config.value_target_update)) 80 | if 'action_target' in heads: 81 | dependencies.append(tools.track_network( 82 | trainer, config.batch_shape[0], 83 | r'.*/head_action/.*', r'.*/head_action_target/.*', 84 | config.action_target_period, config.action_target_update)) 85 | 86 | # Apply and optimize model. 87 | embedded = encoder(data) 88 | with tf.control_dependencies(dependencies): 89 | embedded = tf.identity(embedded) 90 | graph = tools.AttrDict(locals()) 91 | prior, posterior = tools.unroll.closed_loop( 92 | cell, embedded, data['action'], config.debug) 93 | objectives = utility.compute_objectives( 94 | posterior, prior, data, graph, config) 95 | summaries, grad_norms = utility.apply_optimizers( 96 | objectives, trainer, config) 97 | dependencies += summaries 98 | 99 | # Active data collection. 100 | with tf.variable_scope('collection'): 101 | with tf.control_dependencies(dependencies): # Make sure to train first. 102 | for name, params in config.train_collects.items(): 103 | schedule = tools.schedule.binary( 104 | step, config.batch_shape[0], 105 | params.steps_after, params.steps_every, params.steps_until) 106 | summary, _ = tf.cond( 107 | tf.logical_and(tf.equal(trainer.phase, 'train'), schedule), 108 | functools.partial( 109 | utility.simulate, metrics, config, params, graph, cleanups, 110 | gif_summary=False, name=name), 111 | lambda: (tf.constant(''), tf.constant(0.0)), 112 | name='should_collect_' + name) 113 | summaries.append(summary) 114 | dependencies.append(summary) 115 | 116 | # Compute summaries. 117 | graph = tools.AttrDict(locals()) 118 | summary, score = tf.cond( 119 | trainer.log, 120 | lambda: define_summaries.define_summaries(graph, config, cleanups), 121 | lambda: (tf.constant(''), tf.zeros((0,), tf.float32)), 122 | name='summaries') 123 | summaries = tf.summary.merge([summaries, summary]) 124 | dependencies.append(utility.print_metrics( 125 | {ob.name: ob.value for ob in objectives}, 126 | step, config.print_metrics_every, 2, 'objectives')) 127 | dependencies.append(utility.print_metrics( 128 | grad_norms, step, config.print_metrics_every, 2, 'grad_norms')) 129 | dependencies.append(tf.cond(trainer.log, metrics.flush, tf.no_op)) 130 | with tf.control_dependencies(dependencies): 131 | score = tf.identity(score) 132 | return score, summaries, cleanups 133 | -------------------------------------------------------------------------------- /dreamer/training/define_summaries.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | from tensorflow_probability import distributions as tfd 21 | 22 | from dreamer import tools 23 | from dreamer.training import utility 24 | from dreamer.tools import summary 25 | 26 | 27 | def define_summaries(graph, config, cleanups): 28 | summaries = [] 29 | # plot_summaries = [] # Control dependencies for non thread-safe matplot. 30 | heads = graph.heads.copy(_unlocked=True) 31 | last_time = tf.Variable(lambda: tf.timestamp(), trainable=False) 32 | last_step = tf.Variable(lambda: 0.0, trainable=False, dtype=tf.float64) 33 | 34 | def transform(dist): 35 | mean = config.postprocess_fn(dist.mean()) 36 | mean = tf.clip_by_value(mean, 0.0, 1.0) 37 | return tfd.Independent(tfd.Normal(mean, 1.0), len(dist.event_shape)) 38 | heads['image'] = lambda features: transform(graph.heads['image'](features)) 39 | heads.pop('cpc', None) # Not applied to RNN states. 40 | 41 | with tf.variable_scope('general'): 42 | summaries += summary.data_summaries(graph.data, config.postprocess_fn) 43 | summaries += summary.dataset_summaries(config.train_dir, 'train_dataset') 44 | summaries += summary.dataset_summaries(config.test_dir, 'test_dataset') 45 | summaries += summary.objective_summaries(graph.objectives) 46 | summaries.append(tf.summary.scalar('step', graph.step)) 47 | new_time, new_step = tf.timestamp(), tf.cast(graph.global_step, tf.float64) 48 | delta_time, delta_step = new_time - last_time, new_step - last_step 49 | with tf.control_dependencies([delta_time, delta_step]): 50 | assign_ops = [last_time.assign(new_time), last_step.assign(new_step)] 51 | with tf.control_dependencies(assign_ops): 52 | summaries.append(tf.summary.scalar( 53 | 'steps_per_second', delta_step / delta_time)) 54 | summaries.append(tf.summary.scalar( 55 | 'seconds_per_step', delta_time / delta_step)) 56 | 57 | with tf.variable_scope('closedloop'): 58 | prior, posterior = tools.unroll.closed_loop( 59 | graph.cell, graph.embedded, graph.data['action'], config.debug) 60 | summaries += summary.state_summaries(graph.cell, prior, posterior) 61 | # with tf.variable_scope('prior'): 62 | # prior_features = graph.cell.features_from_state(prior) 63 | # prior_dists = { 64 | # name: head(prior_features) 65 | # for name, head in heads.items()} 66 | # summaries += summary.dist_summaries(prior_dists, graph.data) 67 | # summaries += summary.image_summaries( 68 | # prior_dists['image'], config.postprocess_fn(graph.data['image'])) 69 | # with tf.variable_scope('posterior'): 70 | # posterior_features = graph.cell.features_from_state(posterior) 71 | # posterior_dists = { 72 | # name: head(posterior_features) 73 | # for name, head in heads.items()} 74 | # summaries += summary.dist_summaries( 75 | # posterior_dists, graph.data) 76 | # summaries += summary.image_summaries( 77 | # posterior_dists['image'], 78 | # config.postprocess_fn(graph.data['image'])) 79 | 80 | with tf.variable_scope('openloop'): 81 | state = tools.unroll.open_loop( 82 | graph.cell, graph.embedded, graph.data['action'], 83 | config.open_loop_context, config.debug) 84 | state_features = graph.cell.features_from_state(state) 85 | state_dists = {name: head(state_features) for name, head in heads.items()} 86 | if 'cpc' in graph.heads: 87 | state_dists['cpc'] = graph.heads.cpc(graph.embedded) 88 | summaries += summary.dist_summaries(state_dists, graph.data) 89 | summaries += summary.image_summaries( 90 | state_dists['image'], config.postprocess_fn(graph.data['image'])) 91 | summaries += summary.state_summaries(graph.cell, state, posterior) 92 | # with tf.control_dependencies(plot_summaries): 93 | # plot_summary = summary.prediction_summaries( 94 | # state_dists, graph.data, state) 95 | # plot_summaries += plot_summary 96 | # summaries += plot_summary 97 | 98 | with tf.variable_scope('simulation'): 99 | sim_returns = [] 100 | for name, params in config.test_collects.items(): 101 | # These are expensive and equivalent for train and test phases, so only 102 | # do one of them. 103 | sim_summary, sim_return = tf.cond( 104 | tf.equal(graph.phase, 'test'), 105 | lambda: utility.simulate( 106 | graph.metrics, config, params, graph, cleanups, 107 | gif_summary=True, name=name), 108 | lambda: ('', 0.0), 109 | name='should_simulate_' + params.task.name) 110 | summaries.append(sim_summary) 111 | sim_returns.append(sim_return) 112 | 113 | summaries = tf.summary.merge(summaries) 114 | score = tf.reduce_mean(sim_returns)[None] 115 | return summaries, score 116 | -------------------------------------------------------------------------------- /dreamer/training/running.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import datetime 20 | import itertools 21 | import os 22 | import pickle 23 | import six 24 | import sys 25 | import threading 26 | import time 27 | import traceback 28 | import uuid 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | 33 | 34 | class StopExperiment(Exception): 35 | pass 36 | 37 | 38 | class WorkerConflict(Exception): 39 | pass 40 | 41 | 42 | class SkipRun(Exception): 43 | pass 44 | 45 | 46 | class Experiment(object): 47 | 48 | def __init__( 49 | self, basedir, process_fn, start_fn=None, resume_fn=None, 50 | num_runs=None, worker_name=None, ping_every=30, resume_runs=True): 51 | self._basedir = basedir 52 | self._process_fn = process_fn 53 | self._start_fn = start_fn 54 | self._resume_fn = resume_fn 55 | self._num_runs = num_runs 56 | self._worker_name = worker_name or str(uuid.uuid4()) 57 | self._ping_every = ping_every 58 | self._ping_stale = ping_every and 2 * ping_every 59 | self._resume_runs = resume_runs 60 | 61 | def __iter__(self): 62 | for current_run in self._generate_run_numbers(): 63 | logdir = self._basedir and os.path.join( 64 | self._basedir, '{:05}'.format(current_run)) 65 | try: 66 | run = Run( 67 | logdir, self._process_fn, self._start_fn, self._resume_fn, 68 | self._worker_name, self._ping_every, self._ping_stale, 69 | self._resume_runs) 70 | yield run 71 | except SkipRun: 72 | continue 73 | except StopExperiment: 74 | print('Stopping.') 75 | break 76 | print('All runs completed.') 77 | 78 | def _generate_run_numbers(self): 79 | if self._num_runs: 80 | # Don't wait initially and see if there are runs that are already stale. 81 | runs = np.random.permutation(range(self._num_runs)) 82 | for run in runs: 83 | yield run + 1 84 | # At the end, wait for all dead runs to become stale, and pick them up. 85 | # This is necessary for complete runs of workers that died very recently. 86 | if self._ping_stale: 87 | time.sleep(self._ping_stale) 88 | for run in runs: 89 | yield run + 1 90 | else: 91 | # For infinite runs, we want to always finish started jobs first. 92 | # Therefore, we need to wait for them to become stale in the beginning. 93 | if self._ping_stale: 94 | time.sleep(self._ping_stale) 95 | for run in itertools.count(): 96 | yield run + 1 97 | 98 | 99 | class Run(object): 100 | 101 | def __init__( 102 | self, logdir, process_fn, start_fn, resume_fn, worker_name, 103 | ping_every=30, ping_stale=60, reuse_if_exists=True): 104 | self._logdir = os.path.expanduser(logdir) 105 | self._process_fn = process_fn 106 | self._worker_name = worker_name 107 | self._ping_every = ping_every 108 | self._ping_stale = ping_stale 109 | self._logger = self._create_logger() 110 | try: 111 | if self._should_start(): 112 | self._claim() 113 | self._logger.info('Start.') 114 | self._init_fn = start_fn 115 | elif reuse_if_exists and self._should_resume(): 116 | self._claim() 117 | self._logger.info('Resume.') 118 | self._init_fn = resume_fn 119 | else: 120 | raise SkipRun 121 | except WorkerConflict: 122 | self._logger.info('Leave to other worker.') 123 | raise SkipRun 124 | self._thread = None 125 | self._running = [True] 126 | self._thread = threading.Thread(target=self._store_ping_thread) 127 | self._thread.daemon = True # Terminate with main thread. 128 | self._thread.start() 129 | 130 | def __iter__(self): 131 | try: 132 | args = self._init_fn and self._init_fn(self._logdir) 133 | if args is None: 134 | args = () 135 | if not isinstance(args, tuple): 136 | args = (args,) 137 | for value in self._process_fn(self._logdir, *args): 138 | if not self._running[0]: 139 | break 140 | yield value 141 | self._logger.info('Done.') 142 | self._store_done() 143 | except WorkerConflict: 144 | self._logging.warn('Unexpected takeover.') 145 | raise SkipRun 146 | except Exception as e: 147 | exc_info = sys.exc_info() 148 | self._handle_exception(e) 149 | six.reraise(*exc_info) 150 | finally: 151 | self._running[0] = False 152 | self._thread and self._thread.join() 153 | 154 | def _should_start(self): 155 | if not self._logdir: 156 | return True 157 | if tf.gfile.Exists(os.path.join(self._logdir, 'PING')): 158 | return False 159 | if tf.gfile.Exists(os.path.join(self._logdir, 'DONE')): 160 | return False 161 | return True 162 | 163 | def _should_resume(self): 164 | if not self._logdir: 165 | return False 166 | if tf.gfile.Exists(os.path.join(self._logdir, 'DONE')): 167 | # self._logger.debug('Already done.') 168 | return False 169 | if not tf.gfile.Exists(os.path.join(self._logdir, 'PING')): 170 | # self._logger.debug('Not started yet.') 171 | return False 172 | last_worker, last_ping = self._read_ping() 173 | if last_worker != self._worker_name and last_ping < self._ping_stale: 174 | # self._logger.debug('Already in progress.') 175 | return False 176 | return True 177 | 178 | def _claim(self): 179 | if not self._logdir: 180 | return False 181 | self._store_ping(overwrite=True) 182 | if self._ping_every: 183 | time.sleep(self._ping_every) 184 | if self._read_ping()[0] != self._worker_name: 185 | raise WorkerConflict 186 | self._store_ping() 187 | 188 | def _store_done(self): 189 | if not self._logdir: 190 | return 191 | with tf.gfile.Open(os.path.join(self._logdir, 'DONE'), 'w') as file_: 192 | file_.write('\n') 193 | 194 | def _store_fail(self, message): 195 | if not self._logdir: 196 | return 197 | with tf.gfile.Open(os.path.join(self._logdir, 'FAIL'), 'w') as file_: 198 | file_.write(message + '\n') 199 | 200 | def _read_ping(self): 201 | if not tf.gfile.Exists(os.path.join(self._logdir, 'PING')): 202 | return None, None 203 | try: 204 | with tf.gfile.Open(os.path.join(self._logdir, 'PING'), 'rb') as file_: 205 | last_worker, last_ping = pickle.load(file_) 206 | duration = (datetime.datetime.utcnow() - last_ping).total_seconds() 207 | return last_worker, duration 208 | except (EOFError, IOError, tf.errors.NotFoundError): 209 | raise WorkerConflict 210 | 211 | def _store_ping(self, overwrite=False): 212 | if not self._logdir: 213 | return 214 | try: 215 | last_worker, _ = self._read_ping() 216 | if last_worker is None: 217 | self._logger.info("Create directory '{}'.".format(self._logdir)) 218 | tf.gfile.MakeDirs(self._logdir) 219 | elif last_worker != self._worker_name and not overwrite: 220 | raise WorkerConflict 221 | # self._logger.debug('Store ping.') 222 | with tf.gfile.Open(os.path.join(self._logdir, 'PING'), 'wb') as file_: 223 | pickle.dump((self._worker_name, datetime.datetime.utcnow()), file_) 224 | except (EOFError, IOError, tf.errors.NotFoundError): 225 | raise WorkerConflict 226 | 227 | def _store_ping_thread(self): 228 | if not self._ping_every: 229 | return 230 | try: 231 | last_write = time.time() 232 | self._store_ping(self._logdir) 233 | while self._running[0]: 234 | if time.time() >= last_write + self._ping_every: 235 | last_write = time.time() 236 | self._store_ping(self._logdir) 237 | # Only wait short times to quickly react to abort. 238 | time.sleep(0.01) 239 | except WorkerConflict: 240 | self._running[0] = False 241 | 242 | def _handle_exception(self, exception): 243 | message = ''.join(traceback.format_exception(*sys.exc_info())) 244 | self._logger.warning('Exception:\n{}'.format(message)) 245 | self._logger.warning('Failed.') 246 | try: 247 | self._store_done() 248 | self._store_fail(message) 249 | except Exception: 250 | message = ''.join(traceback.format_exception(*sys.exc_info())) 251 | template = 'Exception in exception handler:\n{}' 252 | self._logger.warning(template.format(message)) 253 | 254 | def _create_logger(self): 255 | run_name = self._logdir and os.path.basename(self._logdir) 256 | methods = {} 257 | for name in 'debug info warning'.split(): 258 | methods[name] = lambda unused_self, message: print( 259 | 'Worker {} run {}: {}'.format(self._worker_name, run_name, message)) 260 | return type('PrefixedLogger', (object,), methods)() 261 | -------------------------------------------------------------------------------- /dreamer/training/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import os 21 | 22 | import tensorflow as tf 23 | 24 | from dreamer import tools 25 | 26 | 27 | _Phase = collections.namedtuple( 28 | 'Phase', 29 | 'name, writer, op, batch_size, steps, feed, report_every, log_every,' 30 | 'checkpoint_every, restore_every') 31 | 32 | 33 | class Trainer(object): 34 | 35 | def __init__(self, logdir, config=None): 36 | self._logdir = logdir 37 | self._global_step = tf.train.get_or_create_global_step() 38 | self._step = tf.placeholder(tf.int32, name='step') 39 | self._phase = tf.placeholder(tf.string, name='phase') 40 | self._log = tf.placeholder(tf.bool, name='log') 41 | self._report = tf.placeholder(tf.bool, name='report') 42 | self._reset = tf.placeholder(tf.bool, name='reset') 43 | self._phases = [] 44 | # Checkpointing. 45 | self._loaders = [] 46 | self._savers = [] 47 | self._logdirs = [] 48 | self._checkpoints = [] 49 | self._config = config or tools.AttrDict() 50 | 51 | @property 52 | def global_step(self): 53 | return self._global_step 54 | 55 | @property 56 | def step(self): 57 | return self._step 58 | 59 | @property 60 | def phase(self): 61 | return self._phase 62 | 63 | @property 64 | def log(self): 65 | return self._log 66 | 67 | @property 68 | def reset(self): 69 | return self._reset 70 | 71 | def add_saver( 72 | self, include=r'.*', exclude=r'.^', logdir=None, load=True, save=True, 73 | checkpoint=None): 74 | variables = tools.filter_variables(include, exclude) 75 | saver = tf.train.Saver(variables, max_to_keep=1) 76 | if load: 77 | self._loaders.append(saver) 78 | if save: 79 | self._savers.append(saver) 80 | self._logdirs.append(logdir or self._logdir) 81 | if checkpoint is None and self._config.checkpoint_to_load: 82 | self._checkpoints.append( 83 | os.path.join(self._logdirs[-1], self._config.checkpoint_to_load)) 84 | else: 85 | self._checkpoints.append(checkpoint) 86 | 87 | def add_phase( 88 | self, name, steps, score, summary, batch_size=1, 89 | report_every=None, log_every=None, checkpoint_every=None, 90 | restore_every=None, feed=None): 91 | score = tf.convert_to_tensor(score, tf.float32) 92 | summary = tf.convert_to_tensor(summary, tf.string) 93 | feed = feed or {} 94 | if not score.shape.ndims: 95 | score = score[None] 96 | writer = self._logdir and tf.summary.FileWriter( 97 | os.path.join(self._logdir, name), 98 | tf.get_default_graph(), flush_secs=30) 99 | op = self._define_step(name, batch_size, score, summary) 100 | self._phases.append(_Phase( 101 | name, writer, op, batch_size, int(steps), feed, report_every, 102 | log_every, checkpoint_every, restore_every)) 103 | 104 | def run(self, max_step=None, sess=None, unused_saver=None): 105 | for _ in self.iterate(max_step, sess): 106 | pass 107 | 108 | def iterate(self, max_step=None, sess=None): 109 | sess = sess or self._create_session() 110 | with sess: 111 | self._initialize_variables( 112 | sess, self._loaders, self._logdirs, self._checkpoints) 113 | sess.graph.finalize() 114 | while True: 115 | global_step = sess.run(self._global_step) 116 | if max_step and global_step >= max_step: 117 | break 118 | phase, epoch, steps_in = self._find_current_phase(global_step) 119 | phase_step = epoch * phase.steps + steps_in 120 | if steps_in % phase.steps < phase.batch_size: 121 | message = '\n' + ('-' * 50) + '\n' 122 | message += 'Epoch {} phase {} (phase step {}, global step {}).' 123 | print(message.format(epoch + 1, phase.name, phase_step, global_step)) 124 | # Populate book keeping tensors. 125 | phase.feed[self._step] = phase_step 126 | phase.feed[self._phase] = phase.name 127 | phase.feed[self._reset] = (steps_in < phase.batch_size) 128 | phase.feed[self._log] = phase.writer and self._is_every_steps( 129 | phase_step, phase.batch_size, phase.log_every) 130 | phase.feed[self._report] = self._is_every_steps( 131 | phase_step, phase.batch_size, phase.report_every) 132 | summary, mean_score, global_step = sess.run(phase.op, phase.feed) 133 | if self._is_every_steps( 134 | phase_step, phase.batch_size, phase.checkpoint_every): 135 | for saver in self._savers: 136 | self._store_checkpoint(sess, saver, global_step) 137 | if self._is_every_steps( 138 | phase_step, phase.batch_size, phase.report_every): 139 | print('Score {}.'.format(mean_score)) 140 | yield mean_score 141 | if summary and phase.writer: 142 | # We want smaller phases to catch up at the beginnig of each epoch so 143 | # that their graphs are aligned. 144 | longest_phase = max(phase_.steps for phase_ in self._phases) 145 | summary_step = epoch * longest_phase + steps_in 146 | phase.writer.add_summary(summary, summary_step) 147 | if self._is_every_steps( 148 | phase_step, phase.batch_size, phase.restore_every): 149 | self._initialize_variables( 150 | sess, self._loaders, self._logdirs, self._checkpoints) 151 | 152 | def _is_every_steps(self, phase_step, batch, every): 153 | if not every: 154 | return False 155 | covered_steps = range(phase_step, phase_step + batch) 156 | return any((step + 1) % every == 0 for step in covered_steps) 157 | 158 | def _find_current_phase(self, global_step): 159 | epoch_size = sum(phase.steps for phase in self._phases) 160 | epoch = int(global_step // epoch_size) 161 | steps_in = global_step % epoch_size 162 | for phase in self._phases: 163 | if steps_in < phase.steps: 164 | return phase, epoch, steps_in 165 | steps_in -= phase.steps 166 | 167 | def _define_step(self, name, batch_size, score, summary): 168 | with tf.variable_scope('phase_{}'.format(name)): 169 | score_mean = tools.StreamingMean((), tf.float32, 'score_mean') 170 | score.set_shape((None,)) 171 | with tf.control_dependencies([score, summary]): 172 | submit_score = score_mean.submit(score) 173 | with tf.control_dependencies([submit_score]): 174 | mean_score = tf.cond(self._report, score_mean.clear, float) 175 | summary = tf.cond( 176 | self._report, 177 | lambda: tf.summary.merge([summary, tf.summary.scalar( 178 | name + '/score', mean_score, family='trainer')]), 179 | lambda: summary) 180 | next_step = self._global_step.assign_add(batch_size) 181 | with tf.control_dependencies([summary, mean_score, next_step]): 182 | return ( 183 | tf.identity(summary), 184 | tf.identity(mean_score), 185 | tf.identity(next_step)) 186 | 187 | def _create_session(self): 188 | config = tf.ConfigProto() 189 | config.gpu_options.allow_growth = True 190 | try: 191 | return tf.Session('local', config=config) 192 | except tf.errors.NotFoundError: 193 | return tf.Session(config=config) 194 | 195 | def _initialize_variables(self, sess, savers, logdirs, checkpoints): 196 | sess.run(tf.group( 197 | tf.local_variables_initializer(), 198 | tf.global_variables_initializer())) 199 | assert len(savers) == len(logdirs) == len(checkpoints) 200 | for i, (saver, logdir, checkpoint) in enumerate( 201 | zip(savers, logdirs, checkpoints)): 202 | logdir = os.path.expanduser(logdir) 203 | state = tf.train.get_checkpoint_state(logdir) 204 | if checkpoint: 205 | checkpoint = os.path.join(logdir, checkpoint) 206 | if not checkpoint and state and state.model_checkpoint_path: 207 | checkpoint = state.model_checkpoint_path 208 | if checkpoint: 209 | saver.restore(sess, checkpoint) 210 | 211 | def _store_checkpoint(self, sess, saver, global_step): 212 | if not self._logdir or not saver: 213 | return 214 | tf.gfile.MakeDirs(self._logdir) 215 | filename = os.path.join(self._logdir, 'model.ckpt') 216 | saver.save(sess, filename, global_step) 217 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import setuptools 20 | 21 | 22 | setuptools.setup( 23 | name='dreamer', 24 | version='1.0.0', 25 | description=( 26 | 'Dream to Explore: Learning Behaviors by Latent Imagination'), 27 | license='Apache 2.0', 28 | url='http://github.com/google-research/dreamer', 29 | install_requires=[ 30 | 'dm_control', 31 | 'gym', 32 | 'matplotlib', 33 | 'ruamel.yaml', 34 | 'scikit-image', 35 | 'scipy', 36 | 'tensorflow-gpu==1.13.1', 37 | 'tensorflow_probability==0.6.0', 38 | ], 39 | packages=setuptools.find_packages(), 40 | classifiers=[ 41 | 'Programming Language :: Python :: 3', 42 | 'License :: OSI Approved :: Apache Software License', 43 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 44 | 'Intended Audience :: Science/Research', 45 | ], 46 | ) 47 | -------------------------------------------------------------------------------- /tests/test_dreamer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import pathlib 20 | import sys 21 | 22 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 23 | 24 | import tensorflow as tf 25 | from dreamer import tools 26 | from dreamer.scripts import train 27 | 28 | 29 | class DreamerTest(tf.test.TestCase): 30 | 31 | def test_dreamer(self): 32 | args = tools.AttrDict( 33 | logdir=self.get_temp_dir(), 34 | num_runs=1, 35 | params=tools.AttrDict( 36 | defaults=['dreamer', 'debug'], 37 | tasks=['dummy'], 38 | isolate_envs='none', 39 | max_steps=30, 40 | train_planner='policy_sample', 41 | test_planner='policy_mode', 42 | planner_objective='reward_value', 43 | action_head=True, 44 | value_head=True, 45 | imagination_horizon=3), 46 | ping_every=0, 47 | resume_runs=False) 48 | train.main(args) 49 | 50 | def test_dreamer_discrete(self): 51 | args = tools.AttrDict( 52 | logdir=self.get_temp_dir(), 53 | num_runs=1, 54 | params=tools.AttrDict( 55 | defaults=['dreamer', 'debug'], 56 | tasks=['dummy'], 57 | isolate_envs='none', 58 | max_steps=30, 59 | train_planner='policy_sample', 60 | test_planner='policy_mode', 61 | planner_objective='reward_value', 62 | action_head=True, 63 | value_head=True, 64 | imagination_horizon=3, 65 | action_head_dist='onehot_score', 66 | action_noise_type='epsilon_greedy'), 67 | ping_every=0, 68 | resume_runs=False) 69 | train.main(args) 70 | 71 | def test_dreamer_target(self): 72 | args = tools.AttrDict( 73 | logdir=self.get_temp_dir(), 74 | num_runs=1, 75 | params=tools.AttrDict( 76 | defaults=['dreamer', 'debug'], 77 | tasks=['dummy'], 78 | isolate_envs='none', 79 | max_steps=30, 80 | train_planner='policy_sample', 81 | test_planner='policy_mode', 82 | planner_objective='reward_value', 83 | action_head=True, 84 | value_head=True, 85 | value_target_head=True, 86 | imagination_horizon=3), 87 | ping_every=0, 88 | resume_runs=False) 89 | train.main(args) 90 | 91 | def test_no_value(self): 92 | args = tools.AttrDict( 93 | logdir=self.get_temp_dir(), 94 | num_runs=1, 95 | params=tools.AttrDict( 96 | defaults=['actor', 'debug'], 97 | tasks=['dummy'], 98 | isolate_envs='none', 99 | max_steps=30, 100 | imagination_horizon=3), 101 | ping_every=0, 102 | resume_runs=False) 103 | train.main(args) 104 | 105 | def test_planet(self): 106 | args = tools.AttrDict( 107 | logdir=self.get_temp_dir(), 108 | num_runs=1, 109 | params=tools.AttrDict( 110 | defaults=['planet', 'debug'], 111 | tasks=['dummy'], 112 | isolate_envs='none', 113 | max_steps=30, 114 | planner_horizon=3), 115 | ping_every=0, 116 | resume_runs=False) 117 | train.main(args) 118 | 119 | 120 | if __name__ == '__main__': 121 | tf.test.main() 122 | -------------------------------------------------------------------------------- /tests/test_isolate_envs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import pathlib 20 | import sys 21 | 22 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 23 | 24 | import tensorflow as tf 25 | from dreamer import tools 26 | from dreamer.scripts import train 27 | 28 | 29 | class IsolateEnvsTest(tf.test.TestCase): 30 | 31 | def test_dummy_thread(self): 32 | args = tools.AttrDict( 33 | logdir=self.get_temp_dir(), 34 | num_runs=1, 35 | params=tools.AttrDict( 36 | defaults=['dreamer', 'debug'], 37 | tasks=['dummy'], 38 | isolate_envs='thread', 39 | max_steps=30), 40 | ping_every=0, 41 | resume_runs=False) 42 | train.main(args) 43 | 44 | def test_dm_control_thread(self): 45 | args = tools.AttrDict( 46 | logdir=self.get_temp_dir(), 47 | num_runs=1, 48 | params=tools.AttrDict( 49 | defaults=['dreamer', 'debug'], 50 | tasks=['cup_catch'], 51 | isolate_envs='thread', 52 | max_steps=30), 53 | ping_every=0, 54 | resume_runs=False) 55 | train.main(args) 56 | 57 | def test_atari_thread(self): 58 | args = tools.AttrDict( 59 | logdir=self.get_temp_dir(), 60 | num_runs=1, 61 | params=tools.AttrDict( 62 | defaults=['dreamer', 'debug'], 63 | tasks=['atari_pong'], 64 | isolate_envs='thread', 65 | action_head_dist='onehot_score', 66 | action_noise_type='epsilon_greedy', 67 | max_steps=30), 68 | ping_every=0, 69 | resume_runs=False) 70 | train.main(args) 71 | 72 | # def test_dmlab_thread(self): 73 | # args = tools.AttrDict( 74 | # logdir=self.get_temp_dir(), 75 | # num_runs=1, 76 | # params=tools.AttrDict( 77 | # defaults=['dreamer', 'debug'], 78 | # tasks=['dmlab_collect'], 79 | # isolate_envs='thread', 80 | # action_head_dist='onehot_score', 81 | # action_noise_type='epsilon_greedy', 82 | # max_steps=30), 83 | # ping_every=0, 84 | # resume_runs=False) 85 | # train.main(args) 86 | 87 | 88 | if __name__ == '__main__': 89 | tf.test.main() 90 | -------------------------------------------------------------------------------- /tests/test_nested.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import pathlib 21 | import sys 22 | 23 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 24 | 25 | import tensorflow as tf 26 | from dreamer.tools import nested 27 | 28 | 29 | class ZipTest(tf.test.TestCase): 30 | 31 | def test_scalar(self): 32 | self.assertEqual(42, nested.zip(42)) 33 | self.assertEqual((13, 42), nested.zip(13, 42)) 34 | 35 | def test_empty(self): 36 | self.assertEqual({}, nested.zip({}, {})) 37 | 38 | def test_base_case(self): 39 | self.assertEqual((1, 2, 3), nested.zip(1, 2, 3)) 40 | 41 | def test_shallow_list(self): 42 | a = [1, 2, 3] 43 | b = [4, 5, 6] 44 | c = [7, 8, 9] 45 | result = nested.zip(a, b, c) 46 | self.assertEqual([(1, 4, 7), (2, 5, 8), (3, 6, 9)], result) 47 | 48 | def test_shallow_tuple(self): 49 | a = (1, 2, 3) 50 | b = (4, 5, 6) 51 | c = (7, 8, 9) 52 | result = nested.zip(a, b, c) 53 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result) 54 | 55 | def test_shallow_dict(self): 56 | a = {'a': 1, 'b': 2, 'c': 3} 57 | b = {'a': 4, 'b': 5, 'c': 6} 58 | c = {'a': 7, 'b': 8, 'c': 9} 59 | result = nested.zip(a, b, c) 60 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8), 'c': (3, 6, 9)}, result) 61 | 62 | def test_single(self): 63 | a = [[1, 2], 3] 64 | result = nested.zip(a) 65 | self.assertEqual(a, result) 66 | 67 | def test_mixed_structures(self): 68 | a = [(1, 2), 3, {'foo': [4]}] 69 | b = [(5, 6), 7, {'foo': [8]}] 70 | result = nested.zip(a, b) 71 | self.assertEqual([((1, 5), (2, 6)), (3, 7), {'foo': [(4, 8)]}], result) 72 | 73 | def test_different_types(self): 74 | a = [1, 2, 3] 75 | b = 'a b c'.split() 76 | result = nested.zip(a, b) 77 | self.assertEqual([(1, 'a'), (2, 'b'), (3, 'c')], result) 78 | 79 | def test_use_type_of_first(self): 80 | a = (1, 2, 3) 81 | b = [4, 5, 6] 82 | c = [7, 8, 9] 83 | result = nested.zip(a, b, c) 84 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result) 85 | 86 | def test_namedtuple(self): 87 | Foo = collections.namedtuple('Foo', 'value') 88 | foo, bar = Foo(42), Foo(13) 89 | self.assertEqual(Foo((42, 13)), nested.zip(foo, bar)) 90 | 91 | 92 | class MapTest(tf.test.TestCase): 93 | 94 | def test_scalar(self): 95 | self.assertEqual(42, nested.map(lambda x: x, 42)) 96 | 97 | def test_empty(self): 98 | self.assertEqual({}, nested.map(lambda x: x, {})) 99 | 100 | def test_shallow_list(self): 101 | self.assertEqual([2, 4, 6], nested.map(lambda x: 2 * x, [1, 2, 3])) 102 | 103 | def test_shallow_dict(self): 104 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4} 105 | self.assertEqual(data, nested.map(lambda x: x, data)) 106 | 107 | def test_mixed_structure(self): 108 | structure = [(1, 2), 3, {'foo': [4]}] 109 | result = nested.map(lambda x: 2 * x, structure) 110 | self.assertEqual([(2, 4), 6, {'foo': [8]}], result) 111 | 112 | def test_mixed_types(self): 113 | self.assertEqual([14, 'foofoo'], nested.map(lambda x: x * 2, [7, 'foo'])) 114 | 115 | def test_multiple_lists(self): 116 | a = [1, 2, 3] 117 | b = [4, 5, 6] 118 | c = [7, 8, 9] 119 | result = nested.map(lambda x, y, z: x + y + z, a, b, c) 120 | self.assertEqual([12, 15, 18], result) 121 | 122 | def test_namedtuple(self): 123 | Foo = collections.namedtuple('Foo', 'value') 124 | foo, bar = [Foo(42)], [Foo(13)] 125 | function = nested.map(lambda x, y: (y, x), foo, bar) 126 | self.assertEqual([Foo((13, 42))], function) 127 | function = nested.map(lambda x, y: x + y, foo, bar) 128 | self.assertEqual([Foo(55)], function) 129 | 130 | 131 | class FlattenTest(tf.test.TestCase): 132 | 133 | def test_scalar(self): 134 | self.assertEqual((42,), nested.flatten(42)) 135 | 136 | def test_empty(self): 137 | self.assertEqual((), nested.flatten({})) 138 | 139 | def test_base_case(self): 140 | self.assertEqual((1,), nested.flatten(1)) 141 | 142 | def test_convert_type(self): 143 | self.assertEqual((1, 2, 3), nested.flatten([1, 2, 3])) 144 | 145 | def test_mixed_structure(self): 146 | self.assertEqual((1, 2, 3, 4), nested.flatten([(1, 2), 3, {'foo': [4]}])) 147 | 148 | def test_value_ordering(self): 149 | self.assertEqual((1, 2, 3), nested.flatten({'a': 1, 'b': 2, 'c': 3})) 150 | 151 | 152 | class FilterTest(tf.test.TestCase): 153 | 154 | def test_empty(self): 155 | self.assertEqual({}, nested.filter(lambda x: True, {})) 156 | self.assertEqual({}, nested.filter(lambda x: False, {})) 157 | 158 | def test_base_case(self): 159 | self.assertEqual((), nested.filter(lambda x: False, 1)) 160 | 161 | def test_single_dict(self): 162 | predicate = lambda x: x % 2 == 0 163 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4} 164 | self.assertEqual({'b': 2, 'd': 4}, nested.filter(predicate, data)) 165 | 166 | def test_multiple_lists(self): 167 | a = [1, 2, 3] 168 | b = [4, 5, 6] 169 | c = [7, 8, 9] 170 | predicate = lambda *args: any(x % 4 == 0 for x in args) 171 | result = nested.filter(predicate, a, b, c) 172 | self.assertEqual([(1, 4, 7), (2, 5, 8)], result) 173 | 174 | def test_multiple_dicts(self): 175 | a = {'a': 1, 'b': 2, 'c': 3} 176 | b = {'a': 4, 'b': 5, 'c': 6} 177 | c = {'a': 7, 'b': 8, 'c': 9} 178 | predicate = lambda *args: any(x % 4 == 0 for x in args) 179 | result = nested.filter(predicate, a, b, c) 180 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8)}, result) 181 | 182 | def test_mixed_structure(self): 183 | predicate = lambda x: x % 2 == 0 184 | data = [(1, 2), 3, {'foo': [4]}] 185 | self.assertEqual([(2,), {'foo': [4]}], nested.filter(predicate, data)) 186 | 187 | def test_remove_empty_containers(self): 188 | data = [(1, 2, 3), 4, {'foo': [5, 6], 'bar': 7}] 189 | self.assertEqual([], nested.filter(lambda x: False, data)) 190 | 191 | def test_namedtuple(self): 192 | Foo = collections.namedtuple('Foo', 'value1, value2') 193 | self.assertEqual(Foo(1, None), nested.filter(lambda x: x == 1, Foo(1, 2))) 194 | 195 | def test_namedtuple_multiple(self): 196 | Foo = collections.namedtuple('Foo', 'value1, value2') 197 | foo = Foo(1, 2) 198 | bar = Foo(2, 3) 199 | result = nested.filter(lambda x, y: x + y > 3, foo, bar) 200 | self.assertEqual(Foo(None, (2, 3)), result) 201 | 202 | def test_namedtuple_nested(self): 203 | Foo = collections.namedtuple('Foo', 'value1, value2') 204 | foo = Foo(1, [1, 2, 3]) 205 | self.assertEqual(Foo(None, [2, 3]), nested.filter(lambda x: x > 1, foo)) 206 | -------------------------------------------------------------------------------- /tests/test_overshooting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import pathlib 20 | import sys 21 | 22 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 23 | 24 | import tensorflow as tf 25 | from dreamer import models 26 | from dreamer.tools import overshooting 27 | 28 | 29 | class _MockCell(models.Base): 30 | 31 | def __init__(self, obs_size): 32 | self._obs_size = obs_size 33 | super(_MockCell, self).__init__( 34 | tf.make_template('transition', self._transition), 35 | tf.make_template('posterior', self._posterior)) 36 | 37 | @property 38 | def state_size(self): 39 | return {'obs': self._obs_size} 40 | 41 | def _transition(self, prev_state, prev_action, zero_obs): 42 | if isinstance(prev_action, (tuple, list)): 43 | prev_action = prev_action[0] 44 | return {'obs': prev_state['obs'] + prev_action} 45 | 46 | def _posterior(self, prev_state, prev_action, obs): 47 | if isinstance(obs, (tuple, list)): 48 | obs = obs[0] 49 | return {'obs': obs} 50 | 51 | 52 | class OvershootingTest(tf.test.TestCase): 53 | 54 | def test_example(self): 55 | obs = tf.constant([ 56 | [10, 20, 30, 40, 50, 60], 57 | [70, 80, 0, 0, 0, 0], 58 | ], dtype=tf.float32)[:, :, None] 59 | prev_action = tf.constant([ 60 | [0.0, 0.1, 0.2, 0.3, 0.4, 0.5], 61 | [9.0, 0.7, 0, 0, 0, 0], 62 | ], dtype=tf.float32)[:, :, None] 63 | length = tf.constant([6, 2], dtype=tf.int32) 64 | cell = _MockCell(1) 65 | _, prior, posterior, mask = overshooting( 66 | cell, obs, obs, prev_action, length, 3) 67 | prior = tf.squeeze(prior['obs'], 3) 68 | posterior = tf.squeeze(posterior['obs'], 3) 69 | mask = tf.to_int32(mask) 70 | with self.test_session(): 71 | # Each column corresponds to a different state step, and each row 72 | # corresponds to a different overshooting distance from there. 73 | self.assertAllEqual([ 74 | [1, 1, 1, 1, 1, 1], 75 | [1, 1, 1, 1, 1, 0], 76 | [1, 1, 1, 1, 0, 0], 77 | [1, 1, 1, 0, 0, 0], 78 | ], mask.eval()[0].T) 79 | self.assertAllEqual([ 80 | [1, 1, 0, 0, 0, 0], 81 | [1, 0, 0, 0, 0, 0], 82 | [0, 0, 0, 0, 0, 0], 83 | [0, 0, 0, 0, 0, 0], 84 | ], mask.eval()[1].T) 85 | self.assertAllClose([ 86 | [0.0, 10.1, 20.2, 30.3, 40.4, 50.5], 87 | [0.1, 10.3, 20.5, 30.7, 40.9, 0], 88 | [0.3, 10.6, 20.9, 31.2, 0, 0], 89 | [0.6, 11.0, 21.4, 0, 0, 0], 90 | ], prior.eval()[0].T) 91 | self.assertAllClose([ 92 | [10, 20, 30, 40, 50, 60], 93 | [20, 30, 40, 50, 60, 0], 94 | [30, 40, 50, 60, 0, 0], 95 | [40, 50, 60, 0, 0, 0], 96 | ], posterior.eval()[0].T) 97 | self.assertAllClose([ 98 | [9.0, 70.7, 0, 0, 0, 0], 99 | [9.7, 0, 0, 0, 0, 0], 100 | [0, 0, 0, 0, 0, 0], 101 | [0, 0, 0, 0, 0, 0], 102 | ], prior.eval()[1].T) 103 | self.assertAllClose([ 104 | [70, 80, 0, 0, 0, 0], 105 | [80, 0, 0, 0, 0, 0], 106 | [0, 0, 0, 0, 0, 0], 107 | [0, 0, 0, 0, 0, 0], 108 | ], posterior.eval()[1].T) 109 | 110 | def test_nested(self): 111 | obs = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)), tf.ones((3, 50, 3))) 112 | prev_action = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2))) 113 | length = tf.constant([49, 50, 3], dtype=tf.int32) 114 | cell = _MockCell(1) 115 | overshooting(cell, obs, obs, prev_action, length, 3) 116 | 117 | 118 | if __name__ == '__main__': 119 | tf.test.main() 120 | -------------------------------------------------------------------------------- /tests/test_running.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Dreamer Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import pathlib 21 | import pickle 22 | import sys 23 | import threading 24 | import time 25 | 26 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 27 | 28 | import numpy as np 29 | import tensorflow as tf 30 | from dreamer.training import running 31 | 32 | 33 | class TestExperiment(tf.test.TestCase): 34 | 35 | def test_no_kills(self): 36 | basedir = os.path.join(tf.test.get_temp_dir(), 'test_no_kills') 37 | processes = [] 38 | for worker_name in range(20): 39 | processes.append(threading.Thread( 40 | target=_worker_normal, args=(basedir, str(worker_name)))) 41 | processes[-1].start() 42 | for process in processes: 43 | process.join() 44 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/DONE')) 45 | self.assertEqual(100, len(filepaths)) 46 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/PING')) 47 | self.assertEqual(100, len(filepaths)) 48 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/started')) 49 | self.assertEqual(100, len(filepaths)) 50 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/resumed')) 51 | self.assertEqual(0, len(filepaths)) 52 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/failed')) 53 | self.assertEqual(0, len(filepaths)) 54 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/numbers')) 55 | self.assertEqual(100, len(filepaths)) 56 | for filepath in filepaths: 57 | with tf.gfile.GFile(filepath, 'rb') as file_: 58 | self.assertEqual(10, len(pickle.load(file_))) 59 | 60 | def test_dying_workers(self): 61 | basedir = os.path.join(tf.test.get_temp_dir(), 'test_dying_workers') 62 | processes = [] 63 | for worker_name in range(20): 64 | processes.append(threading.Thread( 65 | target=_worker_dying, args=(basedir, 15, str(worker_name)))) 66 | processes[-1].start() 67 | for process in processes: 68 | process.join() 69 | processes = [] 70 | for worker_name in range(20): 71 | processes.append(threading.Thread( 72 | target=_worker_normal, args=(basedir, str(worker_name)))) 73 | processes[-1].start() 74 | for process in processes: 75 | process.join() 76 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/DONE')) 77 | self.assertEqual(100, len(filepaths)) 78 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/PING')) 79 | self.assertEqual(100, len(filepaths)) 80 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/FAIL')) 81 | self.assertEqual(0, len(filepaths)) 82 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/started')) 83 | self.assertEqual(100, len(filepaths)) 84 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/resumed')) 85 | self.assertEqual(20, len(filepaths)) 86 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/numbers')) 87 | self.assertEqual(100, len(filepaths)) 88 | for filepath in filepaths: 89 | with tf.gfile.GFile(filepath, 'rb') as file_: 90 | self.assertEqual(10, len(pickle.load(file_))) 91 | 92 | 93 | def _worker_normal(basedir, worker_name): 94 | experiment = running.Experiment( 95 | basedir, _process_fn, _start_fn, _resume_fn, 96 | num_runs=100, worker_name=worker_name, ping_every=1.0) 97 | for run in experiment: 98 | for score in run: 99 | pass 100 | 101 | 102 | def _worker_dying(basedir, die_at_step, worker_name): 103 | experiment = running.Experiment( 104 | basedir, _process_fn, _start_fn, _resume_fn, 105 | num_runs=100, worker_name=worker_name, ping_every=1.0) 106 | step = 0 107 | for run in experiment: 108 | for score in run: 109 | step += 1 110 | if step >= die_at_step: 111 | return 112 | 113 | 114 | def _start_fn(logdir): 115 | assert not tf.gfile.Exists(os.path.join(logdir, 'DONE')) 116 | assert not tf.gfile.Exists(os.path.join(logdir, 'started')) 117 | assert not tf.gfile.Exists(os.path.join(logdir, 'resumed')) 118 | with tf.gfile.GFile(os.path.join(logdir, 'started'), 'w') as file_: 119 | file_.write('\n') 120 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'wb') as file_: 121 | pickle.dump([], file_) 122 | return [] 123 | 124 | 125 | def _resume_fn(logdir): 126 | assert not tf.gfile.Exists(os.path.join(logdir, 'DONE')) 127 | assert tf.gfile.Exists(os.path.join(logdir, 'started')) 128 | with tf.gfile.GFile(os.path.join(logdir, 'resumed'), 'w') as file_: 129 | file_.write('\n') 130 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'rb') as file_: 131 | numbers = pickle.load(file_) 132 | if len(numbers) != 5: 133 | raise Exception('Expected to be resumed in the middle for this test.') 134 | return numbers 135 | 136 | 137 | def _process_fn(logdir, numbers): 138 | assert tf.gfile.Exists(os.path.join(logdir, 'started')) 139 | while len(numbers) < 10: 140 | number = np.random.uniform(0, 0.1) 141 | time.sleep(number) 142 | numbers.append(number) 143 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'wb') as file_: 144 | pickle.dump(numbers, file_) 145 | yield number 146 | 147 | 148 | if __name__ == '__main__': 149 | tf.test.main() 150 | --------------------------------------------------------------------------------