├── .gitignore ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── planet ├── __init__.py ├── control │ ├── __init__.py │ ├── batch_env.py │ ├── dummy_env.py │ ├── in_graph_batch_env.py │ ├── mpc_agent.py │ ├── planning.py │ ├── random_episodes.py │ ├── simulate.py │ ├── temporal_difference.py │ └── wrappers.py ├── models │ ├── __init__.py │ ├── base.py │ ├── drnn.py │ ├── rssm.py │ └── ssm.py ├── networks │ ├── __init__.py │ ├── basic.py │ └── conv_ha.py ├── scripts │ ├── __init__.py │ ├── configs.py │ ├── create_video.py │ ├── fetch_events.py │ ├── objectives.py │ ├── sync.py │ ├── tasks.py │ ├── test_planet.py │ └── train.py ├── tools │ ├── __init__.py │ ├── attr_dict.py │ ├── bind.py │ ├── bound_action.py │ ├── chunk_sequence.py │ ├── copy_weights.py │ ├── count_dataset.py │ ├── count_weights.py │ ├── custom_optimizer.py │ ├── filter_variables_lib.py │ ├── gif_summary.py │ ├── image_strip_summary.py │ ├── mask.py │ ├── nested.py │ ├── numpy_episodes.py │ ├── overshooting.py │ ├── preprocess.py │ ├── reshape_as.py │ ├── schedule.py │ ├── shape.py │ ├── streaming_mean.py │ ├── summary.py │ ├── target_network.py │ ├── test_nested.py │ ├── test_overshooting.py │ └── unroll.py └── training │ ├── __init__.py │ ├── define_model.py │ ├── define_summaries.py │ ├── running.py │ ├── test_running.py │ ├── trainer.py │ └── utility.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.egg-info 4 | /dist 5 | MUJOCO_LOG.TXT 6 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of PlaNet authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | Google LLC 7 | Danijar Hafner 8 | Timothy Lillicrap 9 | Ian Fischer 10 | Ruben Villegas 11 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Planning Network 2 | 3 | Danijar Hafner, Timothy Lillicrap, Ian Fischer, Ruben Villegas, David Ha, Honglak Lee, James Davidson 4 | 5 | ![PlaNet policies and predictions](https://imgur.com/UeeQIfo.gif) 6 | 7 | This project provides the open source implementation of the PlaNet agent 8 | introduced in [Learning Latent Dynamics for Planning from Pixels][paper]. 9 | PlaNet is a purely model-based reinforcement learning algorithm that solves 10 | control tasks from images by efficient planning in a learned latent space. 11 | PlaNet competes with top model-free methods in terms of final performance and 12 | training time while using substantially less interaction with the environment. 13 | 14 | If you find this open source release useful, please reference in your paper: 15 | 16 | ``` 17 | @inproceedings{hafner2019planet, 18 | title={Learning Latent Dynamics for Planning from Pixels}, 19 | author={Hafner, Danijar and Lillicrap, Timothy and Fischer, Ian and Villegas, Ruben and Ha, David and Lee, Honglak and Davidson, James}, 20 | booktitle={International Conference on Machine Learning}, 21 | pages={2555--2565}, 22 | year={2019} 23 | } 24 | ``` 25 | 26 | ## Method 27 | 28 | ![PlaNet model diagram](https://i.imgur.com/fpvrAqw.png) 29 | 30 | PlaNet models the world as a compact sequence of hidden states. For planning, 31 | we first encode the history of past images into the current state. From there, 32 | we efficiently predict future rewards for multiple action sequences in latent 33 | space. We execute the first action of the best sequence found and replan after 34 | observing the next image. 35 | 36 | Find more information: 37 | 38 | - [Google AI Blog post][blog] 39 | - [Project website][website] 40 | - [PDF paper][paper] 41 | 42 | [blog]: https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html 43 | [website]: https://danijar.com/project/planet/ 44 | [paper]: https://arxiv.org/pdf/1811.04551.pdf 45 | 46 | ## Instructions 47 | 48 | To train an agent, install the dependencies and then run: 49 | 50 | ```sh 51 | python3 -m planet.scripts.train --logdir /path/to/logdir --params '{tasks: [cheetah_run]}' 52 | ``` 53 | 54 | The code prints `nan` as the score for iterations during which no summaries 55 | were computed. 56 | 57 | The available tasks are listed in `scripts/tasks.py`. The default parameters 58 | can be found in `scripts/configs.py`. To run the experiments from our 59 | paper, pass the following parameters to `--params {...}` in addition to the 60 | list of tasks: 61 | 62 | | Experiment | Parameters | 63 | | :--------- | :--------- | 64 | | PlaNet | No additional parameters. | 65 | | Random data collection | `planner_iterations: 0, train_action_noise: 1.0` | 66 | | Purely deterministic | `mean_only: True, divergence_scale: 0.0` | 67 | | Purely stochastic | `model: ssm` | 68 | | One agent all tasks | `collect_every: 30000` | 69 | 70 | Please note that the agent has seen some improvements so the results may be a 71 | bit different now. 72 | 73 | ## Modifications 74 | 75 | These are good places to start when modifying the code: 76 | 77 | | Directory | Description | 78 | | :-------- | :---------- | 79 | | `scripts/configs.py` | Add new parameters or change defaults. | 80 | | `scripts/tasks.py` | Add or modify environments. | 81 | | `models` | Add or modify latent transition models. | 82 | | `networks` | Add or modify encoder and decoder networks. | 83 | 84 | Tips for development: 85 | 86 | - You can set `--config debug` to reduce the episode length, batch size, and 87 | collect data more freqnently. This helps to quickly reach all parts of the 88 | code. 89 | - You can use `--num_runs 1000 --resume_runs False` to automatically start new 90 | runs in sub directories of the logdir every time to execute the script. 91 | - Environments live in separate processes by default. Some environments work 92 | better when separated into threads instead by specifying `--params 93 | '{isolate_envs: thread}'`. 94 | 95 | ## Dependencies 96 | 97 | The code was tested under Ubuntu 18 and uses these packages: 98 | 99 | - tensorflow-gpu==1.13.1 100 | - tensorflow_probability==0.6.0 101 | - dm_control (`egl` [rendering option][dmc-rendering] recommended) 102 | - gym 103 | - scikit-image 104 | - scipy 105 | - ruamel.yaml 106 | - matplotlib 107 | 108 | [dmc-rendering]: https://github.com/deepmind/dm_control#rendering 109 | 110 | Disclaimer: This is not an official Google product. 111 | -------------------------------------------------------------------------------- /planet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import control 20 | from . import models 21 | from . import networks 22 | from . import scripts 23 | from . import tools 24 | from . import training 25 | -------------------------------------------------------------------------------- /planet/control/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import planning 20 | from .batch_env import BatchEnv 21 | from .dummy_env import DummyEnv 22 | from .in_graph_batch_env import InGraphBatchEnv 23 | from .mpc_agent import MPCAgent 24 | from .random_episodes import random_episodes 25 | from .simulate import simulate 26 | from .temporal_difference import discounted_return 27 | from .temporal_difference import fixed_step_return 28 | from .temporal_difference import lambda_return 29 | -------------------------------------------------------------------------------- /planet/control/batch_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | 21 | 22 | class BatchEnv(object): 23 | """Combine multiple environments to step them in batch.""" 24 | 25 | def __init__(self, envs, blocking): 26 | """Combine multiple environments to step them in batch. 27 | 28 | To step environments in parallel, environments must support a 29 | `blocking=False` argument to their step and reset functions that makes them 30 | return callables instead to receive the result at a later time. 31 | 32 | Args: 33 | envs: List of environments. 34 | blocking: Step environments after another rather than in parallel. 35 | 36 | Raises: 37 | ValueError: Environments have different observation or action spaces. 38 | """ 39 | self._envs = envs 40 | self._blocking = blocking 41 | observ_space = self._envs[0].observation_space 42 | if not all(env.observation_space == observ_space for env in self._envs): 43 | raise ValueError('All environments must use the same observation space.') 44 | action_space = self._envs[0].action_space 45 | if not all(env.action_space == action_space for env in self._envs): 46 | raise ValueError('All environments must use the same observation space.') 47 | 48 | def __len__(self): 49 | """Number of combined environments.""" 50 | return len(self._envs) 51 | 52 | def __getitem__(self, index): 53 | """Access an underlying environment by index.""" 54 | return self._envs[index] 55 | 56 | def __getattr__(self, name): 57 | """Forward unimplemented attributes to one of the original environments. 58 | 59 | Args: 60 | name: Attribute that was accessed. 61 | 62 | Returns: 63 | Value behind the attribute name one of the wrapped environments. 64 | """ 65 | return getattr(self._envs[0], name) 66 | 67 | def step(self, actions): 68 | """Forward a batch of actions to the wrapped environments. 69 | 70 | Args: 71 | actions: Batched action to apply to the environment. 72 | 73 | Raises: 74 | ValueError: Invalid actions. 75 | 76 | Returns: 77 | Batch of observations, rewards, and done flags. 78 | """ 79 | for index, (env, action) in enumerate(zip(self._envs, actions)): 80 | if not env.action_space.contains(action): 81 | message = 'Invalid action at index {}: {}' 82 | raise ValueError(message.format(index, action)) 83 | if self._blocking: 84 | transitions = [ 85 | env.step(action) 86 | for env, action in zip(self._envs, actions)] 87 | else: 88 | transitions = [ 89 | env.step(action, blocking=False) 90 | for env, action in zip(self._envs, actions)] 91 | transitions = [transition() for transition in transitions] 92 | observs, rewards, dones, infos = zip(*transitions) 93 | observ = np.stack(observs) 94 | reward = np.stack(rewards).astype(np.float32) 95 | done = np.stack(dones) 96 | info = tuple(infos) 97 | return observ, reward, done, info 98 | 99 | def reset(self, indices=None): 100 | """Reset the environment and convert the resulting observation. 101 | 102 | Args: 103 | indices: The batch indices of environments to reset; defaults to all. 104 | 105 | Returns: 106 | Batch of observations. 107 | """ 108 | if indices is None: 109 | indices = np.arange(len(self._envs)) 110 | if self._blocking: 111 | observs = [self._envs[index].reset() for index in indices] 112 | else: 113 | observs = [self._envs[index].reset(blocking=False) for index in indices] 114 | observs = [observ() for observ in observs] 115 | observ = np.stack(observs) 116 | return observ 117 | 118 | def close(self): 119 | """Send close messages to the external process and join them.""" 120 | for env in self._envs: 121 | if hasattr(env, 'close'): 122 | env.close() 123 | -------------------------------------------------------------------------------- /planet/control/dummy_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import gym 20 | import numpy as np 21 | 22 | 23 | class DummyEnv(object): 24 | 25 | def __init__(self): 26 | self._random = np.random.RandomState(seed=0) 27 | self._step = None 28 | 29 | @property 30 | def observation_space(self): 31 | low = np.zeros([64, 64, 3], dtype=np.float32) 32 | high = np.ones([64, 64, 3], dtype=np.float32) 33 | spaces = {'image': gym.spaces.Box(low, high)} 34 | return gym.spaces.Dict(spaces) 35 | 36 | @property 37 | def action_space(self): 38 | low = -np.ones([5], dtype=np.float32) 39 | high = np.ones([5], dtype=np.float32) 40 | return gym.spaces.Box(low, high) 41 | 42 | def reset(self): 43 | self._step = 0 44 | obs = self.observation_space.sample() 45 | return obs 46 | 47 | def step(self, action): 48 | obs = self.observation_space.sample() 49 | reward = self._random.uniform(0, 1) 50 | self._step += 1 51 | done = self._step >= 1000 52 | info = {} 53 | return obs, reward, done, info 54 | -------------------------------------------------------------------------------- /planet/control/in_graph_batch_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Batch of environments inside the TensorFlow graph.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import gym 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | class InGraphBatchEnv(object): 27 | """Batch of environments inside the TensorFlow graph. 28 | 29 | The batch of environments will be stepped and reset inside of the graph using 30 | a tf.py_func(). The current batch of observations, actions, rewards, and done 31 | flags are held in according variables. 32 | """ 33 | 34 | def __init__(self, batch_env): 35 | """Batch of environments inside the TensorFlow graph. 36 | 37 | Args: 38 | batch_env: Batch environment. 39 | """ 40 | self._batch_env = batch_env 41 | batch_dims = (len(self._batch_env),) 42 | observ_shape = self._parse_shape(self._batch_env.observation_space) 43 | observ_dtype = self._parse_dtype(self._batch_env.observation_space) 44 | action_shape = self._parse_shape(self._batch_env.action_space) 45 | action_dtype = self._parse_dtype(self._batch_env.action_space) 46 | with tf.variable_scope('env_temporary'): 47 | self._observ = tf.get_variable( 48 | 'observ', batch_dims + observ_shape, observ_dtype, 49 | tf.constant_initializer(0), trainable=False) 50 | self._action = tf.get_variable( 51 | 'action', batch_dims + action_shape, action_dtype, 52 | tf.constant_initializer(0), trainable=False) 53 | self._reward = tf.get_variable( 54 | 'reward', batch_dims, tf.float32, 55 | tf.constant_initializer(0), trainable=False) 56 | # This variable should be boolean, but tf.scatter_update() does not 57 | # support boolean resource variables yet. 58 | self._done = tf.get_variable( 59 | 'done', batch_dims, tf.int32, 60 | tf.constant_initializer(False), trainable=False) 61 | 62 | def __getattr__(self, name): 63 | """Forward unimplemented attributes to one of the original environments. 64 | 65 | Args: 66 | name: Attribute that was accessed. 67 | 68 | Returns: 69 | Value behind the attribute name in one of the original environments. 70 | """ 71 | return getattr(self._batch_env, name) 72 | 73 | def __len__(self): 74 | """Number of combined environments.""" 75 | return len(self._batch_env) 76 | 77 | def __getitem__(self, index): 78 | """Access an underlying environment by index.""" 79 | return self._batch_env[index] 80 | 81 | def step(self, action): 82 | """Step the batch of environments. 83 | 84 | The results of the step can be accessed from the variables defined below. 85 | 86 | Args: 87 | action: Tensor holding the batch of actions to apply. 88 | 89 | Returns: 90 | Operation. 91 | """ 92 | with tf.name_scope('environment/simulate'): 93 | observ_dtype = self._parse_dtype(self._batch_env.observation_space) 94 | observ, reward, done = tf.py_func( 95 | lambda a: self._batch_env.step(a)[:3], [action], 96 | [observ_dtype, tf.float32, tf.bool], name='step') 97 | # reward = tf.cast(reward, tf.float32) 98 | return tf.group( 99 | self._observ.assign(observ), 100 | self._action.assign(action), 101 | self._reward.assign(reward), 102 | self._done.assign(tf.to_int32(done))) 103 | 104 | def reset(self, indices=None): 105 | """Reset the batch of environments. 106 | 107 | Args: 108 | indices: The batch indices of the environments to reset; defaults to all. 109 | 110 | Returns: 111 | Batch tensor of the new observations. 112 | """ 113 | if indices is None: 114 | indices = tf.range(len(self._batch_env)) 115 | observ_dtype = self._parse_dtype(self._batch_env.observation_space) 116 | observ = tf.py_func( 117 | self._batch_env.reset, [indices], observ_dtype, name='reset') 118 | reward = tf.zeros_like(indices, tf.float32) 119 | done = tf.zeros_like(indices, tf.int32) 120 | with tf.control_dependencies([ 121 | tf.scatter_update(self._observ, indices, observ), 122 | tf.scatter_update(self._reward, indices, reward), 123 | tf.scatter_update(self._done, indices, tf.to_int32(done))]): 124 | return tf.identity(observ) 125 | 126 | @property 127 | def observ(self): 128 | """Access the variable holding the current observation.""" 129 | return self._observ + 0 130 | 131 | @property 132 | def action(self): 133 | """Access the variable holding the last received action.""" 134 | return self._action + 0 135 | 136 | @property 137 | def reward(self): 138 | """Access the variable holding the current reward.""" 139 | return self._reward + 0 140 | 141 | @property 142 | def done(self): 143 | """Access the variable indicating whether the episode is done.""" 144 | return tf.cast(self._done, tf.bool) 145 | 146 | def close(self): 147 | """Send close messages to the external process and join them.""" 148 | self._batch_env.close() 149 | 150 | def _parse_shape(self, space): 151 | """Get a tensor shape from a OpenAI Gym space. 152 | 153 | Args: 154 | space: Gym space. 155 | 156 | Raises: 157 | NotImplementedError: For spaces other than Box and Discrete. 158 | 159 | Returns: 160 | Shape tuple. 161 | """ 162 | if isinstance(space, gym.spaces.Discrete): 163 | return () 164 | if isinstance(space, gym.spaces.Box): 165 | return space.shape 166 | raise NotImplementedError("Unsupported space '{}.'".format(space)) 167 | 168 | def _parse_dtype(self, space): 169 | """Get a tensor dtype from a OpenAI Gym space. 170 | 171 | Args: 172 | space: Gym space. 173 | 174 | Raises: 175 | NotImplementedError: For spaces other than Box and Discrete. 176 | 177 | Returns: 178 | TensorFlow data type. 179 | """ 180 | if isinstance(space, gym.spaces.Discrete): 181 | return tf.int32 182 | if isinstance(space, gym.spaces.Box): 183 | if space.low.dtype == np.uint8: 184 | return tf.uint8 185 | else: 186 | return tf.float32 187 | raise NotImplementedError() 188 | -------------------------------------------------------------------------------- /planet/control/mpc_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from tensorflow_probability import distributions as tfd 20 | import tensorflow as tf 21 | 22 | from planet.tools import nested 23 | 24 | 25 | class MPCAgent(object): 26 | 27 | def __init__(self, batch_env, step, is_training, should_log, config): 28 | self._batch_env = batch_env 29 | self._step = step # Trainer step, not environment step. 30 | self._is_training = is_training 31 | self._should_log = should_log 32 | self._config = config 33 | self._cell = config.cell 34 | state = self._cell.zero_state(len(batch_env), tf.float32) 35 | var_like = lambda x: tf.get_local_variable( 36 | x.name.split(':')[0].replace('/', '_') + '_var', 37 | shape=x.shape, 38 | initializer=lambda *_, **__: tf.zeros_like(x), use_resource=True) 39 | self._state = nested.map(var_like, state) 40 | self._prev_action = tf.get_local_variable( 41 | 'prev_action_var', shape=self._batch_env.action.shape, 42 | initializer=lambda *_, **__: tf.zeros_like(self._batch_env.action), 43 | use_resource=True) 44 | 45 | def begin_episode(self, agent_indices): 46 | state = nested.map( 47 | lambda tensor: tf.gather(tensor, agent_indices), 48 | self._state) 49 | reset_state = nested.map( 50 | lambda var, val: tf.scatter_update(var, agent_indices, 0 * val), 51 | self._state, state, flatten=True) 52 | reset_prev_action = self._prev_action.assign( 53 | tf.zeros_like(self._prev_action)) 54 | with tf.control_dependencies(reset_state + (reset_prev_action,)): 55 | return tf.constant('') 56 | 57 | def perform(self, agent_indices, observ): 58 | observ = self._config.preprocess_fn(observ) 59 | embedded = self._config.encoder({'image': observ[:, None]})[:, 0] 60 | state = nested.map( 61 | lambda tensor: tf.gather(tensor, agent_indices), 62 | self._state) 63 | prev_action = self._prev_action + 0 64 | with tf.control_dependencies([prev_action]): 65 | use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None] 66 | _, state = self._cell((embedded, prev_action, use_obs), state) 67 | action = self._config.planner( 68 | self._cell, self._config.objective, state, 69 | embedded.shape[1:].as_list(), 70 | prev_action.shape[1:].as_list()) 71 | action = action[:, 0] 72 | if self._config.exploration: 73 | scale = self._config.exploration.scale 74 | if self._config.exploration.schedule: 75 | scale *= self._config.exploration.schedule(self._step) 76 | action = tfd.Normal(action, scale).sample() 77 | action = tf.clip_by_value(action, -1, 1) 78 | remember_action = self._prev_action.assign(action) 79 | remember_state = nested.map( 80 | lambda var, val: tf.scatter_update(var, agent_indices, val), 81 | self._state, state, flatten=True) 82 | with tf.control_dependencies(remember_state + (remember_action,)): 83 | return tf.identity(action), tf.constant('') 84 | 85 | def experience(self, agent_indices, *experience): 86 | return tf.constant('') 87 | 88 | def end_episode(self, agent_indices): 89 | return tf.constant('') 90 | -------------------------------------------------------------------------------- /planet/control/planning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from planet import tools 22 | 23 | 24 | def cross_entropy_method( 25 | cell, objective_fn, state, obs_shape, action_shape, horizon, graph, 26 | amount=1000, topk=100, iterations=10, min_action=-1, max_action=1): 27 | obs_shape, action_shape = tuple(obs_shape), tuple(action_shape) 28 | original_batch = tools.shape(tools.nested.flatten(state)[0])[0] 29 | initial_state = tools.nested.map(lambda tensor: tf.tile( 30 | tensor, [amount] + [1] * (tensor.shape.ndims - 1)), state) 31 | extended_batch = tools.shape(tools.nested.flatten(initial_state)[0])[0] 32 | use_obs = tf.zeros([extended_batch, horizon, 1], tf.bool) 33 | obs = tf.zeros((extended_batch, horizon) + obs_shape) 34 | 35 | def iteration(mean_and_stddev, _): 36 | mean, stddev = mean_and_stddev 37 | # Sample action proposals from belief. 38 | normal = tf.random_normal((original_batch, amount, horizon) + action_shape) 39 | action = normal * stddev[:, None] + mean[:, None] 40 | action = tf.clip_by_value(action, min_action, max_action) 41 | # Evaluate proposal actions. 42 | action = tf.reshape( 43 | action, (extended_batch, horizon) + action_shape) 44 | (_, state), _ = tf.nn.dynamic_rnn( 45 | cell, (0 * obs, action, use_obs), initial_state=initial_state) 46 | return_ = objective_fn(state) 47 | return_ = tf.reshape(return_, (original_batch, amount)) 48 | # Re-fit belief to the best ones. 49 | _, indices = tf.nn.top_k(return_, topk, sorted=False) 50 | indices += tf.range(original_batch)[:, None] * amount 51 | best_actions = tf.gather(action, indices) 52 | mean, variance = tf.nn.moments(best_actions, 1) 53 | stddev = tf.sqrt(variance + 1e-6) 54 | return mean, stddev 55 | 56 | mean = tf.zeros((original_batch, horizon) + action_shape) 57 | stddev = tf.ones((original_batch, horizon) + action_shape) 58 | if iterations < 1: 59 | return mean 60 | mean, stddev = tf.scan( 61 | iteration, tf.range(iterations), (mean, stddev), back_prop=False) 62 | mean, stddev = mean[-1], stddev[-1] # Select belief at last iterations. 63 | return mean 64 | -------------------------------------------------------------------------------- /planet/control/random_episodes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from planet.control import wrappers 20 | 21 | 22 | def random_episodes(env_ctor, num_episodes, outdir=None): 23 | env = env_ctor() 24 | env = wrappers.CollectGymDataset(env, outdir) 25 | episodes = [] if outdir else None 26 | for _ in range(num_episodes): 27 | policy = lambda env, obs: env.action_space.sample() 28 | done = False 29 | obs = env.reset() 30 | while not done: 31 | action = policy(env, obs) 32 | obs, _, done, info = env.step(action) 33 | if outdir is None: 34 | episodes.append(info['episode']) 35 | try: 36 | env.close() 37 | except AttributeError: 38 | pass 39 | return episodes 40 | -------------------------------------------------------------------------------- /planet/control/simulate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """In-graph simulation step of a vectorized algorithm with environments.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | 23 | import tensorflow as tf 24 | 25 | from planet import tools 26 | from planet.control import batch_env 27 | from planet.control import in_graph_batch_env 28 | from planet.control import mpc_agent 29 | from planet.control import wrappers 30 | from planet.tools import streaming_mean 31 | 32 | 33 | def simulate( 34 | step, env_ctor, duration, num_agents, agent_config, 35 | isolate_envs='none', expensive_summaries=False, 36 | gif_summary=True, name='simulate'): 37 | summaries = [] 38 | with tf.variable_scope(name): 39 | return_, image, action, reward, cleanup = collect_rollouts( 40 | step=step, 41 | env_ctor=env_ctor, 42 | duration=duration, 43 | num_agents=num_agents, 44 | agent_config=agent_config, 45 | isolate_envs=isolate_envs) 46 | return_mean = tf.reduce_mean(return_) 47 | summaries.append(tf.summary.scalar('return', return_mean)) 48 | if expensive_summaries: 49 | summaries.append(tf.summary.histogram('return_hist', return_)) 50 | summaries.append(tf.summary.histogram('reward_hist', reward)) 51 | summaries.append(tf.summary.histogram('action_hist', action)) 52 | summaries.append(tools.image_strip_summary( 53 | 'image', image, max_length=duration)) 54 | if gif_summary: 55 | summaries.append(tools.gif_summary( 56 | 'animation', image, max_outputs=1, fps=20)) 57 | summary = tf.summary.merge(summaries) 58 | return summary, return_mean, cleanup 59 | 60 | 61 | def collect_rollouts( 62 | step, env_ctor, duration, num_agents, agent_config, isolate_envs): 63 | batch_env = define_batch_env(env_ctor, num_agents, isolate_envs) 64 | agent = mpc_agent.MPCAgent(batch_env, step, False, False, agent_config) 65 | cleanup = lambda: batch_env.close() 66 | 67 | def simulate_fn(unused_last, step): 68 | done, score, unused_summary = simulate_step( 69 | batch_env, agent, 70 | log=False, 71 | reset=tf.equal(step, 0)) 72 | with tf.control_dependencies([done, score]): 73 | image = batch_env.observ 74 | batch_action = batch_env.action 75 | batch_reward = batch_env.reward 76 | return done, score, image, batch_action, batch_reward 77 | 78 | initializer = ( 79 | tf.zeros([num_agents], tf.bool), 80 | tf.zeros([num_agents], tf.float32), 81 | 0 * batch_env.observ, 82 | 0 * batch_env.action, 83 | tf.zeros([num_agents], tf.float32)) 84 | done, score, image, action, reward = tf.scan( 85 | simulate_fn, tf.range(duration), 86 | initializer, parallel_iterations=1) 87 | score = tf.boolean_mask(score, done) 88 | image = tf.transpose(image, [1, 0, 2, 3, 4]) 89 | action = tf.transpose(action, [1, 0, 2]) 90 | reward = tf.transpose(reward) 91 | return score, image, action, reward, cleanup 92 | 93 | 94 | def define_batch_env(env_ctor, num_agents, isolate_envs): 95 | with tf.variable_scope('environments'): 96 | if isolate_envs == 'none': 97 | factory = lambda ctor: ctor() 98 | blocking = True 99 | elif isolate_envs == 'thread': 100 | factory = functools.partial(wrappers.Async, strategy='thread') 101 | blocking = False 102 | elif isolate_envs == 'process': 103 | factory = functools.partial(wrappers.Async, strategy='process') 104 | blocking = False 105 | else: 106 | raise NotImplementedError(isolate_envs) 107 | envs = [factory(env_ctor) for _ in range(num_agents)] 108 | env = batch_env.BatchEnv(envs, blocking) 109 | env = in_graph_batch_env.InGraphBatchEnv(env) 110 | return env 111 | 112 | 113 | def simulate_step(batch_env, algo, log=True, reset=False): 114 | """Simulation step of a vectorized algorithm with in-graph environments. 115 | 116 | Integrates the operations implemented by the algorithm and the environments 117 | into a combined operation. 118 | 119 | Args: 120 | batch_env: In-graph batch environment. 121 | algo: Algorithm instance implementing required operations. 122 | log: Tensor indicating whether to compute and return summaries. 123 | reset: Tensor causing all environments to reset. 124 | 125 | Returns: 126 | Tuple of tensors containing done flags for the current episodes, possibly 127 | intermediate scores for the episodes, and a summary tensor. 128 | """ 129 | 130 | def _define_begin_episode(agent_indices): 131 | """Reset environments, intermediate scores and durations for new episodes. 132 | 133 | Args: 134 | agent_indices: Tensor containing batch indices starting an episode. 135 | 136 | Returns: 137 | Summary tensor, new score tensor, and new length tensor. 138 | """ 139 | assert agent_indices.shape.ndims == 1 140 | zero_scores = tf.zeros_like(agent_indices, tf.float32) 141 | zero_durations = tf.zeros_like(agent_indices) 142 | update_score = tf.scatter_update(score_var, agent_indices, zero_scores) 143 | update_length = tf.scatter_update( 144 | length_var, agent_indices, zero_durations) 145 | reset_ops = [ 146 | batch_env.reset(agent_indices), update_score, update_length] 147 | with tf.control_dependencies(reset_ops): 148 | return algo.begin_episode(agent_indices), update_score, update_length 149 | 150 | def _define_step(): 151 | """Request actions from the algorithm and apply them to the environments. 152 | 153 | Increments the lengths of all episodes and increases their scores by the 154 | current reward. After stepping the environments, provides the full 155 | transition tuple to the algorithm. 156 | 157 | Returns: 158 | Summary tensor, new score tensor, and new length tensor. 159 | """ 160 | prevob = batch_env.observ + 0 # Ensure a copy of the variable value. 161 | agent_indices = tf.range(len(batch_env)) 162 | action, step_summary = algo.perform(agent_indices, prevob) 163 | action.set_shape(batch_env.action.shape) 164 | with tf.control_dependencies([batch_env.step(action)]): 165 | add_score = score_var.assign_add(batch_env.reward) 166 | inc_length = length_var.assign_add(tf.ones(len(batch_env), tf.int32)) 167 | with tf.control_dependencies([add_score, inc_length]): 168 | agent_indices = tf.range(len(batch_env)) 169 | experience_summary = algo.experience( 170 | agent_indices, prevob, 171 | batch_env.action, 172 | batch_env.reward, 173 | batch_env.done, 174 | batch_env.observ) 175 | summary = tf.summary.merge([step_summary, experience_summary]) 176 | return summary, add_score, inc_length 177 | 178 | def _define_end_episode(agent_indices): 179 | """Notify the algorithm of ending episodes. 180 | 181 | Also updates the mean score and length counters used for summaries. 182 | 183 | Args: 184 | agent_indices: Tensor holding batch indices that end their episodes. 185 | 186 | Returns: 187 | Summary tensor. 188 | """ 189 | assert agent_indices.shape.ndims == 1 190 | submit_score = mean_score.submit(tf.gather(score, agent_indices)) 191 | submit_length = mean_length.submit( 192 | tf.cast(tf.gather(length, agent_indices), tf.float32)) 193 | with tf.control_dependencies([submit_score, submit_length]): 194 | return algo.end_episode(agent_indices) 195 | 196 | def _define_summaries(): 197 | """Reset the average score and duration, and return them as summary. 198 | 199 | Returns: 200 | Summary string. 201 | """ 202 | score_summary = tf.cond( 203 | tf.logical_and(log, tf.cast(mean_score.count, tf.bool)), 204 | lambda: tf.summary.scalar('mean_score', mean_score.clear()), str) 205 | length_summary = tf.cond( 206 | tf.logical_and(log, tf.cast(mean_length.count, tf.bool)), 207 | lambda: tf.summary.scalar('mean_length', mean_length.clear()), str) 208 | return tf.summary.merge([score_summary, length_summary]) 209 | 210 | with tf.name_scope('simulate'): 211 | log = tf.convert_to_tensor(log) 212 | reset = tf.convert_to_tensor(reset) 213 | with tf.variable_scope('simulate_temporary'): 214 | score_var = tf.get_variable( 215 | 'score', (len(batch_env),), tf.float32, 216 | tf.constant_initializer(0), 217 | trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) 218 | length_var = tf.get_variable( 219 | 'length', (len(batch_env),), tf.int32, 220 | tf.constant_initializer(0), 221 | trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) 222 | mean_score = streaming_mean.StreamingMean((), tf.float32, 'mean_score') 223 | mean_length = streaming_mean.StreamingMean((), tf.float32, 'mean_length') 224 | agent_indices = tf.cond( 225 | reset, 226 | lambda: tf.range(len(batch_env)), 227 | lambda: tf.cast(tf.where(batch_env.done)[:, 0], tf.int32)) 228 | begin_episode, score, length = tf.cond( 229 | tf.cast(tf.shape(agent_indices)[0], tf.bool), 230 | lambda: _define_begin_episode(agent_indices), 231 | lambda: (str(), score_var, length_var)) 232 | with tf.control_dependencies([begin_episode]): 233 | step, score, length = _define_step() 234 | with tf.control_dependencies([step]): 235 | agent_indices = tf.cast(tf.where(batch_env.done)[:, 0], tf.int32) 236 | end_episode = tf.cond( 237 | tf.cast(tf.shape(agent_indices)[0], tf.bool), 238 | lambda: _define_end_episode(agent_indices), str) 239 | with tf.control_dependencies([end_episode]): 240 | summary = tf.summary.merge([ 241 | _define_summaries(), begin_episode, step, end_episode]) 242 | with tf.control_dependencies([summary]): 243 | score = 0.0 + score 244 | done = batch_env.done 245 | return done, score, summary 246 | -------------------------------------------------------------------------------- /planet/control/temporal_difference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Copmute discounted return.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | 23 | def discounted_return(reward, discount, bootstrap, axis, stop_gradient=True): 24 | """Discounted Monte Carlo return.""" 25 | if discount == 1 and bootstrap is None: 26 | return tf.reduce_sum(reward, axis) 27 | if discount == 1: 28 | return tf.reduce_sum(reward, axis) + bootstrap 29 | # Bring the aggregation dimension front. 30 | dims = list(range(reward.shape.ndims)) 31 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 32 | reward = tf.transpose(reward, dims) 33 | if bootstrap is None: 34 | bootstrap = tf.zeros_like(reward[-1]) 35 | return_ = tf.scan( 36 | fn=lambda agg, cur: cur + discount * agg, 37 | elems=reward, 38 | initializer=bootstrap, 39 | back_prop=not stop_gradient, 40 | reverse=True) 41 | return_ = tf.transpose(return_, dims) 42 | if stop_gradient: 43 | return_ = tf.stop_gradient(return_) 44 | return return_ 45 | 46 | 47 | def lambda_return( 48 | reward, value, bootstrap, discount, lambda_, axis, stop_gradient=True): 49 | """Average of different multi-step returns. 50 | 51 | Setting lambda=1 gives a discounted Monte Carlo return. 52 | Setting lambda=0 gives a fixed 1-step return. 53 | """ 54 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) 55 | # Bring the aggregation dimension front. 56 | dims = list(range(reward.shape.ndims)) 57 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 58 | reward = tf.transpose(reward, dims) 59 | value = tf.transpose(value, dims) 60 | if bootstrap is None: 61 | bootstrap = tf.zeros_like(value[-1]) 62 | next_values = tf.concat([value[1:], bootstrap[None]], 0) 63 | inputs = reward + discount * next_values * (1 - lambda_) 64 | return_ = tf.scan( 65 | fn=lambda agg, cur: cur + discount * lambda_ * agg, 66 | elems=inputs, 67 | initializer=bootstrap, 68 | back_prop=not stop_gradient, 69 | reverse=True) 70 | return_ = tf.transpose(return_, dims) 71 | if stop_gradient: 72 | return_ = tf.stop_gradient(return_) 73 | return return_ 74 | 75 | 76 | def fixed_step_return( 77 | reward, value, discount, steps, axis, stop_gradient=True): 78 | """Discounted N-step returns for fixed-length sequences.""" 79 | # Brings the aggregation dimension front. 80 | dims = list(range(reward.shape.ndims)) 81 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 82 | reward = tf.transpose(reward, dims) 83 | length = tf.shape(reward)[0] 84 | _, return_ = tf.while_loop( 85 | cond=lambda i, p: i < steps + 1, 86 | body=lambda i, p: (i + 1, reward[steps - i: length - i] + discount * p), 87 | loop_vars=[tf.constant(1), tf.zeros_like(reward[steps:])], 88 | back_prop=not stop_gradient) 89 | if value is not None: 90 | return_ += discount ** steps * tf.transpose(value, dims)[steps:] 91 | return_ = tf.transpose(return_, dims) 92 | if stop_gradient: 93 | return_ = tf.stop_gradient(return_) 94 | return return_ 95 | -------------------------------------------------------------------------------- /planet/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from .base import Base 20 | from .drnn import DRNN 21 | from .rssm import RSSM 22 | from .ssm import SSM 23 | -------------------------------------------------------------------------------- /planet/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from planet import tools 22 | 23 | 24 | class Base(tf.nn.rnn_cell.RNNCell): 25 | 26 | def __init__(self, transition_tpl, posterior_tpl, reuse=None): 27 | super(Base, self).__init__(_reuse=reuse) 28 | self._posterior_tpl = posterior_tpl 29 | self._transition_tpl = transition_tpl 30 | self._debug = False 31 | 32 | @property 33 | def state_size(self): 34 | raise NotImplementedError 35 | 36 | @property 37 | def updates(self): 38 | return [] 39 | 40 | @property 41 | def losses(self): 42 | return [] 43 | 44 | @property 45 | def output_size(self): 46 | return (self.state_size, self.state_size) 47 | 48 | def zero_state(self, batch_size, dtype): 49 | return tools.nested.map( 50 | lambda size: tf.zeros([batch_size, size], dtype), 51 | self.state_size) 52 | 53 | def call(self, inputs, prev_state): 54 | obs, prev_action, use_obs = inputs 55 | if self._debug: 56 | with tf.control_dependencies([tf.assert_equal(use_obs, use_obs[0, 0])]): 57 | use_obs = tf.identity(use_obs) 58 | use_obs = use_obs[0, 0] 59 | zero_obs = tools.nested.map(tf.zeros_like, obs) 60 | prior = self._transition_tpl(prev_state, prev_action, zero_obs) 61 | posterior = tf.cond( 62 | use_obs, 63 | lambda: self._posterior_tpl(prev_state, prev_action, obs), 64 | lambda: prior) 65 | return (prior, posterior), posterior 66 | -------------------------------------------------------------------------------- /planet/models/drnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | from tensorflow_probability import distributions as tfd 21 | 22 | from planet import tools 23 | from planet.models import base 24 | 25 | 26 | class DRNN(base.Base): 27 | r"""Doubly recurrent state-space model. 28 | 29 | Prior: Posterior: 30 | 31 | (a) (a) (a,o) (a,o) 32 | | | : : : : 33 | v v v v v v 34 | [e]--->[e] [e]...>[e] 35 | | | : : 36 | v v v v 37 | (s)--->(s) (s)--->(s) 38 | | | | | 39 | v v v v 40 | [d]--->[d] [d]--->[d] 41 | | | | | 42 | v v v v 43 | (o) (o) (o) (o) 44 | """ 45 | 46 | def __init__( 47 | self, state_size, belief_size, embed_size, 48 | mean_only=False, min_stddev=1e-1, activation=tf.nn.elu, 49 | encoder_to_decoder=False, sample_to_sample=True, 50 | sample_to_encoder=True, decoder_to_encoder=False, 51 | decoder_to_sample=True, action_to_decoder=False): 52 | self._state_size = state_size 53 | self._belief_size = belief_size 54 | self._embed_size = embed_size 55 | self._encoder_cell = tf.contrib.rnn.GRUBlockCell(self._belief_size) 56 | self._decoder_cell = tf.contrib.rnn.GRUBlockCell(self._belief_size) 57 | self._kwargs = dict(units=self._embed_size, activation=tf.nn.relu) 58 | self._mean_only = mean_only 59 | self._min_stddev = min_stddev 60 | self._encoder_to_decoder = encoder_to_decoder 61 | self._sample_to_sample = sample_to_sample 62 | self._sample_to_encoder = sample_to_encoder 63 | self._decoder_to_encoder = decoder_to_encoder 64 | self._decoder_to_sample = decoder_to_sample 65 | self._action_to_decoder = action_to_decoder 66 | posterior_tpl = tf.make_template('posterior', self._posterior) 67 | super(DRNN, self).__init__(posterior_tpl, posterior_tpl) 68 | 69 | @property 70 | def state_size(self): 71 | return { 72 | 'encoder_state': self._encoder_cell.state_size, 73 | 'decoder_state': self._decoder_cell.state_size, 74 | 'mean': self._state_size, 75 | 'stddev': self._state_size, 76 | 'sample': self._state_size, 77 | } 78 | 79 | def dist_from_state(self, state, mask=None): 80 | """Extract the latent distribution from a prior or posterior state.""" 81 | if mask is not None: 82 | stddev = tools.mask(state['stddev'], mask, value=1) 83 | else: 84 | stddev = state['stddev'] 85 | dist = tfd.MultivariateNormalDiag(state['mean'], stddev) 86 | return dist 87 | 88 | def features_from_state(self, state): 89 | """Extract features for the decoder network from a prior or posterior.""" 90 | return state['decoder_state'] 91 | 92 | def divergence_from_states(self, lhs, rhs, mask=None): 93 | """Compute the divergence measure between two states.""" 94 | lhs = self.dist_from_state(lhs, mask) 95 | rhs = self.dist_from_state(rhs, mask) 96 | divergence = tfd.kl_divergence(lhs, rhs) 97 | if mask is not None: 98 | divergence = tools.mask(divergence, mask) 99 | return divergence 100 | 101 | def _posterior(self, prev_state, prev_action, obs): 102 | """Compute posterior state from previous state and current observation.""" 103 | 104 | # Recurrent encoder. 105 | encoder_inputs = [obs, prev_action] 106 | if self._sample_to_encoder: 107 | encoder_inputs.append(prev_state['sample']) 108 | if self._decoder_to_encoder: 109 | encoder_inputs.append(prev_state['decoder_state']) 110 | encoded, encoder_state = self._encoder_cell( 111 | tf.concat(encoder_inputs, -1), prev_state['encoder_state']) 112 | 113 | # Sample sequence. 114 | sample_inputs = [encoded] 115 | if self._sample_to_sample: 116 | sample_inputs.append(prev_state['sample']) 117 | if self._decoder_to_sample: 118 | sample_inputs.append(prev_state['decoder_state']) 119 | hidden = tf.layers.dense( 120 | tf.concat(sample_inputs, -1), **self._kwargs) 121 | mean = tf.layers.dense(hidden, self._state_size, None) 122 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) 123 | stddev += self._min_stddev 124 | if self._mean_only: 125 | sample = mean 126 | else: 127 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample() 128 | 129 | # Recurrent decoder. 130 | decoder_inputs = [sample] 131 | if self._encoder_to_decoder: 132 | decoder_inputs.append(prev_state['encoder_state']) 133 | if self._action_to_decoder: 134 | decoder_inputs.append(prev_action) 135 | decoded, decoder_state = self._decoder_cell( 136 | tf.concat(decoder_inputs, -1), prev_state['decoder_state']) 137 | 138 | return { 139 | 'encoder_state': encoder_state, 140 | 'decoder_state': decoder_state, 141 | 'mean': mean, 142 | 'stddev': stddev, 143 | 'sample': sample, 144 | } 145 | -------------------------------------------------------------------------------- /planet/models/rssm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | from tensorflow_probability import distributions as tfd 21 | 22 | from planet import tools 23 | from planet.models import base 24 | 25 | 26 | class RSSM(base.Base): 27 | """Deterministic and stochastic state model. 28 | 29 | The stochastic latent is computed from the hidden state at the same time 30 | step. If an observation is present, the posterior latent is compute from both 31 | the hidden state and the observation. 32 | 33 | Prior: Posterior: 34 | 35 | (a) (a) 36 | \ \ 37 | v v 38 | [h]->[h] [h]->[h] 39 | ^ | ^ : 40 | / v / v 41 | (s) (s) (s) (s) 42 | ^ 43 | : 44 | (o) 45 | """ 46 | 47 | def __init__( 48 | self, state_size, belief_size, embed_size, 49 | future_rnn=True, mean_only=False, min_stddev=0.1, activation=tf.nn.elu, 50 | num_layers=1): 51 | self._state_size = state_size 52 | self._belief_size = belief_size 53 | self._embed_size = embed_size 54 | self._future_rnn = future_rnn 55 | self._cell = tf.contrib.rnn.GRUBlockCell(self._belief_size) 56 | self._kwargs = dict(units=self._embed_size, activation=activation) 57 | self._mean_only = mean_only 58 | self._min_stddev = min_stddev 59 | self._num_layers = num_layers 60 | super(RSSM, self).__init__( 61 | tf.make_template('transition', self._transition), 62 | tf.make_template('posterior', self._posterior)) 63 | 64 | @property 65 | def state_size(self): 66 | return { 67 | 'mean': self._state_size, 68 | 'stddev': self._state_size, 69 | 'sample': self._state_size, 70 | 'belief': self._belief_size, 71 | 'rnn_state': self._belief_size, 72 | } 73 | 74 | def dist_from_state(self, state, mask=None): 75 | """Extract the latent distribution from a prior or posterior state.""" 76 | if mask is not None: 77 | stddev = tools.mask(state['stddev'], mask, value=1) 78 | else: 79 | stddev = state['stddev'] 80 | dist = tfd.MultivariateNormalDiag(state['mean'], stddev) 81 | return dist 82 | 83 | def features_from_state(self, state): 84 | """Extract features for the decoder network from a prior or posterior.""" 85 | return tf.concat([state['sample'], state['belief']], -1) 86 | 87 | def divergence_from_states(self, lhs, rhs, mask=None): 88 | """Compute the divergence measure between two states.""" 89 | lhs = self.dist_from_state(lhs, mask) 90 | rhs = self.dist_from_state(rhs, mask) 91 | divergence = tfd.kl_divergence(lhs, rhs) 92 | if mask is not None: 93 | divergence = tools.mask(divergence, mask) 94 | return divergence 95 | 96 | def _transition(self, prev_state, prev_action, zero_obs): 97 | """Compute prior next state by applying the transition dynamics.""" 98 | hidden = tf.concat([prev_state['sample'], prev_action], -1) 99 | for _ in range(self._num_layers): 100 | hidden = tf.layers.dense(hidden, **self._kwargs) 101 | belief, rnn_state = self._cell(hidden, prev_state['rnn_state']) 102 | if self._future_rnn: 103 | hidden = belief 104 | for _ in range(self._num_layers): 105 | hidden = tf.layers.dense(hidden, **self._kwargs) 106 | mean = tf.layers.dense(hidden, self._state_size, None) 107 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) 108 | stddev += self._min_stddev 109 | if self._mean_only: 110 | sample = mean 111 | else: 112 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample() 113 | return { 114 | 'mean': mean, 115 | 'stddev': stddev, 116 | 'sample': sample, 117 | 'belief': belief, 118 | 'rnn_state': rnn_state, 119 | } 120 | 121 | def _posterior(self, prev_state, prev_action, obs): 122 | """Compute posterior state from previous state and current observation.""" 123 | prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs)) 124 | hidden = tf.concat([prior['belief'], obs], -1) 125 | for _ in range(self._num_layers): 126 | hidden = tf.layers.dense(hidden, **self._kwargs) 127 | mean = tf.layers.dense(hidden, self._state_size, None) 128 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) 129 | stddev += self._min_stddev 130 | if self._mean_only: 131 | sample = mean 132 | else: 133 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample() 134 | return { 135 | 'mean': mean, 136 | 'stddev': stddev, 137 | 'sample': sample, 138 | 'belief': prior['belief'], 139 | 'rnn_state': prior['rnn_state'], 140 | } 141 | -------------------------------------------------------------------------------- /planet/models/ssm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | from tensorflow_probability import distributions as tfd 21 | 22 | from planet import tools 23 | from planet.models import base 24 | 25 | 26 | class SSM(base.Base): 27 | """Gaussian state space model. 28 | 29 | Implements the transition function and encoder using feed forward networks. 30 | 31 | Prior: Posterior: 32 | 33 | (a) (a) 34 | \ \ 35 | v v 36 | (s)->(s) (s)->(s) 37 | ^ 38 | : 39 | (o) 40 | """ 41 | 42 | def __init__( 43 | self, state_size, embed_size, 44 | mean_only=False, activation=tf.nn.elu, min_stddev=1e-5): 45 | self._state_size = state_size 46 | self._embed_size = embed_size 47 | self._mean_only = mean_only 48 | self._min_stddev = min_stddev 49 | super(SSM, self).__init__( 50 | tf.make_template('transition', self._transition), 51 | tf.make_template('posterior', self._posterior)) 52 | self._kwargs = dict(units=self._embed_size, activation=activation) 53 | 54 | @property 55 | def state_size(self): 56 | return { 57 | 'mean': self._state_size, 58 | 'stddev': self._state_size, 59 | 'sample': self._state_size, 60 | } 61 | 62 | def dist_from_state(self, state, mask=None): 63 | """Extract the latent distribution from a prior or posterior state.""" 64 | if mask is not None: 65 | stddev = tools.mask(state['stddev'], mask, value=1) 66 | else: 67 | stddev = state['stddev'] 68 | dist = tfd.MultivariateNormalDiag(state['mean'], stddev) 69 | return dist 70 | 71 | def features_from_state(self, state): 72 | """Extract features for the decoder network from a prior or posterior.""" 73 | return state['sample'] 74 | 75 | def divergence_from_states(self, lhs, rhs, mask=None): 76 | """Compute the divergence measure between two states.""" 77 | lhs = self.dist_from_state(lhs, mask) 78 | rhs = self.dist_from_state(rhs, mask) 79 | divergence = tfd.kl_divergence(lhs, rhs) 80 | if mask is not None: 81 | divergence = tools.mask(divergence, mask) 82 | return divergence 83 | 84 | def _transition(self, prev_state, prev_action, zero_obs): 85 | """Compute prior next state by applying the transition dynamics.""" 86 | inputs = tf.concat([prev_state['sample'], prev_action], -1) 87 | hidden = tf.layers.dense(inputs, **self._kwargs) 88 | mean = tf.layers.dense(hidden, self._state_size, None) 89 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) 90 | stddev += self._min_stddev 91 | if self._mean_only: 92 | sample = mean 93 | else: 94 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample() 95 | return { 96 | 'mean': mean, 97 | 'stddev': stddev, 98 | 'sample': sample, 99 | } 100 | 101 | def _posterior(self, prev_state, prev_action, obs): 102 | """Compute posterior state from previous state and current observation.""" 103 | prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs)) 104 | inputs = tf.concat([prior['mean'], prior['stddev'], obs], -1) 105 | hidden = tf.layers.dense(inputs, **self._kwargs) 106 | mean = tf.layers.dense(hidden, self._state_size, None) 107 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus) 108 | stddev += self._min_stddev 109 | if self._mean_only: 110 | sample = mean 111 | else: 112 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample() 113 | return { 114 | 'mean': mean, 115 | 'stddev': stddev, 116 | 'sample': sample, 117 | } 118 | -------------------------------------------------------------------------------- /planet/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import conv_ha 20 | from .basic import feed_forward 21 | -------------------------------------------------------------------------------- /planet/networks/basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | import tensorflow_probability as tfp 22 | from tensorflow_probability import distributions as tfd 23 | 24 | from planet import tools 25 | 26 | 27 | def feed_forward( 28 | state, data_shape, num_layers=2, activation=tf.nn.relu, 29 | mean_activation=None, stop_gradient=False, trainable=True, units=100, 30 | std=1.0, low=-1.0, high=1.0, dist='normal'): 31 | """Create a model returning unnormalized MSE distribution.""" 32 | hidden = state 33 | if stop_gradient: 34 | hidden = tf.stop_gradient(hidden) 35 | for _ in range(num_layers): 36 | hidden = tf.layers.dense(hidden, units, activation) 37 | mean = tf.layers.dense( 38 | hidden, int(np.prod(data_shape)), mean_activation, trainable=trainable) 39 | mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape) 40 | if std == 'learned': 41 | std = tf.layers.dense( 42 | hidden, int(np.prod(data_shape)), None, trainable=trainable) 43 | std = tf.nn.softplus(std + 0.55) + 0.01 44 | std = tf.reshape(std, tools.shape(state)[:-1] + data_shape) 45 | if dist == 'normal': 46 | dist = tfd.Normal(mean, std) 47 | elif dist == 'truncated_normal': 48 | # https://www.desmos.com/calculator/3o96eyqxib 49 | dist = tfd.TruncatedNormal(mean, std, low, high) 50 | elif dist == 'tanh_normal': 51 | # https://www.desmos.com/calculator/sxpp7ectjv 52 | dist = tfd.Normal(mean, std) 53 | dist = tfd.TransformedDistribution(dist, tfp.bijectors.Tanh()) 54 | elif dist == 'deterministic': 55 | dist = tfd.Deterministic(mean) 56 | else: 57 | raise NotImplementedError(dist) 58 | dist = tfd.Independent(dist, len(data_shape)) 59 | return dist 60 | -------------------------------------------------------------------------------- /planet/networks/conv_ha.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow_probability import distributions as tfd 22 | 23 | from planet import tools 24 | 25 | 26 | def encoder(obs): 27 | """Extract deterministic features from an observation.""" 28 | kwargs = dict(strides=2, activation=tf.nn.relu) 29 | hidden = tf.reshape(obs['image'], [-1] + obs['image'].shape[2:].as_list()) 30 | hidden = tf.layers.conv2d(hidden, 32, 4, **kwargs) 31 | hidden = tf.layers.conv2d(hidden, 64, 4, **kwargs) 32 | hidden = tf.layers.conv2d(hidden, 128, 4, **kwargs) 33 | hidden = tf.layers.conv2d(hidden, 256, 4, **kwargs) 34 | hidden = tf.layers.flatten(hidden) 35 | assert hidden.shape[1:].as_list() == [1024], hidden.shape.as_list() 36 | hidden = tf.reshape(hidden, tools.shape(obs['image'])[:2] + [ 37 | np.prod(hidden.shape[1:].as_list())]) 38 | return hidden 39 | 40 | 41 | def decoder(state, data_shape): 42 | """Compute the data distribution of an observation from its state.""" 43 | kwargs = dict(strides=2, activation=tf.nn.relu) 44 | hidden = tf.layers.dense(state, 1024, None) 45 | hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1].value]) 46 | hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs) 47 | hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs) 48 | hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs) 49 | mean = tf.layers.conv2d_transpose(hidden, 3, 6, strides=2) 50 | assert mean.shape[1:].as_list() == [64, 64, 3], mean.shape 51 | mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape) 52 | dist = tfd.Normal(mean, 1.0) 53 | dist = tfd.Independent(dist, len(data_shape)) 54 | return dist 55 | -------------------------------------------------------------------------------- /planet/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import configs 20 | from . import tasks 21 | from . import train 22 | -------------------------------------------------------------------------------- /planet/scripts/create_video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import os 21 | import shutil 22 | 23 | from matplotlib import animation 24 | import matplotlib as mpl 25 | import matplotlib.pyplot as plt 26 | import scipy.misc 27 | 28 | 29 | def create_animation(frames, size, fps=10, **kwargs): 30 | fig = plt.figure(figsize=size, frameon=False) 31 | ax = fig.add_axes([0, 0, 1, 1]) 32 | img = ax.imshow(frames[0], **kwargs) 33 | ax.set_xticks([]) 34 | ax.set_yticks([]) 35 | callback = lambda frame: (img, img.set_data(frame))[:1] 36 | kwargs = dict(frames=frames, interval=1000 / fps, blit=True) 37 | anim = animation.FuncAnimation(fig, callback, **kwargs) 38 | return anim 39 | 40 | 41 | def save_animation(filepath, anim, fps=10, overwrite=False): 42 | if filepath.endswith('.mp4'): 43 | if not shutil.which('ffmpeg'): 44 | raise RuntimeError("The 'ffmpeg' executable must be in PATH.") 45 | mpl.rcParams['animation.ffmpeg_path'] = 'ffmpeg' 46 | kwargs = dict(fps=fps, writer='ffmpeg', extra_args=['-vcodec', 'libx264']) 47 | elif filepath.endswith('.gif'): 48 | kwargs = dict(fps=fps, writer='imagemagick') 49 | else: 50 | message = "Unknown video format; filename should end in '.mp4' or '.gif'." 51 | raise NotImplementedError(message) 52 | if os.path.exists(filepath) and not overwrite: 53 | message = 'Skip rendering animation because {} already exists.' 54 | print(message.format(filepath)) 55 | return 56 | print('Render animation to {}.'.format(filepath)) 57 | anim.save(filepath, **kwargs) 58 | 59 | 60 | def unpack_image_strip(image, tile_width, tile_height): 61 | image = image.reshape(( 62 | image.shape[0] // tile_height, 63 | tile_height, 64 | image.shape[1] // tile_width, 65 | tile_width, 66 | image.shape[2], 67 | )) 68 | image = image.transpose((0, 2, 1, 3, 4)) 69 | return image 70 | 71 | 72 | def pack_animation_frames(image): 73 | image = image.transpose((1, 2, 0, 3, 4)) 74 | image = image.reshape(( 75 | image.shape[0], 76 | image.shape[1], 77 | image.shape[2] * image.shape[3], 78 | image.shape[4], 79 | )) 80 | return image 81 | 82 | 83 | def main(args): 84 | image = scipy.misc.imread(args.image_path) 85 | image = unpack_image_strip(image, args.tile_width, args.tile_height) 86 | frames = pack_animation_frames(image) 87 | size = frames.shape[2] / args.dpi, frames.shape[1] / args.dpi 88 | anim = create_animation(frames, size, args.fps) 89 | save_animation(args.animation_path, anim, args.fps, args.overwrite) 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('-i', '--image-path', required=True) 95 | parser.add_argument('-o', '--animation-path', required=True) 96 | parser.add_argument('-f', '--overwrite', action='store_true', default=False) 97 | parser.add_argument('-x', '--tile-width', type=int, default=64) 98 | parser.add_argument('-y', '--tile-height', type=int, default=64) 99 | parser.add_argument('-r', '--fps', type=int, default=10) 100 | parser.add_argument('-d', '--dpi', type=float, default=50) 101 | args = parser.parse_args() 102 | main(args) 103 | -------------------------------------------------------------------------------- /planet/scripts/fetch_events.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import csv 21 | import fnmatch 22 | import functools 23 | import multiprocessing.dummy as multiprocessing 24 | import os 25 | import re 26 | import sys 27 | import traceback 28 | 29 | # import imageio 30 | import numpy as np 31 | import skimage.io 32 | import tensorflow as tf 33 | from tensorboard.backend.event_processing import ( 34 | plugin_event_multiplexer as event_multiplexer) 35 | 36 | 37 | lock = multiprocessing.Lock() 38 | 39 | 40 | def safe_print(*args, **kwargs): 41 | with lock: 42 | print(*args, **kwargs) 43 | 44 | 45 | def create_reader(logdir): 46 | reader = event_multiplexer.EventMultiplexer() 47 | reader.AddRun(logdir, 'run') 48 | reader.Reload() 49 | return reader 50 | 51 | 52 | def extract_values(reader, tag): 53 | events = reader.Tensors('run', tag) 54 | steps = [event.step for event in events] 55 | times = [event.wall_time for event in events] 56 | values = [tf.make_ndarray(event.tensor_proto) for event in events] 57 | return steps, times, values 58 | 59 | 60 | def export_scalar(basename, steps, times, values): 61 | safe_print('Writing', basename + '.csv') 62 | values = [value.item() for value in values] 63 | with tf.gfile.Open(basename + '.csv', 'w') as outfile: 64 | writer = csv.writer(outfile) 65 | writer.writerow(('wall_time', 'step', 'value')) 66 | for row in zip(times, steps, values): 67 | writer.writerow(row) 68 | 69 | 70 | def export_image(basename, steps, times, values): 71 | tf.reset_default_graph() 72 | tf_string = tf.placeholder(tf.string) 73 | tf_tensor = tf.image.decode_image(tf_string) 74 | with tf.Session() as sess: 75 | for step, time_, value in zip(steps, times, values): 76 | filename = '{}-{}-{}.png'.format(basename, step, time_) 77 | width, height, string = value[0], value[1], value[2] 78 | del width 79 | del height 80 | tensor = sess.run(tf_tensor, {tf_string: string}) 81 | # imageio.imsave(filename, tensor) 82 | skimage.io.imsave(filename, tensor) 83 | filename = '{}-{}-{}.npy'.format(basename, step, time_) 84 | np.save(filename, tensor) 85 | 86 | 87 | def process_logdir(logdir, args): 88 | clean = lambda text: re.sub('[^A-Za-z0-9_]', '_', text) 89 | basename = os.path.join(args.outdir, clean(logdir)) 90 | if len(tf.gfile.Glob(basename + '*')) > 0 and not args.force: 91 | safe_print('Exists', logdir) 92 | return 93 | try: 94 | safe_print('Start', logdir) 95 | reader = create_reader(logdir) 96 | for tag in reader.Runs()['run']['tensors']: # tensors -> scalars 97 | if fnmatch.fnmatch(tag, args.tags): 98 | steps, times, values = extract_values(reader, tag) 99 | filename = '{}___{}'.format(basename, clean(tag)) 100 | export_scalar(filename, steps, times, values) 101 | # for tag in tags['images']: 102 | # if fnmatch.fnmatch(tag, args.tags): 103 | # steps, times, values = extract_values(reader, tag) 104 | # filename = '{}___{}'.format(basename, clean(tag)) 105 | # export_image(filename, steps, times, values) 106 | del reader 107 | safe_print('Done', logdir) 108 | except Exception: 109 | safe_print('Exception', logdir) 110 | safe_print(traceback.print_exc()) 111 | 112 | 113 | def main(args): 114 | logdirs = tf.gfile.Glob(args.logdirs) 115 | print(len(logdirs), 'logdirs.') 116 | assert logdirs 117 | tf.gfile.MakeDirs(args.outdir) 118 | np.random.shuffle(logdirs) 119 | pool = multiprocessing.Pool(args.workers) 120 | worker_fn = functools.partial(process_logdir, args=args) 121 | pool.map(worker_fn, logdirs) 122 | 123 | 124 | if __name__ == '__main__': 125 | boolean = lambda x: ['False', 'True'].index(x) 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument( 128 | '--logdirs', required=True, 129 | help='glob for log directories to fetch') 130 | parser.add_argument( 131 | '--tags', default='trainer*score', 132 | help='glob for tags to save') 133 | parser.add_argument( 134 | '--outdir', required=True, 135 | help='output directory to store values') 136 | parser.add_argument( 137 | '--force', type=boolean, default=False, 138 | help='overwrite existing files') 139 | parser.add_argument( 140 | '--workers', type=int, default=10, 141 | help='number of worker threads') 142 | args_, remaining = parser.parse_known_args() 143 | args_.logdirs = os.path.expanduser(args_.logdirs) 144 | args_.outdir = os.path.expanduser(args_.outdir) 145 | remaining.insert(0, sys.argv[0]) 146 | tf.app.run(lambda _: main(args_), remaining) 147 | -------------------------------------------------------------------------------- /planet/scripts/objectives.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def reward(state, graph, params): 23 | features = graph.cell.features_from_state(state) 24 | reward = graph.heads.reward(features).mean() 25 | return tf.reduce_sum(reward, 1) 26 | -------------------------------------------------------------------------------- /planet/scripts/sync.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | # Copyright 2019 The PlaNet Authors. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import glob 19 | import os 20 | import shutil 21 | 22 | 23 | def find_source_files(directory): 24 | top_level = glob.glob(directory + '*.py') 25 | recursive = glob.glob(directory + '/**/*.py') 26 | return top_level + recursive 27 | 28 | 29 | def copy_source_tree(source_dir, target_dir): 30 | for source in find_source_files(source_dir): 31 | target = os.path.join(target_dir, os.path.relpath(source, source_dir)) 32 | os.makedirs(os.path.dirname(target), exist_ok=True) 33 | if os.path.exists(target): 34 | print('Override', os.path.relpath(target, target_dir)) 35 | else: 36 | print('Add', os.path.relpath(target, target_dir)) 37 | shutil.copy(source, target) 38 | for target in find_source_files(target_dir): 39 | source = os.path.join(source_dir, os.path.relpath(target, target_dir)) 40 | if not os.path.exists(source): 41 | print('Remove', os.path.relpath(target, target_dir)) 42 | os.remove(target) 43 | 44 | 45 | def infer_headers(directory): 46 | try: 47 | filename = find_source_files(directory)[0] 48 | except IndexError: 49 | raise RuntimeError('No code files found in {}.'.format(directory)) 50 | header = [] 51 | with open(filename, 'r') as f: 52 | for index, line in enumerate(f): 53 | if index == 0 and not line.startswith('#'): 54 | break 55 | if not line.startswith('#') and line.strip(' \n'): 56 | break 57 | header.append(line) 58 | return header 59 | 60 | 61 | def add_headers(directory, header): 62 | for filename in find_source_files(directory): 63 | with open(filename, 'r') as f: 64 | text = f.readlines() 65 | with open(filename, 'w') as f: 66 | f.write(''.join(header + text)) 67 | 68 | 69 | def remove_headers(directory, header): 70 | for filename in find_source_files(directory): 71 | with open(filename, 'r') as f: 72 | text = f.readlines() 73 | if text[:len(header)] == header: 74 | text = text[len(header):] 75 | with open(filename, 'w') as f: 76 | f.write(''.join(text)) 77 | 78 | 79 | def main(args): 80 | print('Inferring headers.\n') 81 | source_header = infer_headers(args.source) 82 | print('{} Source header {}\n{}{}\n'.format( 83 | '-' * 32, '-' * 32, ''.join(source_header), '-' * 79)) 84 | target_header = infer_headers(args.target) 85 | print('{} Target header {}\n{}{}\n'.format( 86 | '-' * 32, '-' * 32, ''.join(target_header), '-' * 79)) 87 | print('Synchronizing directories.') 88 | copy_source_tree(args.source, args.target) 89 | if target_header and not source_header: 90 | print('Adding headers.') 91 | add_headers(args.target, target_header) 92 | if source_header and not target_header: 93 | print('Removing headers.') 94 | remove_headers(args.target, source_header) 95 | print('Done.') 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--source', type=os.path.expanduser, required=True) 101 | parser.add_argument('--target', type=os.path.expanduser, required=True) 102 | main(parser.parse_args()) 103 | -------------------------------------------------------------------------------- /planet/scripts/tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | 21 | import numpy as np 22 | 23 | from planet import control 24 | from planet import tools 25 | 26 | 27 | Task = collections.namedtuple( 28 | 'Task', 'name, env_ctor, max_length, state_components') 29 | 30 | 31 | def dummy(config, params): 32 | action_repeat = params.get('action_repeat', 1) 33 | max_length = 1000 // action_repeat 34 | state_components = ['reward'] 35 | env_ctor = lambda: control.wrappers.ActionRepeat( 36 | control.DummyEnv, action_repeat) 37 | return Task('dummy', env_ctor, max_length, state_components) 38 | 39 | 40 | def cartpole_balance(config, params): 41 | action_repeat = params.get('action_repeat', 8) 42 | max_length = 1000 // action_repeat 43 | state_components = ['reward', 'position', 'velocity'] 44 | env_ctor = tools.bind( 45 | _dm_control_env, action_repeat, max_length, 'cartpole', 'balance', 46 | params) 47 | return Task('cartpole_balance', env_ctor, max_length, state_components) 48 | 49 | 50 | def cartpole_swingup(config, params): 51 | action_repeat = params.get('action_repeat', 8) 52 | max_length = 1000 // action_repeat 53 | state_components = ['reward', 'position', 'velocity'] 54 | env_ctor = tools.bind( 55 | _dm_control_env, action_repeat, max_length, 'cartpole', 'swingup', 56 | params) 57 | return Task('cartpole_swingup', env_ctor, max_length, state_components) 58 | 59 | 60 | def finger_spin(config, params): 61 | action_repeat = params.get('action_repeat', 2) 62 | max_length = 1000 // action_repeat 63 | state_components = ['reward', 'position', 'velocity', 'touch'] 64 | env_ctor = tools.bind( 65 | _dm_control_env, action_repeat, max_length, 'finger', 'spin', params) 66 | return Task('finger_spin', env_ctor, max_length, state_components) 67 | 68 | 69 | def cheetah_run(config, params): 70 | action_repeat = params.get('action_repeat', 4) 71 | max_length = 1000 // action_repeat 72 | state_components = ['reward', 'position', 'velocity'] 73 | env_ctor = tools.bind( 74 | _dm_control_env, action_repeat, max_length, 'cheetah', 'run', params) 75 | return Task('cheetah_run', env_ctor, max_length, state_components) 76 | 77 | 78 | def cup_catch(config, params): 79 | action_repeat = params.get('action_repeat', 4) 80 | max_length = 1000 // action_repeat 81 | state_components = ['reward', 'position', 'velocity'] 82 | env_ctor = tools.bind( 83 | _dm_control_env, action_repeat, max_length, 'ball_in_cup', 'catch', 84 | params) 85 | return Task('cup_catch', env_ctor, max_length, state_components) 86 | 87 | 88 | def walker_walk(config, params): 89 | action_repeat = params.get('action_repeat', 2) 90 | max_length = 1000 // action_repeat 91 | state_components = ['reward', 'height', 'orientations', 'velocity'] 92 | env_ctor = tools.bind( 93 | _dm_control_env, action_repeat, max_length, 'walker', 'walk', params) 94 | return Task('walker_walk', env_ctor, max_length, state_components) 95 | 96 | 97 | def reacher_easy(config, params): 98 | action_repeat = params.get('action_repeat', 4) 99 | max_length = 1000 // action_repeat 100 | state_components = ['reward', 'position', 'velocity', 'to_target'] 101 | env_ctor = tools.bind( 102 | _dm_control_env, action_repeat, max_length, 'reacher', 'easy', params) 103 | return Task('reacher_easy', env_ctor, max_length, state_components) 104 | 105 | 106 | def gym_cheetah(config, params): 107 | # Works with `isolate_envs: process`. 108 | action_repeat = params.get('action_repeat', 1) 109 | max_length = 1000 // action_repeat 110 | state_components = ['reward', 'state'] 111 | env_ctor = tools.bind( 112 | _gym_env, action_repeat, config.batch_shape[1], max_length, 113 | 'HalfCheetah-v3') 114 | return Task('gym_cheetah', env_ctor, max_length, state_components) 115 | 116 | 117 | def gym_racecar(config, params): 118 | # Works with `isolate_envs: thread`. 119 | action_repeat = params.get('action_repeat', 1) 120 | max_length = 1000 // action_repeat 121 | state_components = ['reward'] 122 | env_ctor = tools.bind( 123 | _gym_env, action_repeat, config.batch_shape[1], max_length, 124 | 'CarRacing-v0', obs_is_image=True) 125 | return Task('gym_racing', env_ctor, max_length, state_components) 126 | 127 | 128 | def _dm_control_env( 129 | action_repeat, max_length, domain, task, params, normalize=False, 130 | camera_id=None): 131 | if isinstance(domain, str): 132 | from dm_control import suite 133 | env = suite.load(domain, task) 134 | else: 135 | assert task is None 136 | env = domain() 137 | if camera_id is None: 138 | camera_id = int(params.get('camera_id', 0)) 139 | env = control.wrappers.DeepMindWrapper(env, (64, 64), camera_id=camera_id) 140 | env = control.wrappers.ActionRepeat(env, action_repeat) 141 | if normalize: 142 | env = control.wrappers.NormalizeActions(env) 143 | env = control.wrappers.MaximumDuration(env, max_length) 144 | env = control.wrappers.PixelObservations(env, (64, 64), np.uint8, 'image') 145 | env = control.wrappers.ConvertTo32Bit(env) 146 | return env 147 | 148 | 149 | def _gym_env(action_repeat, min_length, max_length, name, obs_is_image=False): 150 | import gym 151 | env = gym.make(name) 152 | env = control.wrappers.ActionRepeat(env, action_repeat) 153 | env = control.wrappers.NormalizeActions(env) 154 | env = control.wrappers.MinimumDuration(env, min_length) 155 | env = control.wrappers.MaximumDuration(env, max_length) 156 | if obs_is_image: 157 | env = control.wrappers.ObservationDict(env, 'image') 158 | env = control.wrappers.ObservationToRender(env) 159 | else: 160 | env = control.wrappers.ObservationDict(env, 'state') 161 | env = control.wrappers.PixelObservations(env, (64, 64), np.uint8, 'image') 162 | env = control.wrappers.ConvertTo32Bit(env) 163 | return env 164 | -------------------------------------------------------------------------------- /planet/scripts/test_planet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import sys 21 | 22 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname( 23 | os.path.abspath(__file__))))) 24 | 25 | import tensorflow as tf 26 | 27 | from planet import tools 28 | from planet.scripts import train 29 | 30 | 31 | class PlanetTest(tf.test.TestCase): 32 | 33 | def test_dummy_isolate_none(self): 34 | args = tools.AttrDict( 35 | logdir=self.get_temp_dir(), 36 | num_runs=1, 37 | config='debug', 38 | params=tools.AttrDict( 39 | task='dummy', 40 | isolate_envs='none', 41 | max_steps=30), 42 | ping_every=0, 43 | resume_runs=False) 44 | try: 45 | tf.app.run(lambda _: train.main(args), [sys.argv[0]]) 46 | except SystemExit: 47 | pass 48 | 49 | def test_dummy_isolate_thread(self): 50 | args = tools.AttrDict( 51 | logdir=self.get_temp_dir(), 52 | num_runs=1, 53 | config='debug', 54 | params=tools.AttrDict( 55 | task='dummy', 56 | isolate_envs='thread', 57 | max_steps=30), 58 | ping_every=0, 59 | resume_runs=False) 60 | try: 61 | tf.app.run(lambda _: train.main(args), [sys.argv[0]]) 62 | except SystemExit: 63 | pass 64 | 65 | def test_dummy_isolate_process(self): 66 | args = tools.AttrDict( 67 | logdir=self.get_temp_dir(), 68 | num_runs=1, 69 | config='debug', 70 | params=tools.AttrDict( 71 | task='dummy', 72 | isolate_envs='process', 73 | max_steps=30), 74 | ping_every=0, 75 | resume_runs=False) 76 | try: 77 | tf.app.run(lambda _: train.main(args), [sys.argv[0]]) 78 | except SystemExit: 79 | pass 80 | 81 | def test_dm_control_isolate_none(self): 82 | args = tools.AttrDict( 83 | logdir=self.get_temp_dir(), 84 | num_runs=1, 85 | config='debug', 86 | params=tools.AttrDict( 87 | task='cup_catch', 88 | isolate_envs='none', 89 | max_steps=30), 90 | ping_every=0, 91 | resume_runs=False) 92 | try: 93 | tf.app.run(lambda _: train.main(args), [sys.argv[0]]) 94 | except SystemExit: 95 | pass 96 | 97 | def test_dm_control_isolate_thread(self): 98 | args = tools.AttrDict( 99 | logdir=self.get_temp_dir(), 100 | num_runs=1, 101 | config='debug', 102 | params=tools.AttrDict( 103 | task='cup_catch', 104 | isolate_envs='thread', 105 | max_steps=30), 106 | ping_every=0, 107 | resume_runs=False) 108 | try: 109 | tf.app.run(lambda _: train.main(args), [sys.argv[0]]) 110 | except SystemExit: 111 | pass 112 | 113 | def test_dm_control_isolate_process(self): 114 | args = tools.AttrDict( 115 | logdir=self.get_temp_dir(), 116 | num_runs=1, 117 | config='debug', 118 | params=tools.AttrDict( 119 | task='cup_catch', 120 | isolate_envs='process', 121 | max_steps=30), 122 | ping_every=0, 123 | resume_runs=False) 124 | try: 125 | tf.app.run(lambda _: train.main(args), [sys.argv[0]]) 126 | except SystemExit: 127 | pass 128 | 129 | 130 | if __name__ == '__main__': 131 | tf.test.main() 132 | -------------------------------------------------------------------------------- /planet/scripts/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Train a Deep Planning Network agent. 16 | 17 | Full training run: 18 | 19 | python3 -m planet.scripts.train \ 20 | --logdir /path/to/logdir \ 21 | --config default \ 22 | --params '{tasks: [cheetah_run]}' 23 | 24 | For debugging: 25 | 26 | python3 -m planet.scripts.train \ 27 | --logdir /path/to/logdir \ 28 | --resume_runs False \ 29 | --num_runs 1000 \ 30 | --config debug \ 31 | --params '{tasks: [cheetah_run]}' 32 | """ 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import argparse 39 | import functools 40 | import os 41 | import sys 42 | 43 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname( 44 | os.path.abspath(__file__))))) 45 | 46 | # Need offline backend to render summaries from within tf.py_func. 47 | import matplotlib 48 | matplotlib.use('Agg') 49 | 50 | import ruamel.yaml as yaml 51 | import tensorflow as tf 52 | 53 | from planet import tools 54 | from planet import training 55 | from planet.scripts import configs 56 | 57 | 58 | def process(logdir, args): 59 | with args.params.unlocked: 60 | args.params.logdir = logdir 61 | config = tools.AttrDict() 62 | with config.unlocked: 63 | config = getattr(configs, args.config)(config, args.params) 64 | training.utility.collect_initial_episodes(config) 65 | tf.reset_default_graph() 66 | dataset = tools.numpy_episodes.numpy_episodes( 67 | config.train_dir, config.test_dir, config.batch_shape, 68 | reader=config.data_reader, 69 | loader=config.data_loader, 70 | num_chunks=config.num_chunks, 71 | preprocess_fn=config.preprocess_fn) 72 | for score in training.utility.train( 73 | training.define_model, dataset, logdir, config): 74 | yield score 75 | 76 | 77 | def main(args): 78 | training.utility.set_up_logging() 79 | experiment = training.Experiment( 80 | args.logdir, 81 | process_fn=functools.partial(process, args=args), 82 | num_runs=args.num_runs, 83 | ping_every=args.ping_every, 84 | resume_runs=args.resume_runs) 85 | for run in experiment: 86 | for unused_score in run: 87 | pass 88 | 89 | 90 | if __name__ == '__main__': 91 | boolean = lambda x: bool(['False', 'True'].index(x)) 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument( 94 | '--logdir', required=True) 95 | parser.add_argument( 96 | '--num_runs', type=int, default=1) 97 | parser.add_argument( 98 | '--config', default='default', 99 | help='Select a configuration function from scripts/configs.py.') 100 | parser.add_argument( 101 | '--params', default='{}', 102 | help='YAML formatted dictionary to be used by the config.') 103 | parser.add_argument( 104 | '--ping_every', type=int, default=0, 105 | help='Used to prevent conflicts between multiple workers; 0 to disable.') 106 | parser.add_argument( 107 | '--resume_runs', type=boolean, default=True, 108 | help='Whether to resume unfinished runs in the log directory.') 109 | args_, remaining = parser.parse_known_args() 110 | args_.params = tools.AttrDict(yaml.safe_load(args_.params.replace('#', ','))) 111 | args_.logdir = args_.logdir and os.path.expanduser(args_.logdir) 112 | remaining.insert(0, sys.argv[0]) 113 | tf.app.run(lambda _: main(args_), remaining) 114 | -------------------------------------------------------------------------------- /planet/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import nested 20 | from . import numpy_episodes 21 | from . import preprocess 22 | from . import schedule 23 | from . import summary 24 | from . import unroll 25 | from .attr_dict import AttrDict 26 | from .bind import bind 27 | from .bound_action import bound_action 28 | from .copy_weights import soft_copy_weights 29 | from .count_dataset import count_dataset 30 | from .count_weights import count_weights 31 | from .custom_optimizer import CustomOptimizer 32 | from .filter_variables_lib import filter_variables 33 | from .gif_summary import gif_summary 34 | from .image_strip_summary import image_strip_summary 35 | from .mask import mask 36 | from .overshooting import overshooting 37 | from .reshape_as import reshape_as 38 | from .shape import shape 39 | from .streaming_mean import StreamingMean 40 | from .target_network import track_network 41 | -------------------------------------------------------------------------------- /planet/tools/attr_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import contextlib 21 | import os 22 | 23 | import numpy as np 24 | import ruamel.yaml as yaml 25 | 26 | 27 | class AttrDict(dict): # collections.OrderedDict 28 | """Wrap a dictionary to access keys as attributes.""" 29 | 30 | def __init__(self, *args, **kwargs): 31 | unlocked = kwargs.pop('_unlocked', not (args or kwargs)) 32 | defaults = kwargs.pop('_defaults', {}) 33 | touched = kwargs.pop('_touched', set()) 34 | super(AttrDict, self).__setattr__('_unlocked', True) 35 | super(AttrDict, self).__setattr__('_touched', set()) 36 | super(AttrDict, self).__setattr__('_defaults', {}) 37 | super(AttrDict, self).__init__(*args, **kwargs) 38 | super(AttrDict, self).__setattr__('_unlocked', unlocked) 39 | super(AttrDict, self).__setattr__('_defaults', defaults) 40 | super(AttrDict, self).__setattr__('_touched', touched) 41 | 42 | def __getattr__(self, name): 43 | try: 44 | return self[name] 45 | except KeyError: 46 | raise AttributeError(name) 47 | 48 | def __setattr__(self, name, value): 49 | self[name] = value 50 | 51 | def __getitem__(self, name): 52 | # Do not provide None for unimplemented magic attributes. 53 | # if name.startswith('__'): 54 | # raise AttributeError(name) 55 | self._touched.add(name) 56 | if name in self: 57 | return super(AttrDict, self).__getitem__(name) 58 | if name in self._defaults: 59 | return self._defaults[name] 60 | raise AttributeError(name) 61 | 62 | def __setitem__(self, name, value): 63 | # if name.startswith('_'): 64 | # raise AttributeError('Cannot set private attribute {}'.format(name)) 65 | if name.startswith('__'): 66 | raise AttributeError("Cannot set magic attribute '{}'".format(name)) 67 | if not self._unlocked: 68 | message = 'Use obj.unlock() before setting {}' 69 | raise RuntimeError(message.format(name)) 70 | super(AttrDict, self).__setitem__(name, value) 71 | 72 | def __repr__(self): 73 | items = [] 74 | for key, value in self.items(): 75 | items.append('{}: {}'.format(key, self._format_value(value))) 76 | return '{' + ', '.join(items) + '}' 77 | 78 | def get(self, key, default=None): 79 | self._touched.add(key) 80 | if key not in self: 81 | return default 82 | return self[key] 83 | 84 | @property 85 | def untouched(self): 86 | return sorted(set(self.keys()) - self._touched) 87 | 88 | @property 89 | @contextlib.contextmanager 90 | def unlocked(self): 91 | self.unlock() 92 | yield 93 | self.lock() 94 | 95 | def lock(self): 96 | super(AttrDict, self).__setattr__('_unlocked', False) 97 | for value in self.values(): 98 | if isinstance(value, AttrDict): 99 | value.lock() 100 | 101 | def unlock(self): 102 | super(AttrDict, self).__setattr__('_unlocked', True) 103 | for value in self.values(): 104 | if isinstance(value, AttrDict): 105 | value.unlock() 106 | 107 | def summarize(self): 108 | items = [] 109 | for key, value in self.items(): 110 | items.append('{}: {}'.format(key, self._format_value(value))) 111 | return '\n'.join(items) 112 | 113 | def update(self, mapping): 114 | if not self._unlocked: 115 | message = 'Use obj.unlock() before updating' 116 | raise RuntimeError(message) 117 | super(AttrDict, self).update(mapping) 118 | return self 119 | 120 | def copy(self, _unlocked=False): 121 | return type(self)(super(AttrDict, self).copy(), _unlocked=_unlocked) 122 | 123 | def save(self, filename): 124 | assert str(filename).endswith('.yaml') 125 | directory = os.path.dirname(str(filename)) 126 | os.makedirs(directory, exist_ok=True) 127 | with open(filename, 'w') as f: 128 | yaml.dump(collections.OrderedDict(self), f) 129 | 130 | @classmethod 131 | def load(cls, filename): 132 | assert str(filename).endswith('.yaml') 133 | with open(filename, 'r') as f: 134 | return cls(yaml.load(f, Loader=yaml.Loader)) 135 | 136 | def _format_value(self, value): 137 | if isinstance(value, np.ndarray): 138 | template = '' 139 | min_ = self._format_value(value.min()) 140 | mean = self._format_value(value.mean()) 141 | max_ = self._format_value(value.max()) 142 | return template.format(value.shape, value.dtype, min_, mean, max_) 143 | if isinstance(value, float) and 1e-3 < abs(value) < 1e6: 144 | return '{:.3f}'.format(value) 145 | if isinstance(value, float): 146 | return '{:4.1e}'.format(value) 147 | if hasattr(value, '__name__'): 148 | return value.__name__ 149 | return str(value) 150 | -------------------------------------------------------------------------------- /planet/tools/bind.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | 20 | class bind(object): 21 | 22 | def __init__(self, fn, *args, **kwargs): 23 | self._fn = fn 24 | self._args = args 25 | self._kwargs = kwargs 26 | 27 | def __call__(self, *args, **kwargs): 28 | args_ = self._args + args 29 | kwargs_ = self._kwargs.copy() 30 | kwargs_.update(kwargs) 31 | return self._fn(*args_, **kwargs_) 32 | 33 | def __repr__(self): 34 | return 'bind({})'.format(self._fn.__name__) 35 | -------------------------------------------------------------------------------- /planet/tools/bound_action.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tensorflow as tf 16 | 17 | 18 | def bound_action(action, strategy): 19 | if strategy == 'none': 20 | pass 21 | elif strategy == 'clip': 22 | forward = tf.stop_gradient(tf.clip_by_value(action, -1.0, 1.0)) 23 | action = action - tf.stop_gradient(action) + forward 24 | elif strategy == 'tanh': 25 | action = tf.tanh(action) 26 | else: 27 | raise NotImplementedError(strategy) 28 | return action 29 | -------------------------------------------------------------------------------- /planet/tools/chunk_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Chunk sequences into fixed lengths.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from planet.tools import nested 23 | 24 | 25 | def chunk_sequence(sequence, chunk_length, randomize=True, num_chunks=None): 26 | """Split a nested dict of sequence tensors into a batch of chunks. 27 | 28 | This function does not expect a batch of sequences, but a single sequence. A 29 | `length` key is added if it did not exist already. When `randomize` is set, 30 | up to `chunk_length - 1` initial frames will be discarded. Final frames that 31 | do not fit into a chunk are always discarded. 32 | 33 | Args: 34 | sequence: Nested dict of tensors with time dimension. 35 | chunk_length: Size of chunks the sequence will be split into. 36 | randomize: Start chunking from a random offset in the sequence, 37 | enforcing that at least one chunk is generated. 38 | num_chunks: Optionally specify the exact number of chunks to be extracted 39 | from the sequence. Requires input to be long enough. 40 | 41 | Returns: 42 | Nested dict of sequence tensors with chunk dimension. 43 | """ 44 | with tf.device('/cpu:0'): 45 | if 'length' in sequence: 46 | length = sequence.pop('length') 47 | else: 48 | length = tf.shape(nested.flatten(sequence)[0])[0] 49 | if randomize: 50 | if num_chunks is None: 51 | num_chunks = tf.maximum(1, length // chunk_length - 1) 52 | else: 53 | num_chunks = num_chunks + 0 * length 54 | used_length = num_chunks * chunk_length 55 | max_offset = length - used_length 56 | offset = tf.random_uniform((), 0, max_offset + 1, dtype=tf.int32) 57 | else: 58 | if num_chunks is None: 59 | num_chunks = length // chunk_length 60 | else: 61 | num_chunks = num_chunks + 0 * length 62 | used_length = num_chunks * chunk_length 63 | max_offset = 0 64 | offset = 0 65 | clipped = nested.map( 66 | lambda tensor: tensor[offset: offset + used_length], 67 | sequence) 68 | chunks = nested.map( 69 | lambda tensor: tf.reshape( 70 | tensor, [num_chunks, chunk_length] + tensor.shape[1:].as_list()), 71 | clipped) 72 | chunks['length'] = chunk_length * tf.ones((num_chunks,), dtype=tf.int32) 73 | return chunks 74 | 75 | 76 | def _pad_tensor(tensor, length, value): 77 | tiling = [length] + ([1] * (tensor.shape.ndims - 1)) 78 | padding = tf.tile(0 * tensor[:1] + value, tiling) 79 | padded = tf.concat([tensor, padding], 0) 80 | return padded 81 | -------------------------------------------------------------------------------- /planet/tools/copy_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from planet.tools import filter_variables_lib 22 | 23 | 24 | def soft_copy_weights(source_pattern, target_pattern, amount): 25 | assert 0 < amount <= 1 26 | source_vars = filter_variables_lib.filter_variables(include=source_pattern) 27 | target_vars = filter_variables_lib.filter_variables(include=target_pattern) 28 | source_vars = sorted(source_vars, key=lambda x: x.name) 29 | target_vars = sorted(target_vars, key=lambda x: x.name) 30 | assert len(source_vars) == len(target_vars) 31 | updates = [] 32 | for source, target in zip(source_vars, target_vars): 33 | assert source.name != target.name 34 | if amount == 1.0: 35 | updates.append(target.assign(source)) 36 | else: 37 | updates.append(target.assign((1 - amount) * target + amount * source)) 38 | return tf.group(*updates) 39 | -------------------------------------------------------------------------------- /planet/tools/count_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | def count_dataset(directory, key='reward'): 26 | directory = os.path.expanduser(directory) 27 | if not tf.gfile.Exists(directory): 28 | message = "Data set directory '{}' does not exist." 29 | raise ValueError(message.format(directory)) 30 | pattern = os.path.join(directory, '*.npz') 31 | def func(): 32 | filenames = tf.gfile.Glob(pattern) 33 | episodes = len(filenames) 34 | episodes = np.array(episodes, dtype=np.int32) 35 | return episodes 36 | return tf.py_func(func, [], tf.int32) 37 | -------------------------------------------------------------------------------- /planet/tools/count_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import re 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | def count_weights(scope=None, exclude=None, graph=None): 26 | """Count learnable parameters. 27 | 28 | Args: 29 | scope: Resrict the count to a variable scope. 30 | exclude: Regex to match variable names to exclude. 31 | graph: Operate on a graph other than the current default graph. 32 | 33 | Returns: 34 | Number of learnable parameters as integer. 35 | """ 36 | if scope: 37 | scope = scope if scope.endswith('/') else scope + '/' 38 | graph = graph or tf.get_default_graph() 39 | vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 40 | if scope: 41 | vars_ = [var for var in vars_ if var.name.startswith(scope)] 42 | if exclude: 43 | exclude = re.compile(exclude) 44 | vars_ = [var for var in vars_ if not exclude.match(var.name)] 45 | shapes = [] 46 | for var in vars_: 47 | if not var.shape.is_fully_defined(): 48 | message = "Trainable variable '{}' has undefined shape '{}'." 49 | raise ValueError(message.format(var.name, var.shape)) 50 | shapes.append(var.shape.as_list()) 51 | return int(sum(np.prod(shape) for shape in shapes)) 52 | -------------------------------------------------------------------------------- /planet/tools/custom_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from planet.tools import filter_variables_lib 22 | 23 | 24 | class CustomOptimizer(object): 25 | 26 | def __init__( 27 | self, optimizer_cls, step, log, learning_rate, 28 | include=None, exclude=None, clipping=None, schedule=None, 29 | debug=False, name='custom_optimizer'): 30 | if schedule: 31 | learning_rate *= schedule(step) 32 | self._step = step 33 | self._log = log 34 | self._learning_rate = learning_rate 35 | self._variables = filter_variables_lib.filter_variables(include, exclude) 36 | self._clipping = float(clipping) 37 | self._debug = debug 38 | self._name = name 39 | self._optimizer = optimizer_cls(learning_rate, name=name) 40 | 41 | def maybe_minimize(self, condition, loss): 42 | # loss = tf.cond(condition, lambda: loss, float) 43 | update_op, grad_norm = tf.cond( 44 | condition, 45 | lambda: self.minimize(loss), 46 | lambda: (tf.no_op(), 0.0)) 47 | with tf.control_dependencies([update_op]): 48 | summary = tf.cond( 49 | tf.logical_and(condition, self._log), 50 | lambda: self.summarize(grad_norm), str) 51 | if self._debug: 52 | # print_op = tf.print('{}_grad_norm='.format(self._name), grad_norm) 53 | message = 'Zero gradient norm in {} optimizer.'.format(self._name) 54 | assertion = lambda: tf.assert_greater(grad_norm, 0.0, message=message) 55 | assert_op = tf.cond(condition, assertion, tf.no_op) 56 | with tf.control_dependencies([assert_op]): 57 | summary = tf.identity(summary) 58 | return summary, grad_norm 59 | 60 | def minimize(self, loss): 61 | with tf.name_scope('optimizer_{}'.format(self._name)): 62 | if self._debug: 63 | loss = tf.check_numerics(loss, '{}_loss'.format(self._name)) 64 | gradients, variables = zip(*self._optimizer.compute_gradients( 65 | loss, self._variables, colocate_gradients_with_ops=True)) 66 | grad_norm = tf.global_norm(gradients) 67 | if self._clipping: 68 | gradients, _ = tf.clip_by_global_norm( 69 | gradients, self._clipping, grad_norm) 70 | optimize = self._optimizer.apply_gradients(zip(gradients, variables)) 71 | return optimize, grad_norm 72 | 73 | def summarize(self, grad_norm): 74 | summaries = [] 75 | with tf.name_scope('optimizer_{}'.format(self._name)): 76 | summaries.append(tf.summary.scalar('learning_rate', self._learning_rate)) 77 | summaries.append(tf.summary.scalar('grad_norm', grad_norm)) 78 | if self._clipping: 79 | clipped = tf.minimum(grad_norm, self._clipping) 80 | summaries.append(tf.summary.scalar('clipped_gradient_norm', clipped)) 81 | summary = tf.summary.merge(summaries) 82 | return summary 83 | -------------------------------------------------------------------------------- /planet/tools/filter_variables_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import re 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def filter_variables(include=None, exclude=None): 25 | # Check arguments. 26 | if include is None: 27 | include = (r'.*',) 28 | if exclude is None: 29 | exclude = () 30 | if not isinstance(include, (tuple, list)): 31 | include = (include,) 32 | if not isinstance(exclude, (tuple, list)): 33 | exclude = (exclude,) 34 | # Compile regexes. 35 | include = [re.compile(regex) for regex in include] 36 | exclude = [re.compile(regex) for regex in exclude] 37 | variables = tf.global_variables() 38 | if not variables: 39 | raise RuntimeError('There are no variables to filter.') 40 | # Check regexes. 41 | for regex in include: 42 | message = "Regex r'{}' does not match any variables in the graph.\n" 43 | message += 'All variables:\n' 44 | message += '\n'.join('- {}'.format(var.name) for var in variables) 45 | if not any(regex.match(variable.name) for variable in variables): 46 | raise RuntimeError(message.format(regex.pattern)) 47 | # Filter variables. 48 | filtered = [] 49 | for variable in variables: 50 | if not any(regex.match(variable.name) for regex in include): 51 | continue 52 | if any(regex.match(variable.name) for regex in exclude): 53 | continue 54 | filtered.append(variable) 55 | # Check result. 56 | if not filtered: 57 | message = 'No variables left after filtering.' 58 | message += '\nIncludes:\n' + '\n'.join(regex.pattern for regex in include) 59 | message += '\nExcludes:\n' + '\n'.join(regex.pattern for regex in exclude) 60 | raise RuntimeError(message) 61 | return filtered 62 | -------------------------------------------------------------------------------- /planet/tools/gif_summary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow.python.ops import summary_op_util 22 | 23 | 24 | def encode_gif(images, fps): 25 | """Encodes numpy images into gif string. 26 | 27 | Args: 28 | images: A 5-D `uint8` `np.array` (or a list of 4-D images) of shape 29 | `[batch_size, time, height, width, channels]` where `channels` is 1 or 3. 30 | fps: frames per second of the animation 31 | 32 | Returns: 33 | The encoded gif string. 34 | 35 | Raises: 36 | IOError: If the ffmpeg command returns an error. 37 | """ 38 | from subprocess import Popen, PIPE 39 | h, w, c = images[0].shape 40 | cmd = [ 41 | 'ffmpeg', '-y', 42 | '-f', 'rawvideo', 43 | '-vcodec', 'rawvideo', 44 | '-r', '%.02f' % fps, 45 | '-s', '%dx%d' % (w, h), 46 | '-pix_fmt', {1: 'gray', 3: 'rgb24'}[c], 47 | '-i', '-', 48 | '-filter_complex', 49 | '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', 50 | '-r', '%.02f' % fps, 51 | '-f', 'gif', 52 | '-'] 53 | proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) 54 | for image in images: 55 | proc.stdin.write(image.tostring()) 56 | out, err = proc.communicate() 57 | if proc.returncode: 58 | err = '\n'.join([' '.join(cmd), err.decode('utf8')]) 59 | raise IOError(err) 60 | del proc 61 | return out 62 | 63 | 64 | def py_gif_summary(tag, images, max_outputs, fps): 65 | """Outputs a `Summary` protocol buffer with gif animations. 66 | 67 | Args: 68 | tag: Name of the summary. 69 | images: A 5-D `uint8` `np.array` of shape `[batch_size, time, height, 70 | width, channels]` where `channels` is 1 or 3. 71 | max_outputs: Max number of batch elements to generate gifs for. 72 | fps: frames per second of the animation 73 | 74 | Returns: 75 | The serialized `Summary` protocol buffer. 76 | 77 | Raises: 78 | ValueError: If `images` is not a 5-D `uint8` array with 1 or 3 channels. 79 | """ 80 | is_bytes = isinstance(tag, bytes) 81 | if is_bytes: 82 | tag = tag.decode("utf-8") 83 | images = np.asarray(images) 84 | if images.dtype != np.uint8: 85 | raise ValueError("Tensor must have dtype uint8 for gif summary.") 86 | if images.ndim != 5: 87 | raise ValueError("Tensor must be 5-D for gif summary.") 88 | batch_size, _, height, width, channels = images.shape 89 | if channels not in (1, 3): 90 | raise ValueError("Tensors must have 1 or 3 channels for gif summary.") 91 | summ = tf.Summary() 92 | num_outputs = min(batch_size, max_outputs) 93 | for i in range(num_outputs): 94 | image_summ = tf.Summary.Image() 95 | image_summ.height = height 96 | image_summ.width = width 97 | image_summ.colorspace = channels # 1: grayscale, 3: RGB 98 | try: 99 | image_summ.encoded_image_string = encode_gif(images[i], fps) 100 | except (IOError, OSError) as e: 101 | tf.logging.warning( 102 | "Unable to encode images to a gif string because either ffmpeg is " 103 | "not installed or ffmpeg returned an error: %s. Falling back to an " 104 | "image summary of the first frame in the sequence.", e) 105 | try: 106 | from PIL import Image # pylint: disable=g-import-not-at-top 107 | import io # pylint: disable=g-import-not-at-top 108 | with io.BytesIO() as output: 109 | Image.fromarray(images[i][0]).save(output, "PNG") 110 | image_summ.encoded_image_string = output.getvalue() 111 | except Exception: 112 | tf.logging.warning( 113 | "Gif summaries requires ffmpeg or PIL to be installed: %s", e) 114 | image_summ.encoded_image_string = ( 115 | "".encode('utf-8') if is_bytes else "") 116 | if num_outputs == 1: 117 | summ_tag = "{}/gif".format(tag) 118 | else: 119 | summ_tag = "{}/gif/{}".format(tag, i) 120 | summ.value.add(tag=summ_tag, image=image_summ) 121 | summ_str = summ.SerializeToString() 122 | return summ_str 123 | 124 | 125 | def gif_summary(name, tensor, max_outputs, fps, collections=None, family=None): 126 | """Outputs a `Summary` protocol buffer with gif animations. 127 | 128 | Args: 129 | name: Name of the summary. 130 | tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width, 131 | channels]` where `channels` is 1 or 3. 132 | max_outputs: Max number of batch elements to generate gifs for. 133 | fps: frames per second of the animation 134 | collections: Optional list of tf.GraphKeys. The collections to add the 135 | summary to. Defaults to [tf.GraphKeys.SUMMARIES] 136 | family: Optional; if provided, used as the prefix of the summary tag name, 137 | which controls the tab name used for display on Tensorboard. 138 | 139 | Returns: 140 | A scalar `Tensor` of type `string`. The serialized `Summary` protocol 141 | buffer. 142 | """ 143 | tensor = tf.convert_to_tensor(tensor) 144 | if tensor.dtype in (tf.float32, tf.float64): 145 | tensor = tf.cast(255.0 * tensor, tf.uint8) 146 | with summary_op_util.summary_scope( 147 | name, family, values=[tensor]) as (tag, scope): 148 | val = tf.py_func( 149 | py_gif_summary, 150 | [tag, tensor, max_outputs, fps], 151 | tf.string, 152 | stateful=False, 153 | name=scope) 154 | summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES]) 155 | return val 156 | -------------------------------------------------------------------------------- /planet/tools/image_strip_summary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def image_strip_summary(name, images, max_length=100, max_batch=10): 23 | """Create an image summary that places frames of a video tensor side by side. 24 | 25 | Args: 26 | name: Name tag of the summary. 27 | images: Tensor with the dimensions batch, time, height, width, channels. 28 | max_length: Maximum number of frames per sequence to include. 29 | max_batch: Maximum number of sequences to include. 30 | 31 | Returns: 32 | Summary string tensor. 33 | """ 34 | if max_batch: 35 | images = images[:max_batch] 36 | if max_length: 37 | images = images[:, :max_length] 38 | if images.dtype == tf.uint8: 39 | images = tf.to_float(images) / 255.0 40 | length, width = tf.shape(images)[1], tf.shape(images)[3] 41 | images = tf.transpose(images, [0, 2, 1, 3, 4]) 42 | images = tf.reshape(images, [1, -1, length * width, 3]) 43 | images = tf.clip_by_value(images, 0., 1.) 44 | return tf.summary.image(name, images) 45 | -------------------------------------------------------------------------------- /planet/tools/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def mask(tensor, mask=None, length=None, value=0, debug=False): 23 | """Set padding elements of a batch of sequences to a constant. 24 | 25 | Useful for setting padding elements to zero before summing along the time 26 | dimension, or for preventing infinite results in padding elements. Either 27 | mask or length must be provided. 28 | 29 | Args: 30 | tensor: Tensor of sequences. 31 | mask: Boolean mask of valid indices. 32 | length: Batch of sequence lengths. 33 | value: Value to write into padding elemnts. 34 | debug: Test for infinite values; slows down performance. 35 | 36 | Raises: 37 | KeyError: If both or non of `mask` and `length` are provided. 38 | 39 | Returns: 40 | Masked sequences. 41 | """ 42 | if len([x for x in (mask, length) if x is not None]) != 1: 43 | raise KeyError('Exactly one of mask and length must be provided.') 44 | with tf.name_scope('mask'): 45 | if mask is None: 46 | range_ = tf.range(tensor.shape[1].value) 47 | mask = range_[None, :] < length[:, None] 48 | batch_dims = mask.shape.ndims 49 | while tensor.shape.ndims > mask.shape.ndims: 50 | mask = mask[..., None] 51 | multiples = [1] * batch_dims + tensor.shape[batch_dims:].as_list() 52 | mask = tf.tile(mask, multiples) 53 | masked = tf.where(mask, tensor, value * tf.ones_like(tensor)) 54 | if debug: 55 | masked = tf.check_numerics(masked, 'masked') 56 | return masked 57 | -------------------------------------------------------------------------------- /planet/tools/nested.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for manipulating nested tuples, list, and dictionaries.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | _builtin_zip = zip 21 | _builtin_map = map 22 | _builtin_filter = filter 23 | 24 | 25 | def zip_(*structures, **kwargs): 26 | """Combine corresponding elements in multiple nested structure to tuples. 27 | 28 | The nested structures can consist of any combination of lists, tuples, and 29 | dicts. All provided structures must have the same nesting. 30 | 31 | Args: 32 | *structures: Nested structures. 33 | flatten: Whether to flatten the resulting structure into a tuple. Keys of 34 | dictionaries will be discarded. 35 | 36 | Returns: 37 | Nested structure. 38 | """ 39 | # Named keyword arguments are not allowed after *args in Python 2. 40 | flatten = kwargs.pop('flatten', False) 41 | assert not kwargs, 'zip() got unexpected keyword arguments.' 42 | return map( 43 | lambda *x: x if len(x) > 1 else x[0], 44 | *structures, 45 | flatten=flatten) 46 | 47 | 48 | def map_(function, *structures, **kwargs): 49 | """Apply a function to every element in a nested structure. 50 | 51 | If multiple structures are provided as input, their structure must match and 52 | the function will be applied to corresponding groups of elements. The nested 53 | structure can consist of any combination of lists, tuples, and dicts. 54 | 55 | Args: 56 | function: The function to apply to the elements of the structure. Receives 57 | one argument for every structure that is provided. 58 | *structures: One of more nested structures. 59 | flatten: Whether to flatten the resulting structure into a tuple. Keys of 60 | dictionaries will be discarded. 61 | 62 | Returns: 63 | Nested structure. 64 | """ 65 | # Named keyword arguments are not allowed after *args in Python 2. 66 | flatten = kwargs.pop('flatten', False) 67 | assert not kwargs, 'map() got unexpected keyword arguments.' 68 | 69 | def impl(function, *structures): 70 | if len(structures) == 0: 71 | return structures 72 | if all(isinstance(s, (tuple, list)) for s in structures): 73 | if len(set(len(x) for x in structures)) > 1: 74 | raise ValueError('Cannot merge tuples or lists of different length.') 75 | args = tuple((impl(function, *x) for x in _builtin_zip(*structures))) 76 | if hasattr(structures[0], '_fields'): # namedtuple 77 | return type(structures[0])(*args) 78 | else: # tuple, list 79 | return type(structures[0])(args) 80 | if all(isinstance(s, dict) for s in structures): 81 | if len(set(frozenset(x.keys()) for x in structures)) > 1: 82 | raise ValueError('Cannot merge dicts with different keys.') 83 | merged = { 84 | k: impl(function, *(s[k] for s in structures)) 85 | for k in structures[0]} 86 | return type(structures[0])(merged) 87 | return function(*structures) 88 | 89 | result = impl(function, *structures) 90 | if flatten: 91 | result = flatten_(result) 92 | return result 93 | 94 | 95 | def flatten_(structure): 96 | """Combine all leaves of a nested structure into a tuple. 97 | 98 | The nested structure can consist of any combination of tuples, lists, and 99 | dicts. Dictionary keys will be discarded but values will ordered by the 100 | sorting of the keys. 101 | 102 | Args: 103 | structure: Nested structure. 104 | 105 | Returns: 106 | Flat tuple. 107 | """ 108 | if isinstance(structure, dict): 109 | result = () 110 | for key in sorted(list(structure.keys())): 111 | result += flatten_(structure[key]) 112 | return result 113 | if isinstance(structure, (tuple, list)): 114 | result = () 115 | for element in structure: 116 | result += flatten_(element) 117 | return result 118 | return (structure,) 119 | 120 | 121 | def filter_(predicate, *structures, **kwargs): 122 | """Select elements of a nested structure based on a predicate function. 123 | 124 | If multiple structures are provided as input, their structure must match and 125 | the function will be applied to corresponding groups of elements. The nested 126 | structure can consist of any combination of lists, tuples, and dicts. 127 | 128 | Args: 129 | predicate: The function to determine whether an element should be kept. 130 | Receives one argument for every structure that is provided. 131 | *structures: One of more nested structures. 132 | flatten: Whether to flatten the resulting structure into a tuple. Keys of 133 | dictionaries will be discarded. 134 | 135 | Returns: 136 | Nested structure. 137 | """ 138 | # Named keyword arguments are not allowed after *args in Python 2. 139 | flatten = kwargs.pop('flatten', False) 140 | assert not kwargs, 'filter() got unexpected keyword arguments.' 141 | 142 | def impl(predicate, *structures): 143 | if len(structures) == 0: 144 | return structures 145 | if all(isinstance(s, (tuple, list)) for s in structures): 146 | if len(set(len(x) for x in structures)) > 1: 147 | raise ValueError('Cannot merge tuples or lists of different length.') 148 | # Only wrap in tuples if more than one structure provided. 149 | if len(structures) > 1: 150 | filtered = (impl(predicate, *x) for x in _builtin_zip(*structures)) 151 | else: 152 | filtered = (impl(predicate, x) for x in structures[0]) 153 | # Remove empty containers and construct result structure. 154 | if hasattr(structures[0], '_fields'): # namedtuple 155 | filtered = (x if x != () else None for x in filtered) 156 | return type(structures[0])(*filtered) 157 | else: # tuple, list 158 | filtered = ( 159 | x for x in filtered if not isinstance(x, (tuple, list, dict)) or x) 160 | return type(structures[0])(filtered) 161 | if all(isinstance(s, dict) for s in structures): 162 | if len(set(frozenset(x.keys()) for x in structures)) > 1: 163 | raise ValueError('Cannot merge dicts with different keys.') 164 | # Only wrap in tuples if more than one structure provided. 165 | if len(structures) > 1: 166 | filtered = { 167 | k: impl(predicate, *(s[k] for s in structures)) 168 | for k in structures[0]} 169 | else: 170 | filtered = {k: impl(predicate, v) for k, v in structures[0].items()} 171 | # Remove empty containers and construct result structure. 172 | filtered = { 173 | k: v for k, v in filtered.items() 174 | if not isinstance(v, (tuple, list, dict)) or v} 175 | return type(structures[0])(filtered) 176 | if len(structures) > 1: 177 | return structures if predicate(*structures) else () 178 | else: 179 | return structures[0] if predicate(structures[0]) else () 180 | 181 | result = impl(predicate, *structures) 182 | if flatten: 183 | result = flatten_(result) 184 | return result 185 | 186 | 187 | zip = zip_ 188 | map = map_ 189 | flatten = flatten_ 190 | filter = filter_ 191 | -------------------------------------------------------------------------------- /planet/tools/numpy_episodes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Load tensors from a directory of numpy files.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import os 23 | import random 24 | 25 | from scipy.ndimage import interpolation 26 | import numpy as np 27 | import tensorflow as tf 28 | 29 | from planet.tools import attr_dict 30 | from planet.tools import chunk_sequence 31 | 32 | 33 | def numpy_episodes( 34 | train_dir, test_dir, shape, reader=None, loader=None, 35 | num_chunks=None, preprocess_fn=None): 36 | """Read sequences stored as compressed Numpy files as a TensorFlow dataset. 37 | 38 | Args: 39 | train_dir: Directory containing NPZ files of the training dataset. 40 | test_dir: Directory containing NPZ files of the testing dataset. 41 | shape: Tuple of batch size and chunk length for the datasets. 42 | reader: Callable that reads an episode from a NPZ filename. 43 | loader: Generator that yields episodes. 44 | 45 | Returns: 46 | Structured data from numpy episodes as Tensors. 47 | """ 48 | reader = reader or episode_reader 49 | loader = loader or cache_loader 50 | try: 51 | dtypes, shapes = _read_spec(reader, train_dir) 52 | except ZeroDivisionError: 53 | dtypes, shapes = _read_spec(reader, test_dir) 54 | train = tf.data.Dataset.from_generator( 55 | functools.partial(loader, reader, train_dir, shape[0]), 56 | dtypes, shapes) 57 | test = tf.data.Dataset.from_generator( 58 | functools.partial(loader, reader, test_dir, shape[0]), 59 | dtypes, shapes) 60 | chunking = lambda x: tf.data.Dataset.from_tensor_slices( 61 | chunk_sequence.chunk_sequence(x, shape[1], True, num_chunks)) 62 | def sequence_preprocess_fn(sequence): 63 | if preprocess_fn: 64 | sequence['image'] = preprocess_fn(sequence['image']) 65 | return sequence 66 | train = train.flat_map(chunking) 67 | train = train.batch(shape[0], drop_remainder=True) 68 | train = train.map(sequence_preprocess_fn, 10).prefetch(10) 69 | test = test.flat_map(chunking) 70 | test = test.batch(shape[0], drop_remainder=True) 71 | test = test.map(sequence_preprocess_fn, 10).prefetch(10) 72 | return attr_dict.AttrDict(train=train, test=test) 73 | 74 | 75 | def cache_loader(reader, directory, batch_size, every): 76 | cache = {} 77 | while True: 78 | episodes = _sample(cache.values(), every) 79 | for episode in _permuted(episodes, every): 80 | yield episode 81 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz')) 82 | filenames = [filename for filename in filenames if filename not in cache] 83 | for filename in filenames: 84 | cache[filename] = reader(filename) 85 | 86 | 87 | def recent_loader(reader, directory, batch_size, every): 88 | recent = {} 89 | cache = {} 90 | while True: 91 | episodes = [] 92 | episodes += _sample(recent.values(), every // 2) 93 | episodes += _sample(cache.values(), every // 2) 94 | for episode in _permuted(episodes, every): 95 | yield episode 96 | cache.update(recent) 97 | recent = {} 98 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz')) 99 | filenames = [filename for filename in filenames if filename not in cache] 100 | for filename in filenames: 101 | recent[filename] = reader(filename) 102 | 103 | 104 | def reload_loader(reader, directory, batch_size): 105 | directory = os.path.expanduser(directory) 106 | while True: 107 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz')) 108 | random.shuffle(filenames) 109 | for filename in filenames: 110 | yield reader(filename) 111 | 112 | 113 | def dummy_loader(reader, directory, batch_size): 114 | random = np.random.RandomState(seed=0) 115 | dtypes, shapes, length = _read_spec(reader, directory, True, True) 116 | while True: 117 | episode = {} 118 | for key in dtypes: 119 | dtype, shape = dtypes[key], (length,) + shapes[key][1:] 120 | if dtype in (np.float32, np.float64): 121 | episode[key] = random.uniform(0, 1, shape).astype(dtype) 122 | elif dtype in (np.int32, np.int64, np.uint8): 123 | episode[key] = random.uniform(0, 255, shape).astype(dtype) 124 | else: 125 | raise NotImplementedError('Unsupported dtype {}.'.format(dtype)) 126 | yield episode 127 | 128 | 129 | def episode_reader(filename, resize=None, max_length=None, action_noise=None): 130 | with tf.gfile.Open(filename, 'rb') as file_: 131 | episode = np.load(file_) 132 | episode = {key: _convert_type(episode[key]) for key in episode.keys()} 133 | episode['return'] = np.cumsum(episode['reward']) 134 | if max_length: 135 | episode = {key: value[:max_length] for key, value in episode.items()} 136 | if resize and resize != 1: 137 | factors = (1, resize, resize, 1) 138 | episode['image'] = interpolation.zoom(episode['image'], factors) 139 | if action_noise: 140 | seed = np.fromstring(filename, dtype=np.uint8) 141 | episode['action'] += np.random.RandomState(seed).normal( 142 | 0, action_noise, episode['action'].shape) 143 | return episode 144 | 145 | 146 | def _read_spec( 147 | reader, directory, return_length=False, numpy_types=False): 148 | episodes = reload_loader(reader, directory, batch_size=1) 149 | episode = next(episodes) 150 | episodes.close() 151 | dtypes = {key: value.dtype for key, value in episode.items()} 152 | if not numpy_types: 153 | dtypes = {key: tf.as_dtype(value) for key, value in dtypes.items()} 154 | shapes = {key: value.shape for key, value in episode.items()} 155 | shapes = {key: (None,) + shape[1:] for key, shape in shapes.items()} 156 | if return_length: 157 | length = len(episode[list(shapes.keys())[0]]) 158 | return dtypes, shapes, length 159 | else: 160 | return dtypes, shapes 161 | 162 | 163 | def _convert_type(array): 164 | if array.dtype == np.float64: 165 | return array.astype(np.float32) 166 | if array.dtype == np.int64: 167 | return array.astype(np.int32) 168 | return array 169 | 170 | 171 | def _sample(sequence, amount): 172 | sequence = list(sequence) 173 | amount = min(amount, len(sequence)) 174 | return random.sample(sequence, amount) 175 | 176 | 177 | def _permuted(sequence, amount): 178 | sequence = list(sequence) 179 | if not sequence: 180 | return 181 | index = 0 182 | while True: 183 | for element in np.random.permutation(sequence): 184 | if index >= amount: 185 | return 186 | yield element 187 | index += 1 188 | -------------------------------------------------------------------------------- /planet/tools/overshooting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import functools 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from planet.tools import nested 25 | from planet.tools import shape 26 | 27 | 28 | def overshooting( 29 | cell, target, embedded, prev_action, length, amount, posterior=None, 30 | ignore_input=False): 31 | """Perform open loop rollouts from the posteriors at every step. 32 | 33 | First, we apply the encoder to embed raw inputs and apply the model to obtain 34 | posterior states for every time step. Then, we perform `amount` long open 35 | loop rollouts from these posteriors. 36 | 37 | Note that the actions should be those leading to the current time step. So 38 | under common convention, it contains the last actions while observations are 39 | the current ones. 40 | 41 | Input: 42 | 43 | target, embedded: 44 | [A B C D E F] [A B C D E ] 45 | 46 | prev_action: 47 | [0 A B C D E] [0 A B C D ] 48 | 49 | length: 50 | [6 5] 51 | 52 | amount: 53 | 3 54 | 55 | Output: 56 | 57 | prior, posterior, target: 58 | [A B C D E F] [A B C D E ] 59 | [B C D E F ] [B C D E ] 60 | [C D E F ] [C D E ] 61 | [D E F ] [D E ] 62 | 63 | mask: 64 | [1 1 1 1 1 1] [1 1 1 1 1 0] 65 | [1 1 1 1 1 0] [1 1 1 1 0 0] 66 | [1 1 1 1 0 0] [1 1 1 0 0 0] 67 | [1 1 1 0 0 0] [1 1 0 0 0 0] 68 | 69 | """ 70 | # Closed loop unroll to get posterior states, which are the starting points 71 | # for open loop unrolls. We don't need the last time step, since we have no 72 | # targets for unrolls from it. 73 | if posterior is None: 74 | use_obs = tf.ones(tf.shape( 75 | nested.flatten(embedded)[0][:, :, :1])[:3], tf.bool) 76 | use_obs = tf.cond( 77 | tf.convert_to_tensor(ignore_input), 78 | lambda: tf.zeros_like(use_obs, tf.bool), 79 | lambda: use_obs) 80 | (_, posterior), _ = tf.nn.dynamic_rnn( 81 | cell, (embedded, prev_action, use_obs), length, dtype=tf.float32, 82 | swap_memory=True) 83 | 84 | # Arrange inputs for every iteration in the open loop unroll. Every loop 85 | # iteration below corresponds to one row in the docstring illustration. 86 | max_length = shape.shape(nested.flatten(embedded)[0])[1] 87 | first_output = { 88 | # 'observ': embedded, 89 | 'prev_action': prev_action, 90 | 'posterior': posterior, 91 | 'target': target, 92 | 'mask': tf.sequence_mask(length, max_length, tf.int32), 93 | } 94 | 95 | progress_fn = lambda tensor: tf.concat([tensor[:, 1:], 0 * tensor[:, :1]], 1) 96 | other_outputs = tf.scan( 97 | lambda past_output, _: nested.map(progress_fn, past_output), 98 | tf.range(amount), first_output) 99 | sequences = nested.map( 100 | lambda lhs, rhs: tf.concat([lhs[None], rhs], 0), 101 | first_output, other_outputs) 102 | 103 | # Merge batch and time dimensions of steps to compute unrolls from every 104 | # time step as one batch. The time dimension becomes the number of 105 | # overshooting distances. 106 | sequences = nested.map( 107 | lambda tensor: _merge_dims(tensor, [1, 2]), 108 | sequences) 109 | sequences = nested.map( 110 | lambda tensor: tf.transpose( 111 | tensor, [1, 0] + list(range(2, tensor.shape.ndims))), 112 | sequences) 113 | merged_length = tf.reduce_sum(sequences['mask'], 1) 114 | 115 | # Mask out padding frames; unnecessary if the input is already masked. 116 | sequences = nested.map( 117 | lambda tensor: tensor * tf.cast( 118 | _pad_dims(sequences['mask'], tensor.shape.ndims), 119 | tensor.dtype), 120 | sequences) 121 | 122 | # Compute open loop rollouts. 123 | use_obs = tf.zeros(tf.shape(sequences['mask']), tf.bool)[..., None] 124 | embed_size = nested.flatten(embedded)[0].shape[2].value 125 | obs = tf.zeros(shape.shape(sequences['mask']) + [embed_size]) 126 | prev_state = nested.map( 127 | lambda tensor: tf.concat([0 * tensor[:, :1], tensor[:, :-1]], 1), 128 | posterior) 129 | prev_state = nested.map( 130 | lambda tensor: _merge_dims(tensor, [0, 1]), prev_state) 131 | (priors, _), _ = tf.nn.dynamic_rnn( 132 | cell, (obs, sequences['prev_action'], use_obs), 133 | merged_length, 134 | prev_state) 135 | 136 | # Restore batch dimension. 137 | target, prior, posterior, mask = nested.map( 138 | functools.partial(_restore_batch_dim, batch_size=shape.shape(length)[0]), 139 | (sequences['target'], priors, sequences['posterior'], sequences['mask'])) 140 | 141 | mask = tf.cast(mask, tf.bool) 142 | return target, prior, posterior, mask 143 | 144 | 145 | def _merge_dims(tensor, dims): 146 | """Flatten consecutive axes of a tensor trying to preserve static shapes.""" 147 | if isinstance(tensor, (list, tuple, dict)): 148 | return nested.map(tensor, lambda x: _merge_dims(x, dims)) 149 | tensor = tf.convert_to_tensor(tensor) 150 | if (np.array(dims) - min(dims) != np.arange(len(dims))).all(): 151 | raise ValueError('Dimensions to merge must all follow each other.') 152 | start, end = dims[0], dims[-1] 153 | output = tf.reshape(tensor, tf.concat([ 154 | tf.shape(tensor)[:start], 155 | [tf.reduce_prod(tf.shape(tensor)[start: end + 1])], 156 | tf.shape(tensor)[end + 1:]], axis=0)) 157 | merged = tensor.shape[start: end + 1].as_list() 158 | output.set_shape( 159 | tensor.shape[:start].as_list() + 160 | [None if None in merged else np.prod(merged)] + 161 | tensor.shape[end + 1:].as_list()) 162 | return output 163 | 164 | 165 | def _pad_dims(tensor, rank): 166 | """Append empty dimensions to the tensor until it is of the given rank.""" 167 | for _ in range(rank - tensor.shape.ndims): 168 | tensor = tensor[..., None] 169 | return tensor 170 | 171 | 172 | def _restore_batch_dim(tensor, batch_size): 173 | """Split batch dimension out of the first dimension of a tensor.""" 174 | initial = shape.shape(tensor) 175 | desired = [batch_size, initial[0] // batch_size] + initial[1:] 176 | return tf.reshape(tensor, desired) 177 | -------------------------------------------------------------------------------- /planet/tools/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def preprocess(image, bits): 23 | bins = 2 ** bits 24 | image = tf.to_float(image) 25 | if bits < 8: 26 | image = tf.floor(image / 2 ** (8 - bits)) 27 | image = image / bins 28 | image = image + tf.random_uniform(tf.shape(image), 0, 1.0 / bins) 29 | image = image - 0.5 30 | return image 31 | 32 | 33 | def postprocess(image, bits, dtype=tf.float32): 34 | bins = 2 ** bits 35 | if dtype == tf.float32: 36 | image = tf.floor(bins * (image + 0.5)) / bins 37 | elif dtype == tf.uint8: 38 | image = image + 0.5 39 | image = tf.floor(bins * image) 40 | image = image * (256.0 / bins) 41 | image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8) 42 | else: 43 | raise NotImplementedError(dtype) 44 | return image 45 | -------------------------------------------------------------------------------- /planet/tools/reshape_as.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from planet.tools import nested 22 | 23 | 24 | def reshape_as(tensor, reference): 25 | if isinstance(tensor, (list, tuple, dict)): 26 | return nested.map(tensor, lambda x: reshape_as(x, reference)) 27 | tensor = tf.convert_to_tensor(tensor) 28 | reference = tf.convert_to_tensor(reference) 29 | statics = reference.shape.as_list() 30 | dynamics = tf.shape(reference) 31 | shape = [ 32 | static if static is not None else dynamics[index] 33 | for index, static in enumerate(statics)] 34 | return tf.reshape(tensor, shape) 35 | -------------------------------------------------------------------------------- /planet/tools/schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | 23 | def binary(step, batch_size, after, every, until): 24 | # https://www.desmos.com/calculator/csbhr5cjzz 25 | offset_step = step - after 26 | phase = tf.less(offset_step % every, batch_size) 27 | active = tf.greater_equal(step, after) 28 | if until > 0: 29 | active = tf.logical_and(active, tf.less(step, until)) 30 | result = tf.logical_and(phase, active) 31 | result.set_shape(tf.TensorShape([])) 32 | return result 33 | 34 | 35 | def linear(step, ramp, min=None, max=None): 36 | # https://www.desmos.com/calculator/nrumhgvxql 37 | if ramp == 0: 38 | result = tf.constant(1, tf.float32) 39 | if ramp > 0: 40 | result = tf.minimum(tf.to_float(step) / tf.to_float(ramp), 1) 41 | if ramp < 0: 42 | result = 1 - linear(step, abs(ramp)) 43 | if min is not None and max is not None: 44 | assert min <= max 45 | if min is not None: 46 | assert 0 <= min <= 1 47 | result = tf.maximum(min, result) 48 | if max is not None: 49 | assert 0 <= min <= 1 50 | result = tf.minimum(result, max) 51 | result.set_shape(tf.TensorShape([])) 52 | return result 53 | 54 | 55 | def linear_reset(step, ramp, after, every): 56 | # https://www.desmos.com/calculator/motbnqhacw 57 | assert every > ramp, (every, ramp) # Would never reach max value. 58 | assert not (every != np.inf and after == np.inf), (every, after) 59 | step, ramp, after, every = [ 60 | tf.to_float(x) for x in (step, ramp, after, every)] 61 | before = tf.to_float(tf.less(step, after)) * step 62 | after = tf.to_float(tf.greater_equal(step, after)) * ((step - after) % every) 63 | result = tf.minimum((before + after) / ramp, 1) 64 | result.set_shape(tf.TensorShape([])) 65 | return result 66 | -------------------------------------------------------------------------------- /planet/tools/shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def shape(tensor): 23 | static = tensor.get_shape().as_list() 24 | dynamic = tf.unstack(tf.shape(tensor)) 25 | assert len(static) == len(dynamic) 26 | combined = [d if s is None else s for s, d in zip(static, dynamic)] 27 | return combined 28 | -------------------------------------------------------------------------------- /planet/tools/streaming_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | 22 | class StreamingMean(object): 23 | """Compute a streaming estimation of the mean of submitted tensors.""" 24 | 25 | def __init__(self, shape, dtype, name): 26 | """Specify the shape and dtype of the mean to be estimated. 27 | 28 | Note that a float mean to zero submitted elements is NaN, while computing 29 | the integer mean of zero elements raises a division by zero error. 30 | 31 | Args: 32 | shape: Shape of the mean to compute. 33 | dtype: Data type of the mean to compute. 34 | """ 35 | self._dtype = dtype 36 | with tf.variable_scope(name): 37 | self._sum = tf.get_variable( 38 | 'sum', shape, dtype, 39 | tf.constant_initializer(0), 40 | trainable=False) 41 | self._count = tf.get_variable( 42 | 'count', (), tf.int32, 43 | tf.constant_initializer(0), 44 | trainable=False) 45 | 46 | @property 47 | def value(self): 48 | """The current value of the mean.""" 49 | return self._sum / tf.cast(self._count, self._dtype) 50 | 51 | @property 52 | def count(self): 53 | """The number of submitted samples.""" 54 | return self._count 55 | 56 | def submit(self, value): 57 | """Submit a single or batch tensor to refine the streaming mean.""" 58 | value = tf.convert_to_tensor(value) 59 | # Add a batch dimension if necessary. 60 | if value.shape.ndims == self._sum.shape.ndims: 61 | value = value[None, ...] 62 | if str(value.shape[1:]) != str(self._sum.shape): 63 | message = 'Value shape ({}) does not fit tracked tensor ({}).' 64 | raise ValueError(message.format(value.shape[1:], self._sum.shape)) 65 | def assign(): 66 | return tf.group( 67 | self._sum.assign_add(tf.reduce_sum(value, 0)), 68 | self._count.assign_add(tf.shape(value)[0])) 69 | not_empty = tf.cast(tf.reduce_prod(tf.shape(value)), tf.bool) 70 | return tf.cond(not_empty, assign, tf.no_op) 71 | 72 | def clear(self): 73 | """Return the mean estimate and reset the streaming statistics.""" 74 | value = self._sum / tf.cast(self._count, self._dtype) 75 | with tf.control_dependencies([value]): 76 | reset_value = self._sum.assign(tf.zeros_like(self._sum)) 77 | reset_count = self._count.assign(0) 78 | with tf.control_dependencies([reset_value, reset_count]): 79 | return tf.identity(value) 80 | -------------------------------------------------------------------------------- /planet/tools/summary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import matplotlib 20 | matplotlib.use('Agg') 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | from planet.tools import count_dataset 26 | from planet.tools import gif_summary 27 | from planet.tools import image_strip_summary 28 | from planet.tools import mask as masklib 29 | from planet.tools import shape as shapelib 30 | 31 | 32 | def plot_summary(titles, lines, labels, name): 33 | """Plot lines using matplotlib and create a TensorFlow summary from it. 34 | 35 | Note that only one instance of this summary can be computed at the same time. 36 | This is because matplotlib uses global state. A workaround is to make earlier 37 | plot summaries control dependences of later ones. 38 | 39 | Args: 40 | titles: List of titles for the subplots. 41 | lines: Nested list of tensors. Each list contains the lines of another 42 | subplot in the figure. 43 | labels: Nested list of strings. Each list contains the names for the lines 44 | of another subplot in the figure. Can be None for any of the sub plots. 45 | name: Name of the summary. 46 | 47 | Returns: 48 | Summary tensor. 49 | """ 50 | 51 | def body_fn(lines): 52 | fig, axes = plt.subplots( 53 | nrows=len(titles), ncols=1, sharex=True, sharey=False, 54 | squeeze=False, figsize=(6, 3 * len(lines))) 55 | axes = axes[:, 0] 56 | for index, ax in enumerate(axes): 57 | ax.set_title(titles[index]) 58 | for line, label in zip(lines[index], labels[index]): 59 | ax.plot(line, label=label) 60 | if any(labels[index]): 61 | ax.legend(frameon=False) 62 | fig.tight_layout() 63 | fig.canvas.draw() 64 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 65 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 66 | plt.close(fig) 67 | return image 68 | 69 | image = tf.py_func(body_fn, (lines,), tf.uint8) 70 | image = image[None] 71 | summary = tf.summary.image(name, image) 72 | return summary 73 | 74 | 75 | def data_summaries(data, postprocess_fn, histograms=False, name='data'): 76 | summaries = [] 77 | with tf.variable_scope(name): 78 | if histograms: 79 | for key, value in data.items(): 80 | if key in ('image',): 81 | continue 82 | summaries.append(tf.summary.histogram(key, data[key])) 83 | image = data['image'] 84 | if postprocess_fn: 85 | image = postprocess_fn(image) 86 | summaries.append(image_strip_summary.image_strip_summary('image', image)) 87 | return summaries 88 | 89 | 90 | def dataset_summaries(directory, name='dataset'): 91 | summaries = [] 92 | with tf.variable_scope(name): 93 | episodes = count_dataset.count_dataset(directory) 94 | summaries.append(tf.summary.scalar('episodes', episodes)) 95 | return summaries 96 | 97 | 98 | def state_summaries( 99 | cell, prior, posterior, mask, histograms=False, name='state'): 100 | summaries = [] 101 | divergence = cell.divergence_from_states(posterior, prior, mask) 102 | prior = cell.dist_from_state(prior, mask) 103 | posterior = cell.dist_from_state(posterior, mask) 104 | prior_entropy = prior.entropy() 105 | posterior_entropy = posterior.entropy() 106 | nan_to_num = lambda x: tf.where(tf.is_nan(x), tf.zeros_like(x), x) 107 | with tf.variable_scope(name): 108 | if histograms: 109 | summaries.append(tf.summary.histogram( 110 | 'prior_entropy_hist', nan_to_num(prior_entropy))) 111 | summaries.append(tf.summary.scalar( 112 | 'prior_entropy', tf.reduce_mean(prior_entropy))) 113 | summaries.append(tf.summary.scalar( 114 | 'prior_std', tf.reduce_mean(prior.stddev()))) 115 | if histograms: 116 | summaries.append(tf.summary.histogram( 117 | 'posterior_entropy_hist', nan_to_num(posterior_entropy))) 118 | summaries.append(tf.summary.scalar( 119 | 'posterior_entropy', tf.reduce_mean(posterior_entropy))) 120 | summaries.append(tf.summary.scalar( 121 | 'posterior_std', tf.reduce_mean(posterior.stddev()))) 122 | summaries.append(tf.summary.scalar( 123 | 'divergence', tf.reduce_mean(divergence))) 124 | return summaries 125 | 126 | 127 | def dist_summaries(dists, obs, mask, name='dist_summaries'): 128 | summaries = [] 129 | with tf.variable_scope(name): 130 | for name, dist in dists.items(): 131 | mode = dist.mode() 132 | mode_mean, mode_var = tf.nn.moments(mode, list(range(mode.shape.ndims))) 133 | mode_std = tf.sqrt(mode_var) 134 | summaries.append(tf.summary.scalar(name + '_mode_mean', mode_mean)) 135 | summaries.append(tf.summary.scalar(name + '_mode_std', mode_std)) 136 | std = dist.stddev() 137 | std_mean, std_var = tf.nn.moments(std, list(range(std.shape.ndims))) 138 | std_std = tf.sqrt(std_var) 139 | summaries.append(tf.summary.scalar(name + '_std_mean', std_mean)) 140 | summaries.append(tf.summary.scalar(name + '_std_std', std_std)) 141 | if name in obs: 142 | log_prob = tf.reduce_mean(dist.log_prob(obs[name])) 143 | summaries.append(tf.summary.scalar(name + '_log_prob', log_prob)) 144 | abs_error = tf.reduce_mean(tf.abs(dist.mode() - obs[name])) 145 | summaries.append(tf.summary.scalar(name + '_abs_error', abs_error)) 146 | return summaries 147 | 148 | 149 | def image_summaries(dist, target, name='image', max_batch=10): 150 | summaries = [] 151 | with tf.variable_scope(name): 152 | empty_frame = 0 * target[:max_batch, :1] 153 | image = dist.mode()[:max_batch] 154 | target = target[:max_batch] 155 | change = tf.concat([empty_frame, image[:, 1:] - image[:, :-1]], 1) 156 | error = image - target 157 | summaries.append(image_strip_summary.image_strip_summary( 158 | 'prediction', image)) 159 | summaries.append(image_strip_summary.image_strip_summary( 160 | 'change', (change + 1) / 2)) 161 | summaries.append(image_strip_summary.image_strip_summary( 162 | 'error', (error + 1) / 2)) 163 | # Concat prediction and target vertically. 164 | frames = tf.concat([target, image], 2) 165 | # Stack batch entries horizontally. 166 | frames = tf.transpose(frames, [1, 2, 0, 3, 4]) 167 | s = shapelib.shape(frames) 168 | frames = tf.reshape(frames, [s[0], s[1], s[2] * s[3], s[4]]) 169 | summaries.append(gif_summary.gif_summary( 170 | 'animation', frames[None], max_outputs=1, fps=20)) 171 | return summaries 172 | 173 | 174 | def objective_summaries(objectives, name='objectives'): 175 | summaries = [] 176 | with tf.variable_scope(name): 177 | for objective in objectives: 178 | summaries.append(tf.summary.scalar(objective.name, objective.value)) 179 | return summaries 180 | 181 | 182 | def prediction_summaries(dists, data, state, name='state'): 183 | summaries = [] 184 | with tf.variable_scope(name): 185 | # Predictions. 186 | log_probs = {} 187 | for key, dist in dists.items(): 188 | if key in ('image',): 189 | continue 190 | if key not in data: 191 | continue 192 | # We only look at the first example in the batch. 193 | log_prob = dist.log_prob(data[key])[0] 194 | prediction = dist.mode()[0] 195 | truth = data[key][0] 196 | plot_name = key 197 | # Ensure that there is a feature dimension. 198 | if prediction.shape.ndims == 1: 199 | prediction = prediction[:, None] 200 | truth = truth[:, None] 201 | prediction = tf.unstack(tf.transpose(prediction, (1, 0))) 202 | truth = tf.unstack(tf.transpose(truth, (1, 0))) 203 | lines = list(zip(prediction, truth)) 204 | titles = ['{} {}'.format(key.title(), i) for i in range(len(lines))] 205 | labels = [['Prediction', 'Truth']] * len(lines) 206 | plot_name = '{}_trajectory'.format(key) 207 | # The control dependencies are needed because rendering in matplotlib 208 | # uses global state, so rendering two plots in parallel interferes. 209 | with tf.control_dependencies(summaries): 210 | summaries.append(plot_summary(titles, lines, labels, plot_name)) 211 | log_probs[key] = log_prob 212 | log_probs = sorted(log_probs.items(), key=lambda x: x[0]) 213 | titles, lines = zip(*log_probs) 214 | titles = [title.title() for title in titles] 215 | lines = [[line] for line in lines] 216 | labels = [[None]] * len(titles) 217 | plot_name = 'logprobs' 218 | with tf.control_dependencies(summaries): 219 | summaries.append(plot_summary(titles, lines, labels, plot_name)) 220 | return summaries 221 | -------------------------------------------------------------------------------- /planet/tools/target_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from planet.tools import schedule as schedule_lib 22 | from planet.tools import copy_weights 23 | 24 | 25 | def track_network( 26 | trainer, batch_size, source_pattern, target_pattern, every, amount): 27 | init_op = tf.cond( 28 | tf.equal(trainer.global_step, 0), 29 | lambda: copy_weights.soft_copy_weights( 30 | source_pattern, target_pattern, 1.0), 31 | tf.no_op) 32 | schedule = schedule_lib.binary(trainer.step, batch_size, 0, every, -1) 33 | with tf.control_dependencies([init_op]): 34 | return tf.cond( 35 | tf.logical_and(tf.equal(trainer.phase, 'train'), schedule), 36 | lambda: copy_weights.soft_copy_weights( 37 | source_pattern, target_pattern, amount), 38 | tf.no_op) 39 | -------------------------------------------------------------------------------- /planet/tools/test_nested.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | 21 | import tensorflow as tf 22 | 23 | from planet.tools import nested 24 | 25 | 26 | class ZipTest(tf.test.TestCase): 27 | 28 | def test_scalar(self): 29 | self.assertEqual(42, nested.zip(42)) 30 | self.assertEqual((13, 42), nested.zip(13, 42)) 31 | 32 | def test_empty(self): 33 | self.assertEqual({}, nested.zip({}, {})) 34 | 35 | def test_base_case(self): 36 | self.assertEqual((1, 2, 3), nested.zip(1, 2, 3)) 37 | 38 | def test_shallow_list(self): 39 | a = [1, 2, 3] 40 | b = [4, 5, 6] 41 | c = [7, 8, 9] 42 | result = nested.zip(a, b, c) 43 | self.assertEqual([(1, 4, 7), (2, 5, 8), (3, 6, 9)], result) 44 | 45 | def test_shallow_tuple(self): 46 | a = (1, 2, 3) 47 | b = (4, 5, 6) 48 | c = (7, 8, 9) 49 | result = nested.zip(a, b, c) 50 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result) 51 | 52 | def test_shallow_dict(self): 53 | a = {'a': 1, 'b': 2, 'c': 3} 54 | b = {'a': 4, 'b': 5, 'c': 6} 55 | c = {'a': 7, 'b': 8, 'c': 9} 56 | result = nested.zip(a, b, c) 57 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8), 'c': (3, 6, 9)}, result) 58 | 59 | def test_single(self): 60 | a = [[1, 2], 3] 61 | result = nested.zip(a) 62 | self.assertEqual(a, result) 63 | 64 | def test_mixed_structures(self): 65 | a = [(1, 2), 3, {'foo': [4]}] 66 | b = [(5, 6), 7, {'foo': [8]}] 67 | result = nested.zip(a, b) 68 | self.assertEqual([((1, 5), (2, 6)), (3, 7), {'foo': [(4, 8)]}], result) 69 | 70 | def test_different_types(self): 71 | a = [1, 2, 3] 72 | b = 'a b c'.split() 73 | result = nested.zip(a, b) 74 | self.assertEqual([(1, 'a'), (2, 'b'), (3, 'c')], result) 75 | 76 | def test_use_type_of_first(self): 77 | a = (1, 2, 3) 78 | b = [4, 5, 6] 79 | c = [7, 8, 9] 80 | result = nested.zip(a, b, c) 81 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result) 82 | 83 | def test_namedtuple(self): 84 | Foo = collections.namedtuple('Foo', 'value') 85 | foo, bar = Foo(42), Foo(13) 86 | self.assertEqual(Foo((42, 13)), nested.zip(foo, bar)) 87 | 88 | 89 | class MapTest(tf.test.TestCase): 90 | 91 | def test_scalar(self): 92 | self.assertEqual(42, nested.map(lambda x: x, 42)) 93 | 94 | def test_empty(self): 95 | self.assertEqual({}, nested.map(lambda x: x, {})) 96 | 97 | def test_shallow_list(self): 98 | self.assertEqual([2, 4, 6], nested.map(lambda x: 2 * x, [1, 2, 3])) 99 | 100 | def test_shallow_dict(self): 101 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4} 102 | self.assertEqual(data, nested.map(lambda x: x, data)) 103 | 104 | def test_mixed_structure(self): 105 | structure = [(1, 2), 3, {'foo': [4]}] 106 | result = nested.map(lambda x: 2 * x, structure) 107 | self.assertEqual([(2, 4), 6, {'foo': [8]}], result) 108 | 109 | def test_mixed_types(self): 110 | self.assertEqual([14, 'foofoo'], nested.map(lambda x: x * 2, [7, 'foo'])) 111 | 112 | def test_multiple_lists(self): 113 | a = [1, 2, 3] 114 | b = [4, 5, 6] 115 | c = [7, 8, 9] 116 | result = nested.map(lambda x, y, z: x + y + z, a, b, c) 117 | self.assertEqual([12, 15, 18], result) 118 | 119 | def test_namedtuple(self): 120 | Foo = collections.namedtuple('Foo', 'value') 121 | foo, bar = [Foo(42)], [Foo(13)] 122 | function = nested.map(lambda x, y: (y, x), foo, bar) 123 | self.assertEqual([Foo((13, 42))], function) 124 | function = nested.map(lambda x, y: x + y, foo, bar) 125 | self.assertEqual([Foo(55)], function) 126 | 127 | 128 | class FlattenTest(tf.test.TestCase): 129 | 130 | def test_scalar(self): 131 | self.assertEqual((42,), nested.flatten(42)) 132 | 133 | def test_empty(self): 134 | self.assertEqual((), nested.flatten({})) 135 | 136 | def test_base_case(self): 137 | self.assertEqual((1,), nested.flatten(1)) 138 | 139 | def test_convert_type(self): 140 | self.assertEqual((1, 2, 3), nested.flatten([1, 2, 3])) 141 | 142 | def test_mixed_structure(self): 143 | self.assertEqual((1, 2, 3, 4), nested.flatten([(1, 2), 3, {'foo': [4]}])) 144 | 145 | def test_value_ordering(self): 146 | self.assertEqual((1, 2, 3), nested.flatten({'a': 1, 'b': 2, 'c': 3})) 147 | 148 | 149 | class FilterTest(tf.test.TestCase): 150 | 151 | def test_empty(self): 152 | self.assertEqual({}, nested.filter(lambda x: True, {})) 153 | self.assertEqual({}, nested.filter(lambda x: False, {})) 154 | 155 | def test_base_case(self): 156 | self.assertEqual((), nested.filter(lambda x: False, 1)) 157 | 158 | def test_single_dict(self): 159 | predicate = lambda x: x % 2 == 0 160 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4} 161 | self.assertEqual({'b': 2, 'd': 4}, nested.filter(predicate, data)) 162 | 163 | def test_multiple_lists(self): 164 | a = [1, 2, 3] 165 | b = [4, 5, 6] 166 | c = [7, 8, 9] 167 | predicate = lambda *args: any(x % 4 == 0 for x in args) 168 | result = nested.filter(predicate, a, b, c) 169 | self.assertEqual([(1, 4, 7), (2, 5, 8)], result) 170 | 171 | def test_multiple_dicts(self): 172 | a = {'a': 1, 'b': 2, 'c': 3} 173 | b = {'a': 4, 'b': 5, 'c': 6} 174 | c = {'a': 7, 'b': 8, 'c': 9} 175 | predicate = lambda *args: any(x % 4 == 0 for x in args) 176 | result = nested.filter(predicate, a, b, c) 177 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8)}, result) 178 | 179 | def test_mixed_structure(self): 180 | predicate = lambda x: x % 2 == 0 181 | data = [(1, 2), 3, {'foo': [4]}] 182 | self.assertEqual([(2,), {'foo': [4]}], nested.filter(predicate, data)) 183 | 184 | def test_remove_empty_containers(self): 185 | data = [(1, 2, 3), 4, {'foo': [5, 6], 'bar': 7}] 186 | self.assertEqual([], nested.filter(lambda x: False, data)) 187 | 188 | def test_namedtuple(self): 189 | Foo = collections.namedtuple('Foo', 'value1, value2') 190 | self.assertEqual(Foo(1, None), nested.filter(lambda x: x == 1, Foo(1, 2))) 191 | 192 | def test_namedtuple_multiple(self): 193 | Foo = collections.namedtuple('Foo', 'value1, value2') 194 | foo = Foo(1, 2) 195 | bar = Foo(2, 3) 196 | result = nested.filter(lambda x, y: x + y > 3, foo, bar) 197 | self.assertEqual(Foo(None, (2, 3)), result) 198 | 199 | def test_namedtuple_nested(self): 200 | Foo = collections.namedtuple('Foo', 'value1, value2') 201 | foo = Foo(1, [1, 2, 3]) 202 | self.assertEqual(Foo(None, [2, 3]), nested.filter(lambda x: x > 1, foo)) 203 | -------------------------------------------------------------------------------- /planet/tools/test_overshooting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from planet import models 22 | from planet.tools import overshooting 23 | 24 | 25 | class _MockCell(models.Base): 26 | """Mock state space model. 27 | 28 | The transition function is to add the action to the observation. The 29 | posterior function is to return the ground truth observation. If actions or 30 | observations are collections, only their first element is used. 31 | """ 32 | 33 | def __init__(self, obs_size): 34 | self._obs_size = obs_size 35 | super(_MockCell, self).__init__( 36 | tf.make_template('transition', self._transition), 37 | tf.make_template('posterior', self._posterior)) 38 | 39 | @property 40 | def state_size(self): 41 | return {'obs': self._obs_size} 42 | 43 | def _transition(self, prev_state, prev_action, zero_obs): 44 | if isinstance(prev_action, (tuple, list)): 45 | prev_action = prev_action[0] 46 | return {'obs': prev_state['obs'] + prev_action} 47 | 48 | def _posterior(self, prev_state, prev_action, obs): 49 | if isinstance(obs, (tuple, list)): 50 | obs = obs[0] 51 | return {'obs': obs} 52 | 53 | 54 | class OvershootingTest(tf.test.TestCase): 55 | 56 | def test_example(self): 57 | obs = tf.constant([ 58 | [10, 20, 30, 40, 50, 60], 59 | [70, 80, 0, 0, 0, 0], 60 | ], dtype=tf.float32)[:, :, None] 61 | prev_action = tf.constant([ 62 | [0.0, 0.1, 0.2, 0.3, 0.4, 0.5], 63 | [9.0, 0.7, 0, 0, 0, 0], 64 | ], dtype=tf.float32)[:, :, None] 65 | length = tf.constant([6, 2], dtype=tf.int32) 66 | cell = _MockCell(1) 67 | _, prior, posterior, mask = overshooting( 68 | cell, obs, obs, prev_action, length, 3) 69 | prior = tf.squeeze(prior['obs'], 3) 70 | posterior = tf.squeeze(posterior['obs'], 3) 71 | mask = tf.to_int32(mask) 72 | with self.test_session(): 73 | # Each column corresponds to a different state step, and each row 74 | # corresponds to a different overshooting distance from there. 75 | self.assertAllEqual([ 76 | [1, 1, 1, 1, 1, 1], 77 | [1, 1, 1, 1, 1, 0], 78 | [1, 1, 1, 1, 0, 0], 79 | [1, 1, 1, 0, 0, 0], 80 | ], mask.eval()[0].T) 81 | self.assertAllEqual([ 82 | [1, 1, 0, 0, 0, 0], 83 | [1, 0, 0, 0, 0, 0], 84 | [0, 0, 0, 0, 0, 0], 85 | [0, 0, 0, 0, 0, 0], 86 | ], mask.eval()[1].T) 87 | self.assertAllClose([ 88 | [0.0, 10.1, 20.2, 30.3, 40.4, 50.5], 89 | [0.1, 10.3, 20.5, 30.7, 40.9, 0], 90 | [0.3, 10.6, 20.9, 31.2, 0, 0], 91 | [0.6, 11.0, 21.4, 0, 0, 0], 92 | ], prior.eval()[0].T) 93 | self.assertAllClose([ 94 | [10, 20, 30, 40, 50, 60], 95 | [20, 30, 40, 50, 60, 0], 96 | [30, 40, 50, 60, 0, 0], 97 | [40, 50, 60, 0, 0, 0], 98 | ], posterior.eval()[0].T) 99 | self.assertAllClose([ 100 | [9.0, 70.7, 0, 0, 0, 0], 101 | [9.7, 0, 0, 0, 0, 0], 102 | [0, 0, 0, 0, 0, 0], 103 | [0, 0, 0, 0, 0, 0], 104 | ], prior.eval()[1].T) 105 | self.assertAllClose([ 106 | [70, 80, 0, 0, 0, 0], 107 | [80, 0, 0, 0, 0, 0], 108 | [0, 0, 0, 0, 0, 0], 109 | [0, 0, 0, 0, 0, 0], 110 | ], posterior.eval()[1].T) 111 | 112 | def test_nested(self): 113 | obs = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)), tf.ones((3, 50, 3))) 114 | prev_action = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2))) 115 | length = tf.constant([49, 50, 3], dtype=tf.int32) 116 | cell = _MockCell(1) 117 | overshooting(cell, obs, obs, prev_action, length, 3) 118 | 119 | 120 | if __name__ == '__main__': 121 | tf.test.main() 122 | -------------------------------------------------------------------------------- /planet/tools/unroll.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | 21 | from planet.tools import nested 22 | from planet.tools import shape 23 | 24 | 25 | def closed_loop(cell, embedded, prev_action, debug=False): 26 | use_obs = tf.ones(tf.shape(embedded[:, :, :1])[:3], tf.bool) 27 | (prior, posterior), _ = tf.nn.dynamic_rnn( 28 | cell, (embedded, prev_action, use_obs), dtype=tf.float32) 29 | if debug: 30 | with tf.control_dependencies([tf.assert_equal( 31 | tf.shape(nested.flatten(posterior)[0])[1], tf.shape(embedded)[1])]): 32 | prior = nested.map(tf.identity, prior) 33 | posterior = nested.map(tf.identity, posterior) 34 | return prior, posterior 35 | 36 | 37 | def open_loop(cell, embedded, prev_action, context=1, debug=False): 38 | use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool) 39 | (_, closed_state), last_state = tf.nn.dynamic_rnn( 40 | cell, (embedded[:, :context], prev_action[:, :context], use_obs), 41 | dtype=tf.float32) 42 | use_obs = tf.zeros(tf.shape(embedded[:, context:, :1])[:3], tf.bool) 43 | (_, open_state), _ = tf.nn.dynamic_rnn( 44 | cell, (0 * embedded[:, context:], prev_action[:, context:], use_obs), 45 | initial_state=last_state) 46 | state = nested.map( 47 | lambda x, y: tf.concat([x, y], 1), 48 | closed_state, open_state) 49 | if debug: 50 | with tf.control_dependencies([tf.assert_equal( 51 | tf.shape(nested.flatten(state)[0])[1], tf.shape(embedded)[1])]): 52 | state = nested.map(tf.identity, state) 53 | return state 54 | 55 | 56 | def planned( 57 | cell, objective_fn, embedded, prev_action, planner, context=1, length=20, 58 | amount=1000, debug=False): 59 | use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool) 60 | (_, closed_state), last_state = tf.nn.dynamic_rnn( 61 | cell, (embedded[:, :context], prev_action[:, :context], use_obs), 62 | dtype=tf.float32) 63 | _, plan_state, return_ = planner( 64 | cell, objective_fn, last_state, 65 | obs_shape=shape.shape(embedded)[2:], 66 | action_shape=shape.shape(prev_action)[2:], 67 | horizon=length, amount=amount) 68 | state = nested.map( 69 | lambda x, y: tf.concat([x, y], 1), 70 | closed_state, plan_state) 71 | if debug: 72 | with tf.control_dependencies([tf.assert_equal( 73 | tf.shape(nested.flatten(state)[0])[1], context + length)]): 74 | state = nested.map(tf.identity, state) 75 | return_ = tf.identity(return_) 76 | return state, return_ 77 | -------------------------------------------------------------------------------- /planet/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from . import utility 20 | from .define_model import define_model 21 | from .define_summaries import define_summaries 22 | from .running import Experiment 23 | from .trainer import Trainer 24 | -------------------------------------------------------------------------------- /planet/training/define_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import functools 20 | 21 | import tensorflow as tf 22 | 23 | from planet import tools 24 | from planet.training import define_summaries 25 | from planet.training import utility 26 | 27 | 28 | def define_model(data, trainer, config): 29 | tf.logging.info('Build TensorFlow compute graph.') 30 | dependencies = [] 31 | cleanups = [] 32 | step = trainer.step 33 | global_step = trainer.global_step 34 | phase = trainer.phase 35 | 36 | # Instantiate network blocks. 37 | cell = config.cell() 38 | kwargs = dict(create_scope_now_=True) 39 | encoder = tf.make_template('encoder', config.encoder, **kwargs) 40 | heads = tools.AttrDict(_unlocked=True) 41 | dummy_features = cell.features_from_state(cell.zero_state(1, tf.float32)) 42 | for key, head in config.heads.items(): 43 | name = 'head_{}'.format(key) 44 | kwargs = dict(create_scope_now_=True) 45 | if key in data: 46 | kwargs['data_shape'] = data[key].shape[2:].as_list() 47 | elif key == 'action_target': 48 | kwargs['data_shape'] = data['action'].shape[2:].as_list() 49 | heads[key] = tf.make_template(name, head, **kwargs) 50 | heads[key](dummy_features) # Initialize weights. 51 | 52 | # Apply and optimize model. 53 | embedded = encoder(data) 54 | with tf.control_dependencies(dependencies): 55 | embedded = tf.identity(embedded) 56 | graph = tools.AttrDict(locals()) 57 | prior, posterior = tools.unroll.closed_loop( 58 | cell, embedded, data['action'], config.debug) 59 | objectives = utility.compute_objectives( 60 | posterior, prior, data, graph, config) 61 | summaries, grad_norms = utility.apply_optimizers( 62 | objectives, trainer, config) 63 | 64 | # Active data collection. 65 | with tf.variable_scope('collection'): 66 | with tf.control_dependencies(summaries): # Make sure to train first. 67 | for name, params in config.train_collects.items(): 68 | schedule = tools.schedule.binary( 69 | step, config.batch_shape[0], 70 | params.steps_after, params.steps_every, params.steps_until) 71 | summary, _ = tf.cond( 72 | tf.logical_and(tf.equal(trainer.phase, 'train'), schedule), 73 | functools.partial( 74 | utility.simulate_episodes, config, params, graph, cleanups, 75 | expensive_summaries=False, gif_summary=False, name=name), 76 | lambda: (tf.constant(''), tf.constant(0.0)), 77 | name='should_collect_' + name) 78 | summaries.append(summary) 79 | 80 | # Compute summaries. 81 | graph = tools.AttrDict(locals()) 82 | summary, score = tf.cond( 83 | trainer.log, 84 | lambda: define_summaries.define_summaries(graph, config, cleanups), 85 | lambda: (tf.constant(''), tf.zeros((0,), tf.float32)), 86 | name='summaries') 87 | summaries = tf.summary.merge([summaries, summary]) 88 | dependencies.append(utility.print_metrics( 89 | {ob.name: ob.value for ob in objectives}, 90 | step, config.print_metrics_every, 'objectives')) 91 | dependencies.append(utility.print_metrics( 92 | grad_norms, step, config.print_metrics_every, 'grad_norms')) 93 | with tf.control_dependencies(dependencies): 94 | score = tf.identity(score) 95 | return score, summaries, cleanups 96 | -------------------------------------------------------------------------------- /planet/training/define_summaries.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import tensorflow as tf 20 | from tensorflow_probability import distributions as tfd 21 | 22 | from planet import tools 23 | from planet.training import utility 24 | from planet.tools import summary 25 | 26 | 27 | def define_summaries(graph, config, cleanups): 28 | summaries = [] 29 | plot_summaries = [] # Control dependencies for non thread-safe matplot. 30 | length = graph.data['length'] 31 | mask = tf.range(graph.embedded.shape[1].value)[None, :] < length[:, None] 32 | heads = graph.heads.copy() 33 | last_time = tf.Variable(lambda: tf.timestamp(), trainable=False) 34 | last_step = tf.Variable(lambda: 0.0, trainable=False, dtype=tf.float64) 35 | 36 | def transform(dist): 37 | mean = config.postprocess_fn(dist.mean()) 38 | mean = tf.clip_by_value(mean, 0.0, 1.0) 39 | return tfd.Independent(tfd.Normal(mean, 1.0), len(dist.event_shape)) 40 | heads.unlock() 41 | heads['image'] = lambda features: transform(graph.heads['image'](features)) 42 | heads.lock() 43 | 44 | with tf.variable_scope('general'): 45 | summaries += summary.data_summaries(graph.data, config.postprocess_fn) 46 | summaries += summary.dataset_summaries(config.train_dir) 47 | summaries += summary.objective_summaries(graph.objectives) 48 | summaries.append(tf.summary.scalar('step', graph.step)) 49 | new_time, new_step = tf.timestamp(), tf.cast(graph.global_step, tf.float64) 50 | delta_time, delta_step = new_time - last_time, new_step - last_step 51 | with tf.control_dependencies([delta_time, delta_step]): 52 | assign_ops = [last_time.assign(new_time), last_step.assign(new_step)] 53 | with tf.control_dependencies(assign_ops): 54 | summaries.append(tf.summary.scalar( 55 | 'steps_per_second', delta_step / delta_time)) 56 | summaries.append(tf.summary.scalar( 57 | 'seconds_per_step', delta_time / delta_step)) 58 | 59 | with tf.variable_scope('closedloop'): 60 | prior, posterior = tools.unroll.closed_loop( 61 | graph.cell, graph.embedded, graph.data['action'], config.debug) 62 | summaries += summary.state_summaries(graph.cell, prior, posterior, mask) 63 | with tf.variable_scope('prior'): 64 | prior_features = graph.cell.features_from_state(prior) 65 | prior_dists = { 66 | name: head(prior_features) 67 | for name, head in heads.items()} 68 | summaries += summary.dist_summaries(prior_dists, graph.data, mask) 69 | summaries += summary.image_summaries( 70 | prior_dists['image'], config.postprocess_fn(graph.data['image'])) 71 | with tf.variable_scope('posterior'): 72 | posterior_features = graph.cell.features_from_state(posterior) 73 | posterior_dists = { 74 | name: head(posterior_features) 75 | for name, head in heads.items()} 76 | summaries += summary.dist_summaries( 77 | posterior_dists, graph.data, mask) 78 | summaries += summary.image_summaries( 79 | posterior_dists['image'], 80 | config.postprocess_fn(graph.data['image'])) 81 | 82 | with tf.variable_scope('openloop'): 83 | state = tools.unroll.open_loop( 84 | graph.cell, graph.embedded, graph.data['action'], 85 | config.open_loop_context, config.debug) 86 | state_features = graph.cell.features_from_state(state) 87 | state_dists = {name: head(state_features) for name, head in heads.items()} 88 | summaries += summary.dist_summaries(state_dists, graph.data, mask) 89 | summaries += summary.image_summaries( 90 | state_dists['image'], config.postprocess_fn(graph.data['image'])) 91 | summaries += summary.state_summaries(graph.cell, state, posterior, mask) 92 | with tf.control_dependencies(plot_summaries): 93 | plot_summary = summary.prediction_summaries( 94 | state_dists, graph.data, state) 95 | plot_summaries += plot_summary 96 | summaries += plot_summary 97 | 98 | with tf.variable_scope('simulation'): 99 | sim_returns = [] 100 | for name, params in config.test_collects.items(): 101 | # These are expensive and equivalent for train and test phases, so only 102 | # do one of them. 103 | sim_summary, sim_return = tf.cond( 104 | tf.equal(graph.phase, 'test'), 105 | lambda: utility.simulate_episodes( 106 | config, params, graph, cleanups, 107 | expensive_summaries=False, 108 | gif_summary=True, 109 | name=name), 110 | lambda: ('', 0.0), 111 | name='should_simulate_' + params.task.name) 112 | summaries.append(sim_summary) 113 | sim_returns.append(sim_return) 114 | 115 | summaries = tf.summary.merge(summaries) 116 | score = tf.reduce_mean(sim_returns)[None] 117 | return summaries, score 118 | -------------------------------------------------------------------------------- /planet/training/test_running.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import pickle 21 | import threading 22 | import time 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | from planet.training import running 28 | 29 | 30 | class TestExperiment(tf.test.TestCase): 31 | 32 | def test_no_kills(self): 33 | tf.logging.set_verbosity(tf.logging.INFO) 34 | basedir = os.path.join(tf.test.get_temp_dir(), 'test_no_kills') 35 | processes = [] 36 | for worker_name in range(20): 37 | processes.append(threading.Thread( 38 | target=_worker_normal, args=(basedir, str(worker_name)))) 39 | processes[-1].start() 40 | for process in processes: 41 | process.join() 42 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/DONE')) 43 | self.assertEqual(100, len(filepaths)) 44 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/PING')) 45 | self.assertEqual(100, len(filepaths)) 46 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/started')) 47 | self.assertEqual(100, len(filepaths)) 48 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/resumed')) 49 | self.assertEqual(0, len(filepaths)) 50 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/failed')) 51 | self.assertEqual(0, len(filepaths)) 52 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/numbers')) 53 | self.assertEqual(100, len(filepaths)) 54 | for filepath in filepaths: 55 | with tf.gfile.GFile(filepath, 'rb') as file_: 56 | self.assertEqual(10, len(pickle.load(file_))) 57 | 58 | def test_dying_workers(self): 59 | tf.logging.set_verbosity(tf.logging.INFO) 60 | basedir = os.path.join(tf.test.get_temp_dir(), 'test_dying_workers') 61 | processes = [] 62 | for worker_name in range(20): 63 | processes.append(threading.Thread( 64 | target=_worker_dying, args=(basedir, 15, str(worker_name)))) 65 | processes[-1].start() 66 | for process in processes: 67 | process.join() 68 | processes = [] 69 | for worker_name in range(20): 70 | processes.append(threading.Thread( 71 | target=_worker_normal, args=(basedir, str(worker_name)))) 72 | processes[-1].start() 73 | for process in processes: 74 | process.join() 75 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/DONE')) 76 | self.assertEqual(100, len(filepaths)) 77 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/PING')) 78 | self.assertEqual(100, len(filepaths)) 79 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/FAIL')) 80 | self.assertEqual(0, len(filepaths)) 81 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/started')) 82 | self.assertEqual(100, len(filepaths)) 83 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/resumed')) 84 | self.assertEqual(20, len(filepaths)) 85 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/numbers')) 86 | self.assertEqual(100, len(filepaths)) 87 | for filepath in filepaths: 88 | with tf.gfile.GFile(filepath, 'rb') as file_: 89 | self.assertEqual(10, len(pickle.load(file_))) 90 | 91 | 92 | def _worker_normal(basedir, worker_name): 93 | experiment = running.Experiment( 94 | basedir, _process_fn, _start_fn, _resume_fn, 95 | num_runs=100, worker_name=worker_name, ping_every=1.0) 96 | for run in experiment: 97 | for score in run: 98 | pass 99 | 100 | 101 | def _worker_dying(basedir, die_at_step, worker_name): 102 | experiment = running.Experiment( 103 | basedir, _process_fn, _start_fn, _resume_fn, 104 | num_runs=100, worker_name=worker_name, ping_every=1.0) 105 | step = 0 106 | for run in experiment: 107 | for score in run: 108 | step += 1 109 | if step >= die_at_step: 110 | return 111 | 112 | 113 | def _start_fn(logdir): 114 | assert not tf.gfile.Exists(os.path.join(logdir, 'DONE')) 115 | assert not tf.gfile.Exists(os.path.join(logdir, 'started')) 116 | assert not tf.gfile.Exists(os.path.join(logdir, 'resumed')) 117 | with tf.gfile.GFile(os.path.join(logdir, 'started'), 'w') as file_: 118 | file_.write('\n') 119 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'wb') as file_: 120 | pickle.dump([], file_) 121 | return [] 122 | 123 | 124 | def _resume_fn(logdir): 125 | assert not tf.gfile.Exists(os.path.join(logdir, 'DONE')) 126 | assert tf.gfile.Exists(os.path.join(logdir, 'started')) 127 | with tf.gfile.GFile(os.path.join(logdir, 'resumed'), 'w') as file_: 128 | file_.write('\n') 129 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'rb') as file_: 130 | numbers = pickle.load(file_) 131 | if len(numbers) != 5: 132 | raise Exception('Expected to be resumed in the middle for this test.') 133 | return numbers 134 | 135 | 136 | def _process_fn(logdir, numbers): 137 | assert tf.gfile.Exists(os.path.join(logdir, 'started')) 138 | while len(numbers) < 10: 139 | number = np.random.uniform(0, 0.1) 140 | time.sleep(number) 141 | numbers.append(number) 142 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'wb') as file_: 143 | pickle.dump(numbers, file_) 144 | yield number 145 | 146 | 147 | if __name__ == '__main__': 148 | tf.test.main() 149 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The PlaNet Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Setup script for PlaNet.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import setuptools 22 | 23 | 24 | setuptools.setup( 25 | name='planetrl', 26 | version='1.0.0', 27 | description=( 28 | 'Deep Planning Network: Control from pixels by latent planning ' + 29 | 'with learned dynamics.'), 30 | license='Apache 2.0', 31 | url='http://github.com/google-research/planet', 32 | install_requires=[ 33 | 'dm_control', 34 | 'gym', 35 | 'matplotlib', 36 | 'ruamel.yaml', 37 | 'scikit-image', 38 | 'scipy', 39 | 'tensorflow-gpu==1.13.1', 40 | 'tensorflow_probability==0.6.0', 41 | ], 42 | packages=setuptools.find_packages(), 43 | classifiers=[ 44 | 'Programming Language :: Python :: 2', 45 | 'Programming Language :: Python :: 3', 46 | 'License :: OSI Approved :: Apache Software License', 47 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 48 | 'Intended Audience :: Science/Research', 49 | ], 50 | ) 51 | --------------------------------------------------------------------------------