├── .gitignore
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── planet
├── __init__.py
├── control
│ ├── __init__.py
│ ├── batch_env.py
│ ├── dummy_env.py
│ ├── in_graph_batch_env.py
│ ├── mpc_agent.py
│ ├── planning.py
│ ├── random_episodes.py
│ ├── simulate.py
│ ├── temporal_difference.py
│ └── wrappers.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── drnn.py
│ ├── rssm.py
│ └── ssm.py
├── networks
│ ├── __init__.py
│ ├── basic.py
│ └── conv_ha.py
├── scripts
│ ├── __init__.py
│ ├── configs.py
│ ├── create_video.py
│ ├── fetch_events.py
│ ├── objectives.py
│ ├── sync.py
│ ├── tasks.py
│ ├── test_planet.py
│ └── train.py
├── tools
│ ├── __init__.py
│ ├── attr_dict.py
│ ├── bind.py
│ ├── bound_action.py
│ ├── chunk_sequence.py
│ ├── copy_weights.py
│ ├── count_dataset.py
│ ├── count_weights.py
│ ├── custom_optimizer.py
│ ├── filter_variables_lib.py
│ ├── gif_summary.py
│ ├── image_strip_summary.py
│ ├── mask.py
│ ├── nested.py
│ ├── numpy_episodes.py
│ ├── overshooting.py
│ ├── preprocess.py
│ ├── reshape_as.py
│ ├── schedule.py
│ ├── shape.py
│ ├── streaming_mean.py
│ ├── summary.py
│ ├── target_network.py
│ ├── test_nested.py
│ ├── test_overshooting.py
│ └── unroll.py
└── training
│ ├── __init__.py
│ ├── define_model.py
│ ├── define_summaries.py
│ ├── running.py
│ ├── test_running.py
│ ├── trainer.py
│ └── utility.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *.egg-info
4 | /dist
5 | MUJOCO_LOG.TXT
6 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of PlaNet authors for copyright purposes.
2 | #
3 | # This does not necessarily list everyone who has contributed code, since in
4 | # some cases, their employer may be the copyright holder. To see the full list
5 | # of contributors, see the revision history in source control.
6 | Google LLC
7 | Danijar Hafner
8 | Timothy Lillicrap
9 | Ian Fischer
10 | Ruben Villegas
11 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Planning Network
2 |
3 | Danijar Hafner, Timothy Lillicrap, Ian Fischer, Ruben Villegas, David Ha, Honglak Lee, James Davidson
4 |
5 | 
6 |
7 | This project provides the open source implementation of the PlaNet agent
8 | introduced in [Learning Latent Dynamics for Planning from Pixels][paper].
9 | PlaNet is a purely model-based reinforcement learning algorithm that solves
10 | control tasks from images by efficient planning in a learned latent space.
11 | PlaNet competes with top model-free methods in terms of final performance and
12 | training time while using substantially less interaction with the environment.
13 |
14 | If you find this open source release useful, please reference in your paper:
15 |
16 | ```
17 | @inproceedings{hafner2019planet,
18 | title={Learning Latent Dynamics for Planning from Pixels},
19 | author={Hafner, Danijar and Lillicrap, Timothy and Fischer, Ian and Villegas, Ruben and Ha, David and Lee, Honglak and Davidson, James},
20 | booktitle={International Conference on Machine Learning},
21 | pages={2555--2565},
22 | year={2019}
23 | }
24 | ```
25 |
26 | ## Method
27 |
28 | 
29 |
30 | PlaNet models the world as a compact sequence of hidden states. For planning,
31 | we first encode the history of past images into the current state. From there,
32 | we efficiently predict future rewards for multiple action sequences in latent
33 | space. We execute the first action of the best sequence found and replan after
34 | observing the next image.
35 |
36 | Find more information:
37 |
38 | - [Google AI Blog post][blog]
39 | - [Project website][website]
40 | - [PDF paper][paper]
41 |
42 | [blog]: https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html
43 | [website]: https://danijar.com/project/planet/
44 | [paper]: https://arxiv.org/pdf/1811.04551.pdf
45 |
46 | ## Instructions
47 |
48 | To train an agent, install the dependencies and then run:
49 |
50 | ```sh
51 | python3 -m planet.scripts.train --logdir /path/to/logdir --params '{tasks: [cheetah_run]}'
52 | ```
53 |
54 | The code prints `nan` as the score for iterations during which no summaries
55 | were computed.
56 |
57 | The available tasks are listed in `scripts/tasks.py`. The default parameters
58 | can be found in `scripts/configs.py`. To run the experiments from our
59 | paper, pass the following parameters to `--params {...}` in addition to the
60 | list of tasks:
61 |
62 | | Experiment | Parameters |
63 | | :--------- | :--------- |
64 | | PlaNet | No additional parameters. |
65 | | Random data collection | `planner_iterations: 0, train_action_noise: 1.0` |
66 | | Purely deterministic | `mean_only: True, divergence_scale: 0.0` |
67 | | Purely stochastic | `model: ssm` |
68 | | One agent all tasks | `collect_every: 30000` |
69 |
70 | Please note that the agent has seen some improvements so the results may be a
71 | bit different now.
72 |
73 | ## Modifications
74 |
75 | These are good places to start when modifying the code:
76 |
77 | | Directory | Description |
78 | | :-------- | :---------- |
79 | | `scripts/configs.py` | Add new parameters or change defaults. |
80 | | `scripts/tasks.py` | Add or modify environments. |
81 | | `models` | Add or modify latent transition models. |
82 | | `networks` | Add or modify encoder and decoder networks. |
83 |
84 | Tips for development:
85 |
86 | - You can set `--config debug` to reduce the episode length, batch size, and
87 | collect data more freqnently. This helps to quickly reach all parts of the
88 | code.
89 | - You can use `--num_runs 1000 --resume_runs False` to automatically start new
90 | runs in sub directories of the logdir every time to execute the script.
91 | - Environments live in separate processes by default. Some environments work
92 | better when separated into threads instead by specifying `--params
93 | '{isolate_envs: thread}'`.
94 |
95 | ## Dependencies
96 |
97 | The code was tested under Ubuntu 18 and uses these packages:
98 |
99 | - tensorflow-gpu==1.13.1
100 | - tensorflow_probability==0.6.0
101 | - dm_control (`egl` [rendering option][dmc-rendering] recommended)
102 | - gym
103 | - scikit-image
104 | - scipy
105 | - ruamel.yaml
106 | - matplotlib
107 |
108 | [dmc-rendering]: https://github.com/deepmind/dm_control#rendering
109 |
110 | Disclaimer: This is not an official Google product.
111 |
--------------------------------------------------------------------------------
/planet/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import control
20 | from . import models
21 | from . import networks
22 | from . import scripts
23 | from . import tools
24 | from . import training
25 |
--------------------------------------------------------------------------------
/planet/control/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import planning
20 | from .batch_env import BatchEnv
21 | from .dummy_env import DummyEnv
22 | from .in_graph_batch_env import InGraphBatchEnv
23 | from .mpc_agent import MPCAgent
24 | from .random_episodes import random_episodes
25 | from .simulate import simulate
26 | from .temporal_difference import discounted_return
27 | from .temporal_difference import fixed_step_return
28 | from .temporal_difference import lambda_return
29 |
--------------------------------------------------------------------------------
/planet/control/batch_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 |
21 |
22 | class BatchEnv(object):
23 | """Combine multiple environments to step them in batch."""
24 |
25 | def __init__(self, envs, blocking):
26 | """Combine multiple environments to step them in batch.
27 |
28 | To step environments in parallel, environments must support a
29 | `blocking=False` argument to their step and reset functions that makes them
30 | return callables instead to receive the result at a later time.
31 |
32 | Args:
33 | envs: List of environments.
34 | blocking: Step environments after another rather than in parallel.
35 |
36 | Raises:
37 | ValueError: Environments have different observation or action spaces.
38 | """
39 | self._envs = envs
40 | self._blocking = blocking
41 | observ_space = self._envs[0].observation_space
42 | if not all(env.observation_space == observ_space for env in self._envs):
43 | raise ValueError('All environments must use the same observation space.')
44 | action_space = self._envs[0].action_space
45 | if not all(env.action_space == action_space for env in self._envs):
46 | raise ValueError('All environments must use the same observation space.')
47 |
48 | def __len__(self):
49 | """Number of combined environments."""
50 | return len(self._envs)
51 |
52 | def __getitem__(self, index):
53 | """Access an underlying environment by index."""
54 | return self._envs[index]
55 |
56 | def __getattr__(self, name):
57 | """Forward unimplemented attributes to one of the original environments.
58 |
59 | Args:
60 | name: Attribute that was accessed.
61 |
62 | Returns:
63 | Value behind the attribute name one of the wrapped environments.
64 | """
65 | return getattr(self._envs[0], name)
66 |
67 | def step(self, actions):
68 | """Forward a batch of actions to the wrapped environments.
69 |
70 | Args:
71 | actions: Batched action to apply to the environment.
72 |
73 | Raises:
74 | ValueError: Invalid actions.
75 |
76 | Returns:
77 | Batch of observations, rewards, and done flags.
78 | """
79 | for index, (env, action) in enumerate(zip(self._envs, actions)):
80 | if not env.action_space.contains(action):
81 | message = 'Invalid action at index {}: {}'
82 | raise ValueError(message.format(index, action))
83 | if self._blocking:
84 | transitions = [
85 | env.step(action)
86 | for env, action in zip(self._envs, actions)]
87 | else:
88 | transitions = [
89 | env.step(action, blocking=False)
90 | for env, action in zip(self._envs, actions)]
91 | transitions = [transition() for transition in transitions]
92 | observs, rewards, dones, infos = zip(*transitions)
93 | observ = np.stack(observs)
94 | reward = np.stack(rewards).astype(np.float32)
95 | done = np.stack(dones)
96 | info = tuple(infos)
97 | return observ, reward, done, info
98 |
99 | def reset(self, indices=None):
100 | """Reset the environment and convert the resulting observation.
101 |
102 | Args:
103 | indices: The batch indices of environments to reset; defaults to all.
104 |
105 | Returns:
106 | Batch of observations.
107 | """
108 | if indices is None:
109 | indices = np.arange(len(self._envs))
110 | if self._blocking:
111 | observs = [self._envs[index].reset() for index in indices]
112 | else:
113 | observs = [self._envs[index].reset(blocking=False) for index in indices]
114 | observs = [observ() for observ in observs]
115 | observ = np.stack(observs)
116 | return observ
117 |
118 | def close(self):
119 | """Send close messages to the external process and join them."""
120 | for env in self._envs:
121 | if hasattr(env, 'close'):
122 | env.close()
123 |
--------------------------------------------------------------------------------
/planet/control/dummy_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import gym
20 | import numpy as np
21 |
22 |
23 | class DummyEnv(object):
24 |
25 | def __init__(self):
26 | self._random = np.random.RandomState(seed=0)
27 | self._step = None
28 |
29 | @property
30 | def observation_space(self):
31 | low = np.zeros([64, 64, 3], dtype=np.float32)
32 | high = np.ones([64, 64, 3], dtype=np.float32)
33 | spaces = {'image': gym.spaces.Box(low, high)}
34 | return gym.spaces.Dict(spaces)
35 |
36 | @property
37 | def action_space(self):
38 | low = -np.ones([5], dtype=np.float32)
39 | high = np.ones([5], dtype=np.float32)
40 | return gym.spaces.Box(low, high)
41 |
42 | def reset(self):
43 | self._step = 0
44 | obs = self.observation_space.sample()
45 | return obs
46 |
47 | def step(self, action):
48 | obs = self.observation_space.sample()
49 | reward = self._random.uniform(0, 1)
50 | self._step += 1
51 | done = self._step >= 1000
52 | info = {}
53 | return obs, reward, done, info
54 |
--------------------------------------------------------------------------------
/planet/control/in_graph_batch_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Batch of environments inside the TensorFlow graph."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import gym
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 |
26 | class InGraphBatchEnv(object):
27 | """Batch of environments inside the TensorFlow graph.
28 |
29 | The batch of environments will be stepped and reset inside of the graph using
30 | a tf.py_func(). The current batch of observations, actions, rewards, and done
31 | flags are held in according variables.
32 | """
33 |
34 | def __init__(self, batch_env):
35 | """Batch of environments inside the TensorFlow graph.
36 |
37 | Args:
38 | batch_env: Batch environment.
39 | """
40 | self._batch_env = batch_env
41 | batch_dims = (len(self._batch_env),)
42 | observ_shape = self._parse_shape(self._batch_env.observation_space)
43 | observ_dtype = self._parse_dtype(self._batch_env.observation_space)
44 | action_shape = self._parse_shape(self._batch_env.action_space)
45 | action_dtype = self._parse_dtype(self._batch_env.action_space)
46 | with tf.variable_scope('env_temporary'):
47 | self._observ = tf.get_variable(
48 | 'observ', batch_dims + observ_shape, observ_dtype,
49 | tf.constant_initializer(0), trainable=False)
50 | self._action = tf.get_variable(
51 | 'action', batch_dims + action_shape, action_dtype,
52 | tf.constant_initializer(0), trainable=False)
53 | self._reward = tf.get_variable(
54 | 'reward', batch_dims, tf.float32,
55 | tf.constant_initializer(0), trainable=False)
56 | # This variable should be boolean, but tf.scatter_update() does not
57 | # support boolean resource variables yet.
58 | self._done = tf.get_variable(
59 | 'done', batch_dims, tf.int32,
60 | tf.constant_initializer(False), trainable=False)
61 |
62 | def __getattr__(self, name):
63 | """Forward unimplemented attributes to one of the original environments.
64 |
65 | Args:
66 | name: Attribute that was accessed.
67 |
68 | Returns:
69 | Value behind the attribute name in one of the original environments.
70 | """
71 | return getattr(self._batch_env, name)
72 |
73 | def __len__(self):
74 | """Number of combined environments."""
75 | return len(self._batch_env)
76 |
77 | def __getitem__(self, index):
78 | """Access an underlying environment by index."""
79 | return self._batch_env[index]
80 |
81 | def step(self, action):
82 | """Step the batch of environments.
83 |
84 | The results of the step can be accessed from the variables defined below.
85 |
86 | Args:
87 | action: Tensor holding the batch of actions to apply.
88 |
89 | Returns:
90 | Operation.
91 | """
92 | with tf.name_scope('environment/simulate'):
93 | observ_dtype = self._parse_dtype(self._batch_env.observation_space)
94 | observ, reward, done = tf.py_func(
95 | lambda a: self._batch_env.step(a)[:3], [action],
96 | [observ_dtype, tf.float32, tf.bool], name='step')
97 | # reward = tf.cast(reward, tf.float32)
98 | return tf.group(
99 | self._observ.assign(observ),
100 | self._action.assign(action),
101 | self._reward.assign(reward),
102 | self._done.assign(tf.to_int32(done)))
103 |
104 | def reset(self, indices=None):
105 | """Reset the batch of environments.
106 |
107 | Args:
108 | indices: The batch indices of the environments to reset; defaults to all.
109 |
110 | Returns:
111 | Batch tensor of the new observations.
112 | """
113 | if indices is None:
114 | indices = tf.range(len(self._batch_env))
115 | observ_dtype = self._parse_dtype(self._batch_env.observation_space)
116 | observ = tf.py_func(
117 | self._batch_env.reset, [indices], observ_dtype, name='reset')
118 | reward = tf.zeros_like(indices, tf.float32)
119 | done = tf.zeros_like(indices, tf.int32)
120 | with tf.control_dependencies([
121 | tf.scatter_update(self._observ, indices, observ),
122 | tf.scatter_update(self._reward, indices, reward),
123 | tf.scatter_update(self._done, indices, tf.to_int32(done))]):
124 | return tf.identity(observ)
125 |
126 | @property
127 | def observ(self):
128 | """Access the variable holding the current observation."""
129 | return self._observ + 0
130 |
131 | @property
132 | def action(self):
133 | """Access the variable holding the last received action."""
134 | return self._action + 0
135 |
136 | @property
137 | def reward(self):
138 | """Access the variable holding the current reward."""
139 | return self._reward + 0
140 |
141 | @property
142 | def done(self):
143 | """Access the variable indicating whether the episode is done."""
144 | return tf.cast(self._done, tf.bool)
145 |
146 | def close(self):
147 | """Send close messages to the external process and join them."""
148 | self._batch_env.close()
149 |
150 | def _parse_shape(self, space):
151 | """Get a tensor shape from a OpenAI Gym space.
152 |
153 | Args:
154 | space: Gym space.
155 |
156 | Raises:
157 | NotImplementedError: For spaces other than Box and Discrete.
158 |
159 | Returns:
160 | Shape tuple.
161 | """
162 | if isinstance(space, gym.spaces.Discrete):
163 | return ()
164 | if isinstance(space, gym.spaces.Box):
165 | return space.shape
166 | raise NotImplementedError("Unsupported space '{}.'".format(space))
167 |
168 | def _parse_dtype(self, space):
169 | """Get a tensor dtype from a OpenAI Gym space.
170 |
171 | Args:
172 | space: Gym space.
173 |
174 | Raises:
175 | NotImplementedError: For spaces other than Box and Discrete.
176 |
177 | Returns:
178 | TensorFlow data type.
179 | """
180 | if isinstance(space, gym.spaces.Discrete):
181 | return tf.int32
182 | if isinstance(space, gym.spaces.Box):
183 | if space.low.dtype == np.uint8:
184 | return tf.uint8
185 | else:
186 | return tf.float32
187 | raise NotImplementedError()
188 |
--------------------------------------------------------------------------------
/planet/control/mpc_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from tensorflow_probability import distributions as tfd
20 | import tensorflow as tf
21 |
22 | from planet.tools import nested
23 |
24 |
25 | class MPCAgent(object):
26 |
27 | def __init__(self, batch_env, step, is_training, should_log, config):
28 | self._batch_env = batch_env
29 | self._step = step # Trainer step, not environment step.
30 | self._is_training = is_training
31 | self._should_log = should_log
32 | self._config = config
33 | self._cell = config.cell
34 | state = self._cell.zero_state(len(batch_env), tf.float32)
35 | var_like = lambda x: tf.get_local_variable(
36 | x.name.split(':')[0].replace('/', '_') + '_var',
37 | shape=x.shape,
38 | initializer=lambda *_, **__: tf.zeros_like(x), use_resource=True)
39 | self._state = nested.map(var_like, state)
40 | self._prev_action = tf.get_local_variable(
41 | 'prev_action_var', shape=self._batch_env.action.shape,
42 | initializer=lambda *_, **__: tf.zeros_like(self._batch_env.action),
43 | use_resource=True)
44 |
45 | def begin_episode(self, agent_indices):
46 | state = nested.map(
47 | lambda tensor: tf.gather(tensor, agent_indices),
48 | self._state)
49 | reset_state = nested.map(
50 | lambda var, val: tf.scatter_update(var, agent_indices, 0 * val),
51 | self._state, state, flatten=True)
52 | reset_prev_action = self._prev_action.assign(
53 | tf.zeros_like(self._prev_action))
54 | with tf.control_dependencies(reset_state + (reset_prev_action,)):
55 | return tf.constant('')
56 |
57 | def perform(self, agent_indices, observ):
58 | observ = self._config.preprocess_fn(observ)
59 | embedded = self._config.encoder({'image': observ[:, None]})[:, 0]
60 | state = nested.map(
61 | lambda tensor: tf.gather(tensor, agent_indices),
62 | self._state)
63 | prev_action = self._prev_action + 0
64 | with tf.control_dependencies([prev_action]):
65 | use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None]
66 | _, state = self._cell((embedded, prev_action, use_obs), state)
67 | action = self._config.planner(
68 | self._cell, self._config.objective, state,
69 | embedded.shape[1:].as_list(),
70 | prev_action.shape[1:].as_list())
71 | action = action[:, 0]
72 | if self._config.exploration:
73 | scale = self._config.exploration.scale
74 | if self._config.exploration.schedule:
75 | scale *= self._config.exploration.schedule(self._step)
76 | action = tfd.Normal(action, scale).sample()
77 | action = tf.clip_by_value(action, -1, 1)
78 | remember_action = self._prev_action.assign(action)
79 | remember_state = nested.map(
80 | lambda var, val: tf.scatter_update(var, agent_indices, val),
81 | self._state, state, flatten=True)
82 | with tf.control_dependencies(remember_state + (remember_action,)):
83 | return tf.identity(action), tf.constant('')
84 |
85 | def experience(self, agent_indices, *experience):
86 | return tf.constant('')
87 |
88 | def end_episode(self, agent_indices):
89 | return tf.constant('')
90 |
--------------------------------------------------------------------------------
/planet/control/planning.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from planet import tools
22 |
23 |
24 | def cross_entropy_method(
25 | cell, objective_fn, state, obs_shape, action_shape, horizon, graph,
26 | amount=1000, topk=100, iterations=10, min_action=-1, max_action=1):
27 | obs_shape, action_shape = tuple(obs_shape), tuple(action_shape)
28 | original_batch = tools.shape(tools.nested.flatten(state)[0])[0]
29 | initial_state = tools.nested.map(lambda tensor: tf.tile(
30 | tensor, [amount] + [1] * (tensor.shape.ndims - 1)), state)
31 | extended_batch = tools.shape(tools.nested.flatten(initial_state)[0])[0]
32 | use_obs = tf.zeros([extended_batch, horizon, 1], tf.bool)
33 | obs = tf.zeros((extended_batch, horizon) + obs_shape)
34 |
35 | def iteration(mean_and_stddev, _):
36 | mean, stddev = mean_and_stddev
37 | # Sample action proposals from belief.
38 | normal = tf.random_normal((original_batch, amount, horizon) + action_shape)
39 | action = normal * stddev[:, None] + mean[:, None]
40 | action = tf.clip_by_value(action, min_action, max_action)
41 | # Evaluate proposal actions.
42 | action = tf.reshape(
43 | action, (extended_batch, horizon) + action_shape)
44 | (_, state), _ = tf.nn.dynamic_rnn(
45 | cell, (0 * obs, action, use_obs), initial_state=initial_state)
46 | return_ = objective_fn(state)
47 | return_ = tf.reshape(return_, (original_batch, amount))
48 | # Re-fit belief to the best ones.
49 | _, indices = tf.nn.top_k(return_, topk, sorted=False)
50 | indices += tf.range(original_batch)[:, None] * amount
51 | best_actions = tf.gather(action, indices)
52 | mean, variance = tf.nn.moments(best_actions, 1)
53 | stddev = tf.sqrt(variance + 1e-6)
54 | return mean, stddev
55 |
56 | mean = tf.zeros((original_batch, horizon) + action_shape)
57 | stddev = tf.ones((original_batch, horizon) + action_shape)
58 | if iterations < 1:
59 | return mean
60 | mean, stddev = tf.scan(
61 | iteration, tf.range(iterations), (mean, stddev), back_prop=False)
62 | mean, stddev = mean[-1], stddev[-1] # Select belief at last iterations.
63 | return mean
64 |
--------------------------------------------------------------------------------
/planet/control/random_episodes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from planet.control import wrappers
20 |
21 |
22 | def random_episodes(env_ctor, num_episodes, outdir=None):
23 | env = env_ctor()
24 | env = wrappers.CollectGymDataset(env, outdir)
25 | episodes = [] if outdir else None
26 | for _ in range(num_episodes):
27 | policy = lambda env, obs: env.action_space.sample()
28 | done = False
29 | obs = env.reset()
30 | while not done:
31 | action = policy(env, obs)
32 | obs, _, done, info = env.step(action)
33 | if outdir is None:
34 | episodes.append(info['episode'])
35 | try:
36 | env.close()
37 | except AttributeError:
38 | pass
39 | return episodes
40 |
--------------------------------------------------------------------------------
/planet/control/simulate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """In-graph simulation step of a vectorized algorithm with environments."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 |
23 | import tensorflow as tf
24 |
25 | from planet import tools
26 | from planet.control import batch_env
27 | from planet.control import in_graph_batch_env
28 | from planet.control import mpc_agent
29 | from planet.control import wrappers
30 | from planet.tools import streaming_mean
31 |
32 |
33 | def simulate(
34 | step, env_ctor, duration, num_agents, agent_config,
35 | isolate_envs='none', expensive_summaries=False,
36 | gif_summary=True, name='simulate'):
37 | summaries = []
38 | with tf.variable_scope(name):
39 | return_, image, action, reward, cleanup = collect_rollouts(
40 | step=step,
41 | env_ctor=env_ctor,
42 | duration=duration,
43 | num_agents=num_agents,
44 | agent_config=agent_config,
45 | isolate_envs=isolate_envs)
46 | return_mean = tf.reduce_mean(return_)
47 | summaries.append(tf.summary.scalar('return', return_mean))
48 | if expensive_summaries:
49 | summaries.append(tf.summary.histogram('return_hist', return_))
50 | summaries.append(tf.summary.histogram('reward_hist', reward))
51 | summaries.append(tf.summary.histogram('action_hist', action))
52 | summaries.append(tools.image_strip_summary(
53 | 'image', image, max_length=duration))
54 | if gif_summary:
55 | summaries.append(tools.gif_summary(
56 | 'animation', image, max_outputs=1, fps=20))
57 | summary = tf.summary.merge(summaries)
58 | return summary, return_mean, cleanup
59 |
60 |
61 | def collect_rollouts(
62 | step, env_ctor, duration, num_agents, agent_config, isolate_envs):
63 | batch_env = define_batch_env(env_ctor, num_agents, isolate_envs)
64 | agent = mpc_agent.MPCAgent(batch_env, step, False, False, agent_config)
65 | cleanup = lambda: batch_env.close()
66 |
67 | def simulate_fn(unused_last, step):
68 | done, score, unused_summary = simulate_step(
69 | batch_env, agent,
70 | log=False,
71 | reset=tf.equal(step, 0))
72 | with tf.control_dependencies([done, score]):
73 | image = batch_env.observ
74 | batch_action = batch_env.action
75 | batch_reward = batch_env.reward
76 | return done, score, image, batch_action, batch_reward
77 |
78 | initializer = (
79 | tf.zeros([num_agents], tf.bool),
80 | tf.zeros([num_agents], tf.float32),
81 | 0 * batch_env.observ,
82 | 0 * batch_env.action,
83 | tf.zeros([num_agents], tf.float32))
84 | done, score, image, action, reward = tf.scan(
85 | simulate_fn, tf.range(duration),
86 | initializer, parallel_iterations=1)
87 | score = tf.boolean_mask(score, done)
88 | image = tf.transpose(image, [1, 0, 2, 3, 4])
89 | action = tf.transpose(action, [1, 0, 2])
90 | reward = tf.transpose(reward)
91 | return score, image, action, reward, cleanup
92 |
93 |
94 | def define_batch_env(env_ctor, num_agents, isolate_envs):
95 | with tf.variable_scope('environments'):
96 | if isolate_envs == 'none':
97 | factory = lambda ctor: ctor()
98 | blocking = True
99 | elif isolate_envs == 'thread':
100 | factory = functools.partial(wrappers.Async, strategy='thread')
101 | blocking = False
102 | elif isolate_envs == 'process':
103 | factory = functools.partial(wrappers.Async, strategy='process')
104 | blocking = False
105 | else:
106 | raise NotImplementedError(isolate_envs)
107 | envs = [factory(env_ctor) for _ in range(num_agents)]
108 | env = batch_env.BatchEnv(envs, blocking)
109 | env = in_graph_batch_env.InGraphBatchEnv(env)
110 | return env
111 |
112 |
113 | def simulate_step(batch_env, algo, log=True, reset=False):
114 | """Simulation step of a vectorized algorithm with in-graph environments.
115 |
116 | Integrates the operations implemented by the algorithm and the environments
117 | into a combined operation.
118 |
119 | Args:
120 | batch_env: In-graph batch environment.
121 | algo: Algorithm instance implementing required operations.
122 | log: Tensor indicating whether to compute and return summaries.
123 | reset: Tensor causing all environments to reset.
124 |
125 | Returns:
126 | Tuple of tensors containing done flags for the current episodes, possibly
127 | intermediate scores for the episodes, and a summary tensor.
128 | """
129 |
130 | def _define_begin_episode(agent_indices):
131 | """Reset environments, intermediate scores and durations for new episodes.
132 |
133 | Args:
134 | agent_indices: Tensor containing batch indices starting an episode.
135 |
136 | Returns:
137 | Summary tensor, new score tensor, and new length tensor.
138 | """
139 | assert agent_indices.shape.ndims == 1
140 | zero_scores = tf.zeros_like(agent_indices, tf.float32)
141 | zero_durations = tf.zeros_like(agent_indices)
142 | update_score = tf.scatter_update(score_var, agent_indices, zero_scores)
143 | update_length = tf.scatter_update(
144 | length_var, agent_indices, zero_durations)
145 | reset_ops = [
146 | batch_env.reset(agent_indices), update_score, update_length]
147 | with tf.control_dependencies(reset_ops):
148 | return algo.begin_episode(agent_indices), update_score, update_length
149 |
150 | def _define_step():
151 | """Request actions from the algorithm and apply them to the environments.
152 |
153 | Increments the lengths of all episodes and increases their scores by the
154 | current reward. After stepping the environments, provides the full
155 | transition tuple to the algorithm.
156 |
157 | Returns:
158 | Summary tensor, new score tensor, and new length tensor.
159 | """
160 | prevob = batch_env.observ + 0 # Ensure a copy of the variable value.
161 | agent_indices = tf.range(len(batch_env))
162 | action, step_summary = algo.perform(agent_indices, prevob)
163 | action.set_shape(batch_env.action.shape)
164 | with tf.control_dependencies([batch_env.step(action)]):
165 | add_score = score_var.assign_add(batch_env.reward)
166 | inc_length = length_var.assign_add(tf.ones(len(batch_env), tf.int32))
167 | with tf.control_dependencies([add_score, inc_length]):
168 | agent_indices = tf.range(len(batch_env))
169 | experience_summary = algo.experience(
170 | agent_indices, prevob,
171 | batch_env.action,
172 | batch_env.reward,
173 | batch_env.done,
174 | batch_env.observ)
175 | summary = tf.summary.merge([step_summary, experience_summary])
176 | return summary, add_score, inc_length
177 |
178 | def _define_end_episode(agent_indices):
179 | """Notify the algorithm of ending episodes.
180 |
181 | Also updates the mean score and length counters used for summaries.
182 |
183 | Args:
184 | agent_indices: Tensor holding batch indices that end their episodes.
185 |
186 | Returns:
187 | Summary tensor.
188 | """
189 | assert agent_indices.shape.ndims == 1
190 | submit_score = mean_score.submit(tf.gather(score, agent_indices))
191 | submit_length = mean_length.submit(
192 | tf.cast(tf.gather(length, agent_indices), tf.float32))
193 | with tf.control_dependencies([submit_score, submit_length]):
194 | return algo.end_episode(agent_indices)
195 |
196 | def _define_summaries():
197 | """Reset the average score and duration, and return them as summary.
198 |
199 | Returns:
200 | Summary string.
201 | """
202 | score_summary = tf.cond(
203 | tf.logical_and(log, tf.cast(mean_score.count, tf.bool)),
204 | lambda: tf.summary.scalar('mean_score', mean_score.clear()), str)
205 | length_summary = tf.cond(
206 | tf.logical_and(log, tf.cast(mean_length.count, tf.bool)),
207 | lambda: tf.summary.scalar('mean_length', mean_length.clear()), str)
208 | return tf.summary.merge([score_summary, length_summary])
209 |
210 | with tf.name_scope('simulate'):
211 | log = tf.convert_to_tensor(log)
212 | reset = tf.convert_to_tensor(reset)
213 | with tf.variable_scope('simulate_temporary'):
214 | score_var = tf.get_variable(
215 | 'score', (len(batch_env),), tf.float32,
216 | tf.constant_initializer(0),
217 | trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
218 | length_var = tf.get_variable(
219 | 'length', (len(batch_env),), tf.int32,
220 | tf.constant_initializer(0),
221 | trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
222 | mean_score = streaming_mean.StreamingMean((), tf.float32, 'mean_score')
223 | mean_length = streaming_mean.StreamingMean((), tf.float32, 'mean_length')
224 | agent_indices = tf.cond(
225 | reset,
226 | lambda: tf.range(len(batch_env)),
227 | lambda: tf.cast(tf.where(batch_env.done)[:, 0], tf.int32))
228 | begin_episode, score, length = tf.cond(
229 | tf.cast(tf.shape(agent_indices)[0], tf.bool),
230 | lambda: _define_begin_episode(agent_indices),
231 | lambda: (str(), score_var, length_var))
232 | with tf.control_dependencies([begin_episode]):
233 | step, score, length = _define_step()
234 | with tf.control_dependencies([step]):
235 | agent_indices = tf.cast(tf.where(batch_env.done)[:, 0], tf.int32)
236 | end_episode = tf.cond(
237 | tf.cast(tf.shape(agent_indices)[0], tf.bool),
238 | lambda: _define_end_episode(agent_indices), str)
239 | with tf.control_dependencies([end_episode]):
240 | summary = tf.summary.merge([
241 | _define_summaries(), begin_episode, step, end_episode])
242 | with tf.control_dependencies([summary]):
243 | score = 0.0 + score
244 | done = batch_env.done
245 | return done, score, summary
246 |
--------------------------------------------------------------------------------
/planet/control/temporal_difference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Copmute discounted return."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 |
23 | def discounted_return(reward, discount, bootstrap, axis, stop_gradient=True):
24 | """Discounted Monte Carlo return."""
25 | if discount == 1 and bootstrap is None:
26 | return tf.reduce_sum(reward, axis)
27 | if discount == 1:
28 | return tf.reduce_sum(reward, axis) + bootstrap
29 | # Bring the aggregation dimension front.
30 | dims = list(range(reward.shape.ndims))
31 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
32 | reward = tf.transpose(reward, dims)
33 | if bootstrap is None:
34 | bootstrap = tf.zeros_like(reward[-1])
35 | return_ = tf.scan(
36 | fn=lambda agg, cur: cur + discount * agg,
37 | elems=reward,
38 | initializer=bootstrap,
39 | back_prop=not stop_gradient,
40 | reverse=True)
41 | return_ = tf.transpose(return_, dims)
42 | if stop_gradient:
43 | return_ = tf.stop_gradient(return_)
44 | return return_
45 |
46 |
47 | def lambda_return(
48 | reward, value, bootstrap, discount, lambda_, axis, stop_gradient=True):
49 | """Average of different multi-step returns.
50 |
51 | Setting lambda=1 gives a discounted Monte Carlo return.
52 | Setting lambda=0 gives a fixed 1-step return.
53 | """
54 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
55 | # Bring the aggregation dimension front.
56 | dims = list(range(reward.shape.ndims))
57 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
58 | reward = tf.transpose(reward, dims)
59 | value = tf.transpose(value, dims)
60 | if bootstrap is None:
61 | bootstrap = tf.zeros_like(value[-1])
62 | next_values = tf.concat([value[1:], bootstrap[None]], 0)
63 | inputs = reward + discount * next_values * (1 - lambda_)
64 | return_ = tf.scan(
65 | fn=lambda agg, cur: cur + discount * lambda_ * agg,
66 | elems=inputs,
67 | initializer=bootstrap,
68 | back_prop=not stop_gradient,
69 | reverse=True)
70 | return_ = tf.transpose(return_, dims)
71 | if stop_gradient:
72 | return_ = tf.stop_gradient(return_)
73 | return return_
74 |
75 |
76 | def fixed_step_return(
77 | reward, value, discount, steps, axis, stop_gradient=True):
78 | """Discounted N-step returns for fixed-length sequences."""
79 | # Brings the aggregation dimension front.
80 | dims = list(range(reward.shape.ndims))
81 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
82 | reward = tf.transpose(reward, dims)
83 | length = tf.shape(reward)[0]
84 | _, return_ = tf.while_loop(
85 | cond=lambda i, p: i < steps + 1,
86 | body=lambda i, p: (i + 1, reward[steps - i: length - i] + discount * p),
87 | loop_vars=[tf.constant(1), tf.zeros_like(reward[steps:])],
88 | back_prop=not stop_gradient)
89 | if value is not None:
90 | return_ += discount ** steps * tf.transpose(value, dims)[steps:]
91 | return_ = tf.transpose(return_, dims)
92 | if stop_gradient:
93 | return_ = tf.stop_gradient(return_)
94 | return return_
95 |
--------------------------------------------------------------------------------
/planet/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from .base import Base
20 | from .drnn import DRNN
21 | from .rssm import RSSM
22 | from .ssm import SSM
23 |
--------------------------------------------------------------------------------
/planet/models/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from planet import tools
22 |
23 |
24 | class Base(tf.nn.rnn_cell.RNNCell):
25 |
26 | def __init__(self, transition_tpl, posterior_tpl, reuse=None):
27 | super(Base, self).__init__(_reuse=reuse)
28 | self._posterior_tpl = posterior_tpl
29 | self._transition_tpl = transition_tpl
30 | self._debug = False
31 |
32 | @property
33 | def state_size(self):
34 | raise NotImplementedError
35 |
36 | @property
37 | def updates(self):
38 | return []
39 |
40 | @property
41 | def losses(self):
42 | return []
43 |
44 | @property
45 | def output_size(self):
46 | return (self.state_size, self.state_size)
47 |
48 | def zero_state(self, batch_size, dtype):
49 | return tools.nested.map(
50 | lambda size: tf.zeros([batch_size, size], dtype),
51 | self.state_size)
52 |
53 | def call(self, inputs, prev_state):
54 | obs, prev_action, use_obs = inputs
55 | if self._debug:
56 | with tf.control_dependencies([tf.assert_equal(use_obs, use_obs[0, 0])]):
57 | use_obs = tf.identity(use_obs)
58 | use_obs = use_obs[0, 0]
59 | zero_obs = tools.nested.map(tf.zeros_like, obs)
60 | prior = self._transition_tpl(prev_state, prev_action, zero_obs)
61 | posterior = tf.cond(
62 | use_obs,
63 | lambda: self._posterior_tpl(prev_state, prev_action, obs),
64 | lambda: prior)
65 | return (prior, posterior), posterior
66 |
--------------------------------------------------------------------------------
/planet/models/drnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | from tensorflow_probability import distributions as tfd
21 |
22 | from planet import tools
23 | from planet.models import base
24 |
25 |
26 | class DRNN(base.Base):
27 | r"""Doubly recurrent state-space model.
28 |
29 | Prior: Posterior:
30 |
31 | (a) (a) (a,o) (a,o)
32 | | | : : : :
33 | v v v v v v
34 | [e]--->[e] [e]...>[e]
35 | | | : :
36 | v v v v
37 | (s)--->(s) (s)--->(s)
38 | | | | |
39 | v v v v
40 | [d]--->[d] [d]--->[d]
41 | | | | |
42 | v v v v
43 | (o) (o) (o) (o)
44 | """
45 |
46 | def __init__(
47 | self, state_size, belief_size, embed_size,
48 | mean_only=False, min_stddev=1e-1, activation=tf.nn.elu,
49 | encoder_to_decoder=False, sample_to_sample=True,
50 | sample_to_encoder=True, decoder_to_encoder=False,
51 | decoder_to_sample=True, action_to_decoder=False):
52 | self._state_size = state_size
53 | self._belief_size = belief_size
54 | self._embed_size = embed_size
55 | self._encoder_cell = tf.contrib.rnn.GRUBlockCell(self._belief_size)
56 | self._decoder_cell = tf.contrib.rnn.GRUBlockCell(self._belief_size)
57 | self._kwargs = dict(units=self._embed_size, activation=tf.nn.relu)
58 | self._mean_only = mean_only
59 | self._min_stddev = min_stddev
60 | self._encoder_to_decoder = encoder_to_decoder
61 | self._sample_to_sample = sample_to_sample
62 | self._sample_to_encoder = sample_to_encoder
63 | self._decoder_to_encoder = decoder_to_encoder
64 | self._decoder_to_sample = decoder_to_sample
65 | self._action_to_decoder = action_to_decoder
66 | posterior_tpl = tf.make_template('posterior', self._posterior)
67 | super(DRNN, self).__init__(posterior_tpl, posterior_tpl)
68 |
69 | @property
70 | def state_size(self):
71 | return {
72 | 'encoder_state': self._encoder_cell.state_size,
73 | 'decoder_state': self._decoder_cell.state_size,
74 | 'mean': self._state_size,
75 | 'stddev': self._state_size,
76 | 'sample': self._state_size,
77 | }
78 |
79 | def dist_from_state(self, state, mask=None):
80 | """Extract the latent distribution from a prior or posterior state."""
81 | if mask is not None:
82 | stddev = tools.mask(state['stddev'], mask, value=1)
83 | else:
84 | stddev = state['stddev']
85 | dist = tfd.MultivariateNormalDiag(state['mean'], stddev)
86 | return dist
87 |
88 | def features_from_state(self, state):
89 | """Extract features for the decoder network from a prior or posterior."""
90 | return state['decoder_state']
91 |
92 | def divergence_from_states(self, lhs, rhs, mask=None):
93 | """Compute the divergence measure between two states."""
94 | lhs = self.dist_from_state(lhs, mask)
95 | rhs = self.dist_from_state(rhs, mask)
96 | divergence = tfd.kl_divergence(lhs, rhs)
97 | if mask is not None:
98 | divergence = tools.mask(divergence, mask)
99 | return divergence
100 |
101 | def _posterior(self, prev_state, prev_action, obs):
102 | """Compute posterior state from previous state and current observation."""
103 |
104 | # Recurrent encoder.
105 | encoder_inputs = [obs, prev_action]
106 | if self._sample_to_encoder:
107 | encoder_inputs.append(prev_state['sample'])
108 | if self._decoder_to_encoder:
109 | encoder_inputs.append(prev_state['decoder_state'])
110 | encoded, encoder_state = self._encoder_cell(
111 | tf.concat(encoder_inputs, -1), prev_state['encoder_state'])
112 |
113 | # Sample sequence.
114 | sample_inputs = [encoded]
115 | if self._sample_to_sample:
116 | sample_inputs.append(prev_state['sample'])
117 | if self._decoder_to_sample:
118 | sample_inputs.append(prev_state['decoder_state'])
119 | hidden = tf.layers.dense(
120 | tf.concat(sample_inputs, -1), **self._kwargs)
121 | mean = tf.layers.dense(hidden, self._state_size, None)
122 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus)
123 | stddev += self._min_stddev
124 | if self._mean_only:
125 | sample = mean
126 | else:
127 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample()
128 |
129 | # Recurrent decoder.
130 | decoder_inputs = [sample]
131 | if self._encoder_to_decoder:
132 | decoder_inputs.append(prev_state['encoder_state'])
133 | if self._action_to_decoder:
134 | decoder_inputs.append(prev_action)
135 | decoded, decoder_state = self._decoder_cell(
136 | tf.concat(decoder_inputs, -1), prev_state['decoder_state'])
137 |
138 | return {
139 | 'encoder_state': encoder_state,
140 | 'decoder_state': decoder_state,
141 | 'mean': mean,
142 | 'stddev': stddev,
143 | 'sample': sample,
144 | }
145 |
--------------------------------------------------------------------------------
/planet/models/rssm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | from tensorflow_probability import distributions as tfd
21 |
22 | from planet import tools
23 | from planet.models import base
24 |
25 |
26 | class RSSM(base.Base):
27 | """Deterministic and stochastic state model.
28 |
29 | The stochastic latent is computed from the hidden state at the same time
30 | step. If an observation is present, the posterior latent is compute from both
31 | the hidden state and the observation.
32 |
33 | Prior: Posterior:
34 |
35 | (a) (a)
36 | \ \
37 | v v
38 | [h]->[h] [h]->[h]
39 | ^ | ^ :
40 | / v / v
41 | (s) (s) (s) (s)
42 | ^
43 | :
44 | (o)
45 | """
46 |
47 | def __init__(
48 | self, state_size, belief_size, embed_size,
49 | future_rnn=True, mean_only=False, min_stddev=0.1, activation=tf.nn.elu,
50 | num_layers=1):
51 | self._state_size = state_size
52 | self._belief_size = belief_size
53 | self._embed_size = embed_size
54 | self._future_rnn = future_rnn
55 | self._cell = tf.contrib.rnn.GRUBlockCell(self._belief_size)
56 | self._kwargs = dict(units=self._embed_size, activation=activation)
57 | self._mean_only = mean_only
58 | self._min_stddev = min_stddev
59 | self._num_layers = num_layers
60 | super(RSSM, self).__init__(
61 | tf.make_template('transition', self._transition),
62 | tf.make_template('posterior', self._posterior))
63 |
64 | @property
65 | def state_size(self):
66 | return {
67 | 'mean': self._state_size,
68 | 'stddev': self._state_size,
69 | 'sample': self._state_size,
70 | 'belief': self._belief_size,
71 | 'rnn_state': self._belief_size,
72 | }
73 |
74 | def dist_from_state(self, state, mask=None):
75 | """Extract the latent distribution from a prior or posterior state."""
76 | if mask is not None:
77 | stddev = tools.mask(state['stddev'], mask, value=1)
78 | else:
79 | stddev = state['stddev']
80 | dist = tfd.MultivariateNormalDiag(state['mean'], stddev)
81 | return dist
82 |
83 | def features_from_state(self, state):
84 | """Extract features for the decoder network from a prior or posterior."""
85 | return tf.concat([state['sample'], state['belief']], -1)
86 |
87 | def divergence_from_states(self, lhs, rhs, mask=None):
88 | """Compute the divergence measure between two states."""
89 | lhs = self.dist_from_state(lhs, mask)
90 | rhs = self.dist_from_state(rhs, mask)
91 | divergence = tfd.kl_divergence(lhs, rhs)
92 | if mask is not None:
93 | divergence = tools.mask(divergence, mask)
94 | return divergence
95 |
96 | def _transition(self, prev_state, prev_action, zero_obs):
97 | """Compute prior next state by applying the transition dynamics."""
98 | hidden = tf.concat([prev_state['sample'], prev_action], -1)
99 | for _ in range(self._num_layers):
100 | hidden = tf.layers.dense(hidden, **self._kwargs)
101 | belief, rnn_state = self._cell(hidden, prev_state['rnn_state'])
102 | if self._future_rnn:
103 | hidden = belief
104 | for _ in range(self._num_layers):
105 | hidden = tf.layers.dense(hidden, **self._kwargs)
106 | mean = tf.layers.dense(hidden, self._state_size, None)
107 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus)
108 | stddev += self._min_stddev
109 | if self._mean_only:
110 | sample = mean
111 | else:
112 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample()
113 | return {
114 | 'mean': mean,
115 | 'stddev': stddev,
116 | 'sample': sample,
117 | 'belief': belief,
118 | 'rnn_state': rnn_state,
119 | }
120 |
121 | def _posterior(self, prev_state, prev_action, obs):
122 | """Compute posterior state from previous state and current observation."""
123 | prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs))
124 | hidden = tf.concat([prior['belief'], obs], -1)
125 | for _ in range(self._num_layers):
126 | hidden = tf.layers.dense(hidden, **self._kwargs)
127 | mean = tf.layers.dense(hidden, self._state_size, None)
128 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus)
129 | stddev += self._min_stddev
130 | if self._mean_only:
131 | sample = mean
132 | else:
133 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample()
134 | return {
135 | 'mean': mean,
136 | 'stddev': stddev,
137 | 'sample': sample,
138 | 'belief': prior['belief'],
139 | 'rnn_state': prior['rnn_state'],
140 | }
141 |
--------------------------------------------------------------------------------
/planet/models/ssm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | from tensorflow_probability import distributions as tfd
21 |
22 | from planet import tools
23 | from planet.models import base
24 |
25 |
26 | class SSM(base.Base):
27 | """Gaussian state space model.
28 |
29 | Implements the transition function and encoder using feed forward networks.
30 |
31 | Prior: Posterior:
32 |
33 | (a) (a)
34 | \ \
35 | v v
36 | (s)->(s) (s)->(s)
37 | ^
38 | :
39 | (o)
40 | """
41 |
42 | def __init__(
43 | self, state_size, embed_size,
44 | mean_only=False, activation=tf.nn.elu, min_stddev=1e-5):
45 | self._state_size = state_size
46 | self._embed_size = embed_size
47 | self._mean_only = mean_only
48 | self._min_stddev = min_stddev
49 | super(SSM, self).__init__(
50 | tf.make_template('transition', self._transition),
51 | tf.make_template('posterior', self._posterior))
52 | self._kwargs = dict(units=self._embed_size, activation=activation)
53 |
54 | @property
55 | def state_size(self):
56 | return {
57 | 'mean': self._state_size,
58 | 'stddev': self._state_size,
59 | 'sample': self._state_size,
60 | }
61 |
62 | def dist_from_state(self, state, mask=None):
63 | """Extract the latent distribution from a prior or posterior state."""
64 | if mask is not None:
65 | stddev = tools.mask(state['stddev'], mask, value=1)
66 | else:
67 | stddev = state['stddev']
68 | dist = tfd.MultivariateNormalDiag(state['mean'], stddev)
69 | return dist
70 |
71 | def features_from_state(self, state):
72 | """Extract features for the decoder network from a prior or posterior."""
73 | return state['sample']
74 |
75 | def divergence_from_states(self, lhs, rhs, mask=None):
76 | """Compute the divergence measure between two states."""
77 | lhs = self.dist_from_state(lhs, mask)
78 | rhs = self.dist_from_state(rhs, mask)
79 | divergence = tfd.kl_divergence(lhs, rhs)
80 | if mask is not None:
81 | divergence = tools.mask(divergence, mask)
82 | return divergence
83 |
84 | def _transition(self, prev_state, prev_action, zero_obs):
85 | """Compute prior next state by applying the transition dynamics."""
86 | inputs = tf.concat([prev_state['sample'], prev_action], -1)
87 | hidden = tf.layers.dense(inputs, **self._kwargs)
88 | mean = tf.layers.dense(hidden, self._state_size, None)
89 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus)
90 | stddev += self._min_stddev
91 | if self._mean_only:
92 | sample = mean
93 | else:
94 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample()
95 | return {
96 | 'mean': mean,
97 | 'stddev': stddev,
98 | 'sample': sample,
99 | }
100 |
101 | def _posterior(self, prev_state, prev_action, obs):
102 | """Compute posterior state from previous state and current observation."""
103 | prior = self._transition_tpl(prev_state, prev_action, tf.zeros_like(obs))
104 | inputs = tf.concat([prior['mean'], prior['stddev'], obs], -1)
105 | hidden = tf.layers.dense(inputs, **self._kwargs)
106 | mean = tf.layers.dense(hidden, self._state_size, None)
107 | stddev = tf.layers.dense(hidden, self._state_size, tf.nn.softplus)
108 | stddev += self._min_stddev
109 | if self._mean_only:
110 | sample = mean
111 | else:
112 | sample = tfd.MultivariateNormalDiag(mean, stddev).sample()
113 | return {
114 | 'mean': mean,
115 | 'stddev': stddev,
116 | 'sample': sample,
117 | }
118 |
--------------------------------------------------------------------------------
/planet/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import conv_ha
20 | from .basic import feed_forward
21 |
--------------------------------------------------------------------------------
/planet/networks/basic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | import tensorflow_probability as tfp
22 | from tensorflow_probability import distributions as tfd
23 |
24 | from planet import tools
25 |
26 |
27 | def feed_forward(
28 | state, data_shape, num_layers=2, activation=tf.nn.relu,
29 | mean_activation=None, stop_gradient=False, trainable=True, units=100,
30 | std=1.0, low=-1.0, high=1.0, dist='normal'):
31 | """Create a model returning unnormalized MSE distribution."""
32 | hidden = state
33 | if stop_gradient:
34 | hidden = tf.stop_gradient(hidden)
35 | for _ in range(num_layers):
36 | hidden = tf.layers.dense(hidden, units, activation)
37 | mean = tf.layers.dense(
38 | hidden, int(np.prod(data_shape)), mean_activation, trainable=trainable)
39 | mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape)
40 | if std == 'learned':
41 | std = tf.layers.dense(
42 | hidden, int(np.prod(data_shape)), None, trainable=trainable)
43 | std = tf.nn.softplus(std + 0.55) + 0.01
44 | std = tf.reshape(std, tools.shape(state)[:-1] + data_shape)
45 | if dist == 'normal':
46 | dist = tfd.Normal(mean, std)
47 | elif dist == 'truncated_normal':
48 | # https://www.desmos.com/calculator/3o96eyqxib
49 | dist = tfd.TruncatedNormal(mean, std, low, high)
50 | elif dist == 'tanh_normal':
51 | # https://www.desmos.com/calculator/sxpp7ectjv
52 | dist = tfd.Normal(mean, std)
53 | dist = tfd.TransformedDistribution(dist, tfp.bijectors.Tanh())
54 | elif dist == 'deterministic':
55 | dist = tfd.Deterministic(mean)
56 | else:
57 | raise NotImplementedError(dist)
58 | dist = tfd.Independent(dist, len(data_shape))
59 | return dist
60 |
--------------------------------------------------------------------------------
/planet/networks/conv_ha.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | from tensorflow_probability import distributions as tfd
22 |
23 | from planet import tools
24 |
25 |
26 | def encoder(obs):
27 | """Extract deterministic features from an observation."""
28 | kwargs = dict(strides=2, activation=tf.nn.relu)
29 | hidden = tf.reshape(obs['image'], [-1] + obs['image'].shape[2:].as_list())
30 | hidden = tf.layers.conv2d(hidden, 32, 4, **kwargs)
31 | hidden = tf.layers.conv2d(hidden, 64, 4, **kwargs)
32 | hidden = tf.layers.conv2d(hidden, 128, 4, **kwargs)
33 | hidden = tf.layers.conv2d(hidden, 256, 4, **kwargs)
34 | hidden = tf.layers.flatten(hidden)
35 | assert hidden.shape[1:].as_list() == [1024], hidden.shape.as_list()
36 | hidden = tf.reshape(hidden, tools.shape(obs['image'])[:2] + [
37 | np.prod(hidden.shape[1:].as_list())])
38 | return hidden
39 |
40 |
41 | def decoder(state, data_shape):
42 | """Compute the data distribution of an observation from its state."""
43 | kwargs = dict(strides=2, activation=tf.nn.relu)
44 | hidden = tf.layers.dense(state, 1024, None)
45 | hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1].value])
46 | hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs)
47 | hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs)
48 | hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs)
49 | mean = tf.layers.conv2d_transpose(hidden, 3, 6, strides=2)
50 | assert mean.shape[1:].as_list() == [64, 64, 3], mean.shape
51 | mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape)
52 | dist = tfd.Normal(mean, 1.0)
53 | dist = tfd.Independent(dist, len(data_shape))
54 | return dist
55 |
--------------------------------------------------------------------------------
/planet/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import configs
20 | from . import tasks
21 | from . import train
22 |
--------------------------------------------------------------------------------
/planet/scripts/create_video.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import os
21 | import shutil
22 |
23 | from matplotlib import animation
24 | import matplotlib as mpl
25 | import matplotlib.pyplot as plt
26 | import scipy.misc
27 |
28 |
29 | def create_animation(frames, size, fps=10, **kwargs):
30 | fig = plt.figure(figsize=size, frameon=False)
31 | ax = fig.add_axes([0, 0, 1, 1])
32 | img = ax.imshow(frames[0], **kwargs)
33 | ax.set_xticks([])
34 | ax.set_yticks([])
35 | callback = lambda frame: (img, img.set_data(frame))[:1]
36 | kwargs = dict(frames=frames, interval=1000 / fps, blit=True)
37 | anim = animation.FuncAnimation(fig, callback, **kwargs)
38 | return anim
39 |
40 |
41 | def save_animation(filepath, anim, fps=10, overwrite=False):
42 | if filepath.endswith('.mp4'):
43 | if not shutil.which('ffmpeg'):
44 | raise RuntimeError("The 'ffmpeg' executable must be in PATH.")
45 | mpl.rcParams['animation.ffmpeg_path'] = 'ffmpeg'
46 | kwargs = dict(fps=fps, writer='ffmpeg', extra_args=['-vcodec', 'libx264'])
47 | elif filepath.endswith('.gif'):
48 | kwargs = dict(fps=fps, writer='imagemagick')
49 | else:
50 | message = "Unknown video format; filename should end in '.mp4' or '.gif'."
51 | raise NotImplementedError(message)
52 | if os.path.exists(filepath) and not overwrite:
53 | message = 'Skip rendering animation because {} already exists.'
54 | print(message.format(filepath))
55 | return
56 | print('Render animation to {}.'.format(filepath))
57 | anim.save(filepath, **kwargs)
58 |
59 |
60 | def unpack_image_strip(image, tile_width, tile_height):
61 | image = image.reshape((
62 | image.shape[0] // tile_height,
63 | tile_height,
64 | image.shape[1] // tile_width,
65 | tile_width,
66 | image.shape[2],
67 | ))
68 | image = image.transpose((0, 2, 1, 3, 4))
69 | return image
70 |
71 |
72 | def pack_animation_frames(image):
73 | image = image.transpose((1, 2, 0, 3, 4))
74 | image = image.reshape((
75 | image.shape[0],
76 | image.shape[1],
77 | image.shape[2] * image.shape[3],
78 | image.shape[4],
79 | ))
80 | return image
81 |
82 |
83 | def main(args):
84 | image = scipy.misc.imread(args.image_path)
85 | image = unpack_image_strip(image, args.tile_width, args.tile_height)
86 | frames = pack_animation_frames(image)
87 | size = frames.shape[2] / args.dpi, frames.shape[1] / args.dpi
88 | anim = create_animation(frames, size, args.fps)
89 | save_animation(args.animation_path, anim, args.fps, args.overwrite)
90 |
91 |
92 | if __name__ == '__main__':
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument('-i', '--image-path', required=True)
95 | parser.add_argument('-o', '--animation-path', required=True)
96 | parser.add_argument('-f', '--overwrite', action='store_true', default=False)
97 | parser.add_argument('-x', '--tile-width', type=int, default=64)
98 | parser.add_argument('-y', '--tile-height', type=int, default=64)
99 | parser.add_argument('-r', '--fps', type=int, default=10)
100 | parser.add_argument('-d', '--dpi', type=float, default=50)
101 | args = parser.parse_args()
102 | main(args)
103 |
--------------------------------------------------------------------------------
/planet/scripts/fetch_events.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import csv
21 | import fnmatch
22 | import functools
23 | import multiprocessing.dummy as multiprocessing
24 | import os
25 | import re
26 | import sys
27 | import traceback
28 |
29 | # import imageio
30 | import numpy as np
31 | import skimage.io
32 | import tensorflow as tf
33 | from tensorboard.backend.event_processing import (
34 | plugin_event_multiplexer as event_multiplexer)
35 |
36 |
37 | lock = multiprocessing.Lock()
38 |
39 |
40 | def safe_print(*args, **kwargs):
41 | with lock:
42 | print(*args, **kwargs)
43 |
44 |
45 | def create_reader(logdir):
46 | reader = event_multiplexer.EventMultiplexer()
47 | reader.AddRun(logdir, 'run')
48 | reader.Reload()
49 | return reader
50 |
51 |
52 | def extract_values(reader, tag):
53 | events = reader.Tensors('run', tag)
54 | steps = [event.step for event in events]
55 | times = [event.wall_time for event in events]
56 | values = [tf.make_ndarray(event.tensor_proto) for event in events]
57 | return steps, times, values
58 |
59 |
60 | def export_scalar(basename, steps, times, values):
61 | safe_print('Writing', basename + '.csv')
62 | values = [value.item() for value in values]
63 | with tf.gfile.Open(basename + '.csv', 'w') as outfile:
64 | writer = csv.writer(outfile)
65 | writer.writerow(('wall_time', 'step', 'value'))
66 | for row in zip(times, steps, values):
67 | writer.writerow(row)
68 |
69 |
70 | def export_image(basename, steps, times, values):
71 | tf.reset_default_graph()
72 | tf_string = tf.placeholder(tf.string)
73 | tf_tensor = tf.image.decode_image(tf_string)
74 | with tf.Session() as sess:
75 | for step, time_, value in zip(steps, times, values):
76 | filename = '{}-{}-{}.png'.format(basename, step, time_)
77 | width, height, string = value[0], value[1], value[2]
78 | del width
79 | del height
80 | tensor = sess.run(tf_tensor, {tf_string: string})
81 | # imageio.imsave(filename, tensor)
82 | skimage.io.imsave(filename, tensor)
83 | filename = '{}-{}-{}.npy'.format(basename, step, time_)
84 | np.save(filename, tensor)
85 |
86 |
87 | def process_logdir(logdir, args):
88 | clean = lambda text: re.sub('[^A-Za-z0-9_]', '_', text)
89 | basename = os.path.join(args.outdir, clean(logdir))
90 | if len(tf.gfile.Glob(basename + '*')) > 0 and not args.force:
91 | safe_print('Exists', logdir)
92 | return
93 | try:
94 | safe_print('Start', logdir)
95 | reader = create_reader(logdir)
96 | for tag in reader.Runs()['run']['tensors']: # tensors -> scalars
97 | if fnmatch.fnmatch(tag, args.tags):
98 | steps, times, values = extract_values(reader, tag)
99 | filename = '{}___{}'.format(basename, clean(tag))
100 | export_scalar(filename, steps, times, values)
101 | # for tag in tags['images']:
102 | # if fnmatch.fnmatch(tag, args.tags):
103 | # steps, times, values = extract_values(reader, tag)
104 | # filename = '{}___{}'.format(basename, clean(tag))
105 | # export_image(filename, steps, times, values)
106 | del reader
107 | safe_print('Done', logdir)
108 | except Exception:
109 | safe_print('Exception', logdir)
110 | safe_print(traceback.print_exc())
111 |
112 |
113 | def main(args):
114 | logdirs = tf.gfile.Glob(args.logdirs)
115 | print(len(logdirs), 'logdirs.')
116 | assert logdirs
117 | tf.gfile.MakeDirs(args.outdir)
118 | np.random.shuffle(logdirs)
119 | pool = multiprocessing.Pool(args.workers)
120 | worker_fn = functools.partial(process_logdir, args=args)
121 | pool.map(worker_fn, logdirs)
122 |
123 |
124 | if __name__ == '__main__':
125 | boolean = lambda x: ['False', 'True'].index(x)
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument(
128 | '--logdirs', required=True,
129 | help='glob for log directories to fetch')
130 | parser.add_argument(
131 | '--tags', default='trainer*score',
132 | help='glob for tags to save')
133 | parser.add_argument(
134 | '--outdir', required=True,
135 | help='output directory to store values')
136 | parser.add_argument(
137 | '--force', type=boolean, default=False,
138 | help='overwrite existing files')
139 | parser.add_argument(
140 | '--workers', type=int, default=10,
141 | help='number of worker threads')
142 | args_, remaining = parser.parse_known_args()
143 | args_.logdirs = os.path.expanduser(args_.logdirs)
144 | args_.outdir = os.path.expanduser(args_.outdir)
145 | remaining.insert(0, sys.argv[0])
146 | tf.app.run(lambda _: main(args_), remaining)
147 |
--------------------------------------------------------------------------------
/planet/scripts/objectives.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def reward(state, graph, params):
23 | features = graph.cell.features_from_state(state)
24 | reward = graph.heads.reward(features).mean()
25 | return tf.reduce_sum(reward, 1)
26 |
--------------------------------------------------------------------------------
/planet/scripts/sync.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | # Copyright 2019 The PlaNet Authors. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse
18 | import glob
19 | import os
20 | import shutil
21 |
22 |
23 | def find_source_files(directory):
24 | top_level = glob.glob(directory + '*.py')
25 | recursive = glob.glob(directory + '/**/*.py')
26 | return top_level + recursive
27 |
28 |
29 | def copy_source_tree(source_dir, target_dir):
30 | for source in find_source_files(source_dir):
31 | target = os.path.join(target_dir, os.path.relpath(source, source_dir))
32 | os.makedirs(os.path.dirname(target), exist_ok=True)
33 | if os.path.exists(target):
34 | print('Override', os.path.relpath(target, target_dir))
35 | else:
36 | print('Add', os.path.relpath(target, target_dir))
37 | shutil.copy(source, target)
38 | for target in find_source_files(target_dir):
39 | source = os.path.join(source_dir, os.path.relpath(target, target_dir))
40 | if not os.path.exists(source):
41 | print('Remove', os.path.relpath(target, target_dir))
42 | os.remove(target)
43 |
44 |
45 | def infer_headers(directory):
46 | try:
47 | filename = find_source_files(directory)[0]
48 | except IndexError:
49 | raise RuntimeError('No code files found in {}.'.format(directory))
50 | header = []
51 | with open(filename, 'r') as f:
52 | for index, line in enumerate(f):
53 | if index == 0 and not line.startswith('#'):
54 | break
55 | if not line.startswith('#') and line.strip(' \n'):
56 | break
57 | header.append(line)
58 | return header
59 |
60 |
61 | def add_headers(directory, header):
62 | for filename in find_source_files(directory):
63 | with open(filename, 'r') as f:
64 | text = f.readlines()
65 | with open(filename, 'w') as f:
66 | f.write(''.join(header + text))
67 |
68 |
69 | def remove_headers(directory, header):
70 | for filename in find_source_files(directory):
71 | with open(filename, 'r') as f:
72 | text = f.readlines()
73 | if text[:len(header)] == header:
74 | text = text[len(header):]
75 | with open(filename, 'w') as f:
76 | f.write(''.join(text))
77 |
78 |
79 | def main(args):
80 | print('Inferring headers.\n')
81 | source_header = infer_headers(args.source)
82 | print('{} Source header {}\n{}{}\n'.format(
83 | '-' * 32, '-' * 32, ''.join(source_header), '-' * 79))
84 | target_header = infer_headers(args.target)
85 | print('{} Target header {}\n{}{}\n'.format(
86 | '-' * 32, '-' * 32, ''.join(target_header), '-' * 79))
87 | print('Synchronizing directories.')
88 | copy_source_tree(args.source, args.target)
89 | if target_header and not source_header:
90 | print('Adding headers.')
91 | add_headers(args.target, target_header)
92 | if source_header and not target_header:
93 | print('Removing headers.')
94 | remove_headers(args.target, source_header)
95 | print('Done.')
96 |
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument('--source', type=os.path.expanduser, required=True)
101 | parser.add_argument('--target', type=os.path.expanduser, required=True)
102 | main(parser.parse_args())
103 |
--------------------------------------------------------------------------------
/planet/scripts/tasks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 |
21 | import numpy as np
22 |
23 | from planet import control
24 | from planet import tools
25 |
26 |
27 | Task = collections.namedtuple(
28 | 'Task', 'name, env_ctor, max_length, state_components')
29 |
30 |
31 | def dummy(config, params):
32 | action_repeat = params.get('action_repeat', 1)
33 | max_length = 1000 // action_repeat
34 | state_components = ['reward']
35 | env_ctor = lambda: control.wrappers.ActionRepeat(
36 | control.DummyEnv, action_repeat)
37 | return Task('dummy', env_ctor, max_length, state_components)
38 |
39 |
40 | def cartpole_balance(config, params):
41 | action_repeat = params.get('action_repeat', 8)
42 | max_length = 1000 // action_repeat
43 | state_components = ['reward', 'position', 'velocity']
44 | env_ctor = tools.bind(
45 | _dm_control_env, action_repeat, max_length, 'cartpole', 'balance',
46 | params)
47 | return Task('cartpole_balance', env_ctor, max_length, state_components)
48 |
49 |
50 | def cartpole_swingup(config, params):
51 | action_repeat = params.get('action_repeat', 8)
52 | max_length = 1000 // action_repeat
53 | state_components = ['reward', 'position', 'velocity']
54 | env_ctor = tools.bind(
55 | _dm_control_env, action_repeat, max_length, 'cartpole', 'swingup',
56 | params)
57 | return Task('cartpole_swingup', env_ctor, max_length, state_components)
58 |
59 |
60 | def finger_spin(config, params):
61 | action_repeat = params.get('action_repeat', 2)
62 | max_length = 1000 // action_repeat
63 | state_components = ['reward', 'position', 'velocity', 'touch']
64 | env_ctor = tools.bind(
65 | _dm_control_env, action_repeat, max_length, 'finger', 'spin', params)
66 | return Task('finger_spin', env_ctor, max_length, state_components)
67 |
68 |
69 | def cheetah_run(config, params):
70 | action_repeat = params.get('action_repeat', 4)
71 | max_length = 1000 // action_repeat
72 | state_components = ['reward', 'position', 'velocity']
73 | env_ctor = tools.bind(
74 | _dm_control_env, action_repeat, max_length, 'cheetah', 'run', params)
75 | return Task('cheetah_run', env_ctor, max_length, state_components)
76 |
77 |
78 | def cup_catch(config, params):
79 | action_repeat = params.get('action_repeat', 4)
80 | max_length = 1000 // action_repeat
81 | state_components = ['reward', 'position', 'velocity']
82 | env_ctor = tools.bind(
83 | _dm_control_env, action_repeat, max_length, 'ball_in_cup', 'catch',
84 | params)
85 | return Task('cup_catch', env_ctor, max_length, state_components)
86 |
87 |
88 | def walker_walk(config, params):
89 | action_repeat = params.get('action_repeat', 2)
90 | max_length = 1000 // action_repeat
91 | state_components = ['reward', 'height', 'orientations', 'velocity']
92 | env_ctor = tools.bind(
93 | _dm_control_env, action_repeat, max_length, 'walker', 'walk', params)
94 | return Task('walker_walk', env_ctor, max_length, state_components)
95 |
96 |
97 | def reacher_easy(config, params):
98 | action_repeat = params.get('action_repeat', 4)
99 | max_length = 1000 // action_repeat
100 | state_components = ['reward', 'position', 'velocity', 'to_target']
101 | env_ctor = tools.bind(
102 | _dm_control_env, action_repeat, max_length, 'reacher', 'easy', params)
103 | return Task('reacher_easy', env_ctor, max_length, state_components)
104 |
105 |
106 | def gym_cheetah(config, params):
107 | # Works with `isolate_envs: process`.
108 | action_repeat = params.get('action_repeat', 1)
109 | max_length = 1000 // action_repeat
110 | state_components = ['reward', 'state']
111 | env_ctor = tools.bind(
112 | _gym_env, action_repeat, config.batch_shape[1], max_length,
113 | 'HalfCheetah-v3')
114 | return Task('gym_cheetah', env_ctor, max_length, state_components)
115 |
116 |
117 | def gym_racecar(config, params):
118 | # Works with `isolate_envs: thread`.
119 | action_repeat = params.get('action_repeat', 1)
120 | max_length = 1000 // action_repeat
121 | state_components = ['reward']
122 | env_ctor = tools.bind(
123 | _gym_env, action_repeat, config.batch_shape[1], max_length,
124 | 'CarRacing-v0', obs_is_image=True)
125 | return Task('gym_racing', env_ctor, max_length, state_components)
126 |
127 |
128 | def _dm_control_env(
129 | action_repeat, max_length, domain, task, params, normalize=False,
130 | camera_id=None):
131 | if isinstance(domain, str):
132 | from dm_control import suite
133 | env = suite.load(domain, task)
134 | else:
135 | assert task is None
136 | env = domain()
137 | if camera_id is None:
138 | camera_id = int(params.get('camera_id', 0))
139 | env = control.wrappers.DeepMindWrapper(env, (64, 64), camera_id=camera_id)
140 | env = control.wrappers.ActionRepeat(env, action_repeat)
141 | if normalize:
142 | env = control.wrappers.NormalizeActions(env)
143 | env = control.wrappers.MaximumDuration(env, max_length)
144 | env = control.wrappers.PixelObservations(env, (64, 64), np.uint8, 'image')
145 | env = control.wrappers.ConvertTo32Bit(env)
146 | return env
147 |
148 |
149 | def _gym_env(action_repeat, min_length, max_length, name, obs_is_image=False):
150 | import gym
151 | env = gym.make(name)
152 | env = control.wrappers.ActionRepeat(env, action_repeat)
153 | env = control.wrappers.NormalizeActions(env)
154 | env = control.wrappers.MinimumDuration(env, min_length)
155 | env = control.wrappers.MaximumDuration(env, max_length)
156 | if obs_is_image:
157 | env = control.wrappers.ObservationDict(env, 'image')
158 | env = control.wrappers.ObservationToRender(env)
159 | else:
160 | env = control.wrappers.ObservationDict(env, 'state')
161 | env = control.wrappers.PixelObservations(env, (64, 64), np.uint8, 'image')
162 | env = control.wrappers.ConvertTo32Bit(env)
163 | return env
164 |
--------------------------------------------------------------------------------
/planet/scripts/test_planet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import sys
21 |
22 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(
23 | os.path.abspath(__file__)))))
24 |
25 | import tensorflow as tf
26 |
27 | from planet import tools
28 | from planet.scripts import train
29 |
30 |
31 | class PlanetTest(tf.test.TestCase):
32 |
33 | def test_dummy_isolate_none(self):
34 | args = tools.AttrDict(
35 | logdir=self.get_temp_dir(),
36 | num_runs=1,
37 | config='debug',
38 | params=tools.AttrDict(
39 | task='dummy',
40 | isolate_envs='none',
41 | max_steps=30),
42 | ping_every=0,
43 | resume_runs=False)
44 | try:
45 | tf.app.run(lambda _: train.main(args), [sys.argv[0]])
46 | except SystemExit:
47 | pass
48 |
49 | def test_dummy_isolate_thread(self):
50 | args = tools.AttrDict(
51 | logdir=self.get_temp_dir(),
52 | num_runs=1,
53 | config='debug',
54 | params=tools.AttrDict(
55 | task='dummy',
56 | isolate_envs='thread',
57 | max_steps=30),
58 | ping_every=0,
59 | resume_runs=False)
60 | try:
61 | tf.app.run(lambda _: train.main(args), [sys.argv[0]])
62 | except SystemExit:
63 | pass
64 |
65 | def test_dummy_isolate_process(self):
66 | args = tools.AttrDict(
67 | logdir=self.get_temp_dir(),
68 | num_runs=1,
69 | config='debug',
70 | params=tools.AttrDict(
71 | task='dummy',
72 | isolate_envs='process',
73 | max_steps=30),
74 | ping_every=0,
75 | resume_runs=False)
76 | try:
77 | tf.app.run(lambda _: train.main(args), [sys.argv[0]])
78 | except SystemExit:
79 | pass
80 |
81 | def test_dm_control_isolate_none(self):
82 | args = tools.AttrDict(
83 | logdir=self.get_temp_dir(),
84 | num_runs=1,
85 | config='debug',
86 | params=tools.AttrDict(
87 | task='cup_catch',
88 | isolate_envs='none',
89 | max_steps=30),
90 | ping_every=0,
91 | resume_runs=False)
92 | try:
93 | tf.app.run(lambda _: train.main(args), [sys.argv[0]])
94 | except SystemExit:
95 | pass
96 |
97 | def test_dm_control_isolate_thread(self):
98 | args = tools.AttrDict(
99 | logdir=self.get_temp_dir(),
100 | num_runs=1,
101 | config='debug',
102 | params=tools.AttrDict(
103 | task='cup_catch',
104 | isolate_envs='thread',
105 | max_steps=30),
106 | ping_every=0,
107 | resume_runs=False)
108 | try:
109 | tf.app.run(lambda _: train.main(args), [sys.argv[0]])
110 | except SystemExit:
111 | pass
112 |
113 | def test_dm_control_isolate_process(self):
114 | args = tools.AttrDict(
115 | logdir=self.get_temp_dir(),
116 | num_runs=1,
117 | config='debug',
118 | params=tools.AttrDict(
119 | task='cup_catch',
120 | isolate_envs='process',
121 | max_steps=30),
122 | ping_every=0,
123 | resume_runs=False)
124 | try:
125 | tf.app.run(lambda _: train.main(args), [sys.argv[0]])
126 | except SystemExit:
127 | pass
128 |
129 |
130 | if __name__ == '__main__':
131 | tf.test.main()
132 |
--------------------------------------------------------------------------------
/planet/scripts/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Train a Deep Planning Network agent.
16 |
17 | Full training run:
18 |
19 | python3 -m planet.scripts.train \
20 | --logdir /path/to/logdir \
21 | --config default \
22 | --params '{tasks: [cheetah_run]}'
23 |
24 | For debugging:
25 |
26 | python3 -m planet.scripts.train \
27 | --logdir /path/to/logdir \
28 | --resume_runs False \
29 | --num_runs 1000 \
30 | --config debug \
31 | --params '{tasks: [cheetah_run]}'
32 | """
33 |
34 | from __future__ import absolute_import
35 | from __future__ import division
36 | from __future__ import print_function
37 |
38 | import argparse
39 | import functools
40 | import os
41 | import sys
42 |
43 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(
44 | os.path.abspath(__file__)))))
45 |
46 | # Need offline backend to render summaries from within tf.py_func.
47 | import matplotlib
48 | matplotlib.use('Agg')
49 |
50 | import ruamel.yaml as yaml
51 | import tensorflow as tf
52 |
53 | from planet import tools
54 | from planet import training
55 | from planet.scripts import configs
56 |
57 |
58 | def process(logdir, args):
59 | with args.params.unlocked:
60 | args.params.logdir = logdir
61 | config = tools.AttrDict()
62 | with config.unlocked:
63 | config = getattr(configs, args.config)(config, args.params)
64 | training.utility.collect_initial_episodes(config)
65 | tf.reset_default_graph()
66 | dataset = tools.numpy_episodes.numpy_episodes(
67 | config.train_dir, config.test_dir, config.batch_shape,
68 | reader=config.data_reader,
69 | loader=config.data_loader,
70 | num_chunks=config.num_chunks,
71 | preprocess_fn=config.preprocess_fn)
72 | for score in training.utility.train(
73 | training.define_model, dataset, logdir, config):
74 | yield score
75 |
76 |
77 | def main(args):
78 | training.utility.set_up_logging()
79 | experiment = training.Experiment(
80 | args.logdir,
81 | process_fn=functools.partial(process, args=args),
82 | num_runs=args.num_runs,
83 | ping_every=args.ping_every,
84 | resume_runs=args.resume_runs)
85 | for run in experiment:
86 | for unused_score in run:
87 | pass
88 |
89 |
90 | if __name__ == '__main__':
91 | boolean = lambda x: bool(['False', 'True'].index(x))
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument(
94 | '--logdir', required=True)
95 | parser.add_argument(
96 | '--num_runs', type=int, default=1)
97 | parser.add_argument(
98 | '--config', default='default',
99 | help='Select a configuration function from scripts/configs.py.')
100 | parser.add_argument(
101 | '--params', default='{}',
102 | help='YAML formatted dictionary to be used by the config.')
103 | parser.add_argument(
104 | '--ping_every', type=int, default=0,
105 | help='Used to prevent conflicts between multiple workers; 0 to disable.')
106 | parser.add_argument(
107 | '--resume_runs', type=boolean, default=True,
108 | help='Whether to resume unfinished runs in the log directory.')
109 | args_, remaining = parser.parse_known_args()
110 | args_.params = tools.AttrDict(yaml.safe_load(args_.params.replace('#', ',')))
111 | args_.logdir = args_.logdir and os.path.expanduser(args_.logdir)
112 | remaining.insert(0, sys.argv[0])
113 | tf.app.run(lambda _: main(args_), remaining)
114 |
--------------------------------------------------------------------------------
/planet/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import nested
20 | from . import numpy_episodes
21 | from . import preprocess
22 | from . import schedule
23 | from . import summary
24 | from . import unroll
25 | from .attr_dict import AttrDict
26 | from .bind import bind
27 | from .bound_action import bound_action
28 | from .copy_weights import soft_copy_weights
29 | from .count_dataset import count_dataset
30 | from .count_weights import count_weights
31 | from .custom_optimizer import CustomOptimizer
32 | from .filter_variables_lib import filter_variables
33 | from .gif_summary import gif_summary
34 | from .image_strip_summary import image_strip_summary
35 | from .mask import mask
36 | from .overshooting import overshooting
37 | from .reshape_as import reshape_as
38 | from .shape import shape
39 | from .streaming_mean import StreamingMean
40 | from .target_network import track_network
41 |
--------------------------------------------------------------------------------
/planet/tools/attr_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import contextlib
21 | import os
22 |
23 | import numpy as np
24 | import ruamel.yaml as yaml
25 |
26 |
27 | class AttrDict(dict): # collections.OrderedDict
28 | """Wrap a dictionary to access keys as attributes."""
29 |
30 | def __init__(self, *args, **kwargs):
31 | unlocked = kwargs.pop('_unlocked', not (args or kwargs))
32 | defaults = kwargs.pop('_defaults', {})
33 | touched = kwargs.pop('_touched', set())
34 | super(AttrDict, self).__setattr__('_unlocked', True)
35 | super(AttrDict, self).__setattr__('_touched', set())
36 | super(AttrDict, self).__setattr__('_defaults', {})
37 | super(AttrDict, self).__init__(*args, **kwargs)
38 | super(AttrDict, self).__setattr__('_unlocked', unlocked)
39 | super(AttrDict, self).__setattr__('_defaults', defaults)
40 | super(AttrDict, self).__setattr__('_touched', touched)
41 |
42 | def __getattr__(self, name):
43 | try:
44 | return self[name]
45 | except KeyError:
46 | raise AttributeError(name)
47 |
48 | def __setattr__(self, name, value):
49 | self[name] = value
50 |
51 | def __getitem__(self, name):
52 | # Do not provide None for unimplemented magic attributes.
53 | # if name.startswith('__'):
54 | # raise AttributeError(name)
55 | self._touched.add(name)
56 | if name in self:
57 | return super(AttrDict, self).__getitem__(name)
58 | if name in self._defaults:
59 | return self._defaults[name]
60 | raise AttributeError(name)
61 |
62 | def __setitem__(self, name, value):
63 | # if name.startswith('_'):
64 | # raise AttributeError('Cannot set private attribute {}'.format(name))
65 | if name.startswith('__'):
66 | raise AttributeError("Cannot set magic attribute '{}'".format(name))
67 | if not self._unlocked:
68 | message = 'Use obj.unlock() before setting {}'
69 | raise RuntimeError(message.format(name))
70 | super(AttrDict, self).__setitem__(name, value)
71 |
72 | def __repr__(self):
73 | items = []
74 | for key, value in self.items():
75 | items.append('{}: {}'.format(key, self._format_value(value)))
76 | return '{' + ', '.join(items) + '}'
77 |
78 | def get(self, key, default=None):
79 | self._touched.add(key)
80 | if key not in self:
81 | return default
82 | return self[key]
83 |
84 | @property
85 | def untouched(self):
86 | return sorted(set(self.keys()) - self._touched)
87 |
88 | @property
89 | @contextlib.contextmanager
90 | def unlocked(self):
91 | self.unlock()
92 | yield
93 | self.lock()
94 |
95 | def lock(self):
96 | super(AttrDict, self).__setattr__('_unlocked', False)
97 | for value in self.values():
98 | if isinstance(value, AttrDict):
99 | value.lock()
100 |
101 | def unlock(self):
102 | super(AttrDict, self).__setattr__('_unlocked', True)
103 | for value in self.values():
104 | if isinstance(value, AttrDict):
105 | value.unlock()
106 |
107 | def summarize(self):
108 | items = []
109 | for key, value in self.items():
110 | items.append('{}: {}'.format(key, self._format_value(value)))
111 | return '\n'.join(items)
112 |
113 | def update(self, mapping):
114 | if not self._unlocked:
115 | message = 'Use obj.unlock() before updating'
116 | raise RuntimeError(message)
117 | super(AttrDict, self).update(mapping)
118 | return self
119 |
120 | def copy(self, _unlocked=False):
121 | return type(self)(super(AttrDict, self).copy(), _unlocked=_unlocked)
122 |
123 | def save(self, filename):
124 | assert str(filename).endswith('.yaml')
125 | directory = os.path.dirname(str(filename))
126 | os.makedirs(directory, exist_ok=True)
127 | with open(filename, 'w') as f:
128 | yaml.dump(collections.OrderedDict(self), f)
129 |
130 | @classmethod
131 | def load(cls, filename):
132 | assert str(filename).endswith('.yaml')
133 | with open(filename, 'r') as f:
134 | return cls(yaml.load(f, Loader=yaml.Loader))
135 |
136 | def _format_value(self, value):
137 | if isinstance(value, np.ndarray):
138 | template = ''
139 | min_ = self._format_value(value.min())
140 | mean = self._format_value(value.mean())
141 | max_ = self._format_value(value.max())
142 | return template.format(value.shape, value.dtype, min_, mean, max_)
143 | if isinstance(value, float) and 1e-3 < abs(value) < 1e6:
144 | return '{:.3f}'.format(value)
145 | if isinstance(value, float):
146 | return '{:4.1e}'.format(value)
147 | if hasattr(value, '__name__'):
148 | return value.__name__
149 | return str(value)
150 |
--------------------------------------------------------------------------------
/planet/tools/bind.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 |
20 | class bind(object):
21 |
22 | def __init__(self, fn, *args, **kwargs):
23 | self._fn = fn
24 | self._args = args
25 | self._kwargs = kwargs
26 |
27 | def __call__(self, *args, **kwargs):
28 | args_ = self._args + args
29 | kwargs_ = self._kwargs.copy()
30 | kwargs_.update(kwargs)
31 | return self._fn(*args_, **kwargs_)
32 |
33 | def __repr__(self):
34 | return 'bind({})'.format(self._fn.__name__)
35 |
--------------------------------------------------------------------------------
/planet/tools/bound_action.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import tensorflow as tf
16 |
17 |
18 | def bound_action(action, strategy):
19 | if strategy == 'none':
20 | pass
21 | elif strategy == 'clip':
22 | forward = tf.stop_gradient(tf.clip_by_value(action, -1.0, 1.0))
23 | action = action - tf.stop_gradient(action) + forward
24 | elif strategy == 'tanh':
25 | action = tf.tanh(action)
26 | else:
27 | raise NotImplementedError(strategy)
28 | return action
29 |
--------------------------------------------------------------------------------
/planet/tools/chunk_sequence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Chunk sequences into fixed lengths."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import tensorflow as tf
21 |
22 | from planet.tools import nested
23 |
24 |
25 | def chunk_sequence(sequence, chunk_length, randomize=True, num_chunks=None):
26 | """Split a nested dict of sequence tensors into a batch of chunks.
27 |
28 | This function does not expect a batch of sequences, but a single sequence. A
29 | `length` key is added if it did not exist already. When `randomize` is set,
30 | up to `chunk_length - 1` initial frames will be discarded. Final frames that
31 | do not fit into a chunk are always discarded.
32 |
33 | Args:
34 | sequence: Nested dict of tensors with time dimension.
35 | chunk_length: Size of chunks the sequence will be split into.
36 | randomize: Start chunking from a random offset in the sequence,
37 | enforcing that at least one chunk is generated.
38 | num_chunks: Optionally specify the exact number of chunks to be extracted
39 | from the sequence. Requires input to be long enough.
40 |
41 | Returns:
42 | Nested dict of sequence tensors with chunk dimension.
43 | """
44 | with tf.device('/cpu:0'):
45 | if 'length' in sequence:
46 | length = sequence.pop('length')
47 | else:
48 | length = tf.shape(nested.flatten(sequence)[0])[0]
49 | if randomize:
50 | if num_chunks is None:
51 | num_chunks = tf.maximum(1, length // chunk_length - 1)
52 | else:
53 | num_chunks = num_chunks + 0 * length
54 | used_length = num_chunks * chunk_length
55 | max_offset = length - used_length
56 | offset = tf.random_uniform((), 0, max_offset + 1, dtype=tf.int32)
57 | else:
58 | if num_chunks is None:
59 | num_chunks = length // chunk_length
60 | else:
61 | num_chunks = num_chunks + 0 * length
62 | used_length = num_chunks * chunk_length
63 | max_offset = 0
64 | offset = 0
65 | clipped = nested.map(
66 | lambda tensor: tensor[offset: offset + used_length],
67 | sequence)
68 | chunks = nested.map(
69 | lambda tensor: tf.reshape(
70 | tensor, [num_chunks, chunk_length] + tensor.shape[1:].as_list()),
71 | clipped)
72 | chunks['length'] = chunk_length * tf.ones((num_chunks,), dtype=tf.int32)
73 | return chunks
74 |
75 |
76 | def _pad_tensor(tensor, length, value):
77 | tiling = [length] + ([1] * (tensor.shape.ndims - 1))
78 | padding = tf.tile(0 * tensor[:1] + value, tiling)
79 | padded = tf.concat([tensor, padding], 0)
80 | return padded
81 |
--------------------------------------------------------------------------------
/planet/tools/copy_weights.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from planet.tools import filter_variables_lib
22 |
23 |
24 | def soft_copy_weights(source_pattern, target_pattern, amount):
25 | assert 0 < amount <= 1
26 | source_vars = filter_variables_lib.filter_variables(include=source_pattern)
27 | target_vars = filter_variables_lib.filter_variables(include=target_pattern)
28 | source_vars = sorted(source_vars, key=lambda x: x.name)
29 | target_vars = sorted(target_vars, key=lambda x: x.name)
30 | assert len(source_vars) == len(target_vars)
31 | updates = []
32 | for source, target in zip(source_vars, target_vars):
33 | assert source.name != target.name
34 | if amount == 1.0:
35 | updates.append(target.assign(source))
36 | else:
37 | updates.append(target.assign((1 - amount) * target + amount * source))
38 | return tf.group(*updates)
39 |
--------------------------------------------------------------------------------
/planet/tools/count_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 |
25 | def count_dataset(directory, key='reward'):
26 | directory = os.path.expanduser(directory)
27 | if not tf.gfile.Exists(directory):
28 | message = "Data set directory '{}' does not exist."
29 | raise ValueError(message.format(directory))
30 | pattern = os.path.join(directory, '*.npz')
31 | def func():
32 | filenames = tf.gfile.Glob(pattern)
33 | episodes = len(filenames)
34 | episodes = np.array(episodes, dtype=np.int32)
35 | return episodes
36 | return tf.py_func(func, [], tf.int32)
37 |
--------------------------------------------------------------------------------
/planet/tools/count_weights.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import re
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 |
25 | def count_weights(scope=None, exclude=None, graph=None):
26 | """Count learnable parameters.
27 |
28 | Args:
29 | scope: Resrict the count to a variable scope.
30 | exclude: Regex to match variable names to exclude.
31 | graph: Operate on a graph other than the current default graph.
32 |
33 | Returns:
34 | Number of learnable parameters as integer.
35 | """
36 | if scope:
37 | scope = scope if scope.endswith('/') else scope + '/'
38 | graph = graph or tf.get_default_graph()
39 | vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
40 | if scope:
41 | vars_ = [var for var in vars_ if var.name.startswith(scope)]
42 | if exclude:
43 | exclude = re.compile(exclude)
44 | vars_ = [var for var in vars_ if not exclude.match(var.name)]
45 | shapes = []
46 | for var in vars_:
47 | if not var.shape.is_fully_defined():
48 | message = "Trainable variable '{}' has undefined shape '{}'."
49 | raise ValueError(message.format(var.name, var.shape))
50 | shapes.append(var.shape.as_list())
51 | return int(sum(np.prod(shape) for shape in shapes))
52 |
--------------------------------------------------------------------------------
/planet/tools/custom_optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from planet.tools import filter_variables_lib
22 |
23 |
24 | class CustomOptimizer(object):
25 |
26 | def __init__(
27 | self, optimizer_cls, step, log, learning_rate,
28 | include=None, exclude=None, clipping=None, schedule=None,
29 | debug=False, name='custom_optimizer'):
30 | if schedule:
31 | learning_rate *= schedule(step)
32 | self._step = step
33 | self._log = log
34 | self._learning_rate = learning_rate
35 | self._variables = filter_variables_lib.filter_variables(include, exclude)
36 | self._clipping = float(clipping)
37 | self._debug = debug
38 | self._name = name
39 | self._optimizer = optimizer_cls(learning_rate, name=name)
40 |
41 | def maybe_minimize(self, condition, loss):
42 | # loss = tf.cond(condition, lambda: loss, float)
43 | update_op, grad_norm = tf.cond(
44 | condition,
45 | lambda: self.minimize(loss),
46 | lambda: (tf.no_op(), 0.0))
47 | with tf.control_dependencies([update_op]):
48 | summary = tf.cond(
49 | tf.logical_and(condition, self._log),
50 | lambda: self.summarize(grad_norm), str)
51 | if self._debug:
52 | # print_op = tf.print('{}_grad_norm='.format(self._name), grad_norm)
53 | message = 'Zero gradient norm in {} optimizer.'.format(self._name)
54 | assertion = lambda: tf.assert_greater(grad_norm, 0.0, message=message)
55 | assert_op = tf.cond(condition, assertion, tf.no_op)
56 | with tf.control_dependencies([assert_op]):
57 | summary = tf.identity(summary)
58 | return summary, grad_norm
59 |
60 | def minimize(self, loss):
61 | with tf.name_scope('optimizer_{}'.format(self._name)):
62 | if self._debug:
63 | loss = tf.check_numerics(loss, '{}_loss'.format(self._name))
64 | gradients, variables = zip(*self._optimizer.compute_gradients(
65 | loss, self._variables, colocate_gradients_with_ops=True))
66 | grad_norm = tf.global_norm(gradients)
67 | if self._clipping:
68 | gradients, _ = tf.clip_by_global_norm(
69 | gradients, self._clipping, grad_norm)
70 | optimize = self._optimizer.apply_gradients(zip(gradients, variables))
71 | return optimize, grad_norm
72 |
73 | def summarize(self, grad_norm):
74 | summaries = []
75 | with tf.name_scope('optimizer_{}'.format(self._name)):
76 | summaries.append(tf.summary.scalar('learning_rate', self._learning_rate))
77 | summaries.append(tf.summary.scalar('grad_norm', grad_norm))
78 | if self._clipping:
79 | clipped = tf.minimum(grad_norm, self._clipping)
80 | summaries.append(tf.summary.scalar('clipped_gradient_norm', clipped))
81 | summary = tf.summary.merge(summaries)
82 | return summary
83 |
--------------------------------------------------------------------------------
/planet/tools/filter_variables_lib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import re
20 |
21 | import tensorflow as tf
22 |
23 |
24 | def filter_variables(include=None, exclude=None):
25 | # Check arguments.
26 | if include is None:
27 | include = (r'.*',)
28 | if exclude is None:
29 | exclude = ()
30 | if not isinstance(include, (tuple, list)):
31 | include = (include,)
32 | if not isinstance(exclude, (tuple, list)):
33 | exclude = (exclude,)
34 | # Compile regexes.
35 | include = [re.compile(regex) for regex in include]
36 | exclude = [re.compile(regex) for regex in exclude]
37 | variables = tf.global_variables()
38 | if not variables:
39 | raise RuntimeError('There are no variables to filter.')
40 | # Check regexes.
41 | for regex in include:
42 | message = "Regex r'{}' does not match any variables in the graph.\n"
43 | message += 'All variables:\n'
44 | message += '\n'.join('- {}'.format(var.name) for var in variables)
45 | if not any(regex.match(variable.name) for variable in variables):
46 | raise RuntimeError(message.format(regex.pattern))
47 | # Filter variables.
48 | filtered = []
49 | for variable in variables:
50 | if not any(regex.match(variable.name) for regex in include):
51 | continue
52 | if any(regex.match(variable.name) for regex in exclude):
53 | continue
54 | filtered.append(variable)
55 | # Check result.
56 | if not filtered:
57 | message = 'No variables left after filtering.'
58 | message += '\nIncludes:\n' + '\n'.join(regex.pattern for regex in include)
59 | message += '\nExcludes:\n' + '\n'.join(regex.pattern for regex in exclude)
60 | raise RuntimeError(message)
61 | return filtered
62 |
--------------------------------------------------------------------------------
/planet/tools/gif_summary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 | from tensorflow.python.ops import summary_op_util
22 |
23 |
24 | def encode_gif(images, fps):
25 | """Encodes numpy images into gif string.
26 |
27 | Args:
28 | images: A 5-D `uint8` `np.array` (or a list of 4-D images) of shape
29 | `[batch_size, time, height, width, channels]` where `channels` is 1 or 3.
30 | fps: frames per second of the animation
31 |
32 | Returns:
33 | The encoded gif string.
34 |
35 | Raises:
36 | IOError: If the ffmpeg command returns an error.
37 | """
38 | from subprocess import Popen, PIPE
39 | h, w, c = images[0].shape
40 | cmd = [
41 | 'ffmpeg', '-y',
42 | '-f', 'rawvideo',
43 | '-vcodec', 'rawvideo',
44 | '-r', '%.02f' % fps,
45 | '-s', '%dx%d' % (w, h),
46 | '-pix_fmt', {1: 'gray', 3: 'rgb24'}[c],
47 | '-i', '-',
48 | '-filter_complex',
49 | '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
50 | '-r', '%.02f' % fps,
51 | '-f', 'gif',
52 | '-']
53 | proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
54 | for image in images:
55 | proc.stdin.write(image.tostring())
56 | out, err = proc.communicate()
57 | if proc.returncode:
58 | err = '\n'.join([' '.join(cmd), err.decode('utf8')])
59 | raise IOError(err)
60 | del proc
61 | return out
62 |
63 |
64 | def py_gif_summary(tag, images, max_outputs, fps):
65 | """Outputs a `Summary` protocol buffer with gif animations.
66 |
67 | Args:
68 | tag: Name of the summary.
69 | images: A 5-D `uint8` `np.array` of shape `[batch_size, time, height,
70 | width, channels]` where `channels` is 1 or 3.
71 | max_outputs: Max number of batch elements to generate gifs for.
72 | fps: frames per second of the animation
73 |
74 | Returns:
75 | The serialized `Summary` protocol buffer.
76 |
77 | Raises:
78 | ValueError: If `images` is not a 5-D `uint8` array with 1 or 3 channels.
79 | """
80 | is_bytes = isinstance(tag, bytes)
81 | if is_bytes:
82 | tag = tag.decode("utf-8")
83 | images = np.asarray(images)
84 | if images.dtype != np.uint8:
85 | raise ValueError("Tensor must have dtype uint8 for gif summary.")
86 | if images.ndim != 5:
87 | raise ValueError("Tensor must be 5-D for gif summary.")
88 | batch_size, _, height, width, channels = images.shape
89 | if channels not in (1, 3):
90 | raise ValueError("Tensors must have 1 or 3 channels for gif summary.")
91 | summ = tf.Summary()
92 | num_outputs = min(batch_size, max_outputs)
93 | for i in range(num_outputs):
94 | image_summ = tf.Summary.Image()
95 | image_summ.height = height
96 | image_summ.width = width
97 | image_summ.colorspace = channels # 1: grayscale, 3: RGB
98 | try:
99 | image_summ.encoded_image_string = encode_gif(images[i], fps)
100 | except (IOError, OSError) as e:
101 | tf.logging.warning(
102 | "Unable to encode images to a gif string because either ffmpeg is "
103 | "not installed or ffmpeg returned an error: %s. Falling back to an "
104 | "image summary of the first frame in the sequence.", e)
105 | try:
106 | from PIL import Image # pylint: disable=g-import-not-at-top
107 | import io # pylint: disable=g-import-not-at-top
108 | with io.BytesIO() as output:
109 | Image.fromarray(images[i][0]).save(output, "PNG")
110 | image_summ.encoded_image_string = output.getvalue()
111 | except Exception:
112 | tf.logging.warning(
113 | "Gif summaries requires ffmpeg or PIL to be installed: %s", e)
114 | image_summ.encoded_image_string = (
115 | "".encode('utf-8') if is_bytes else "")
116 | if num_outputs == 1:
117 | summ_tag = "{}/gif".format(tag)
118 | else:
119 | summ_tag = "{}/gif/{}".format(tag, i)
120 | summ.value.add(tag=summ_tag, image=image_summ)
121 | summ_str = summ.SerializeToString()
122 | return summ_str
123 |
124 |
125 | def gif_summary(name, tensor, max_outputs, fps, collections=None, family=None):
126 | """Outputs a `Summary` protocol buffer with gif animations.
127 |
128 | Args:
129 | name: Name of the summary.
130 | tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width,
131 | channels]` where `channels` is 1 or 3.
132 | max_outputs: Max number of batch elements to generate gifs for.
133 | fps: frames per second of the animation
134 | collections: Optional list of tf.GraphKeys. The collections to add the
135 | summary to. Defaults to [tf.GraphKeys.SUMMARIES]
136 | family: Optional; if provided, used as the prefix of the summary tag name,
137 | which controls the tab name used for display on Tensorboard.
138 |
139 | Returns:
140 | A scalar `Tensor` of type `string`. The serialized `Summary` protocol
141 | buffer.
142 | """
143 | tensor = tf.convert_to_tensor(tensor)
144 | if tensor.dtype in (tf.float32, tf.float64):
145 | tensor = tf.cast(255.0 * tensor, tf.uint8)
146 | with summary_op_util.summary_scope(
147 | name, family, values=[tensor]) as (tag, scope):
148 | val = tf.py_func(
149 | py_gif_summary,
150 | [tag, tensor, max_outputs, fps],
151 | tf.string,
152 | stateful=False,
153 | name=scope)
154 | summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES])
155 | return val
156 |
--------------------------------------------------------------------------------
/planet/tools/image_strip_summary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def image_strip_summary(name, images, max_length=100, max_batch=10):
23 | """Create an image summary that places frames of a video tensor side by side.
24 |
25 | Args:
26 | name: Name tag of the summary.
27 | images: Tensor with the dimensions batch, time, height, width, channels.
28 | max_length: Maximum number of frames per sequence to include.
29 | max_batch: Maximum number of sequences to include.
30 |
31 | Returns:
32 | Summary string tensor.
33 | """
34 | if max_batch:
35 | images = images[:max_batch]
36 | if max_length:
37 | images = images[:, :max_length]
38 | if images.dtype == tf.uint8:
39 | images = tf.to_float(images) / 255.0
40 | length, width = tf.shape(images)[1], tf.shape(images)[3]
41 | images = tf.transpose(images, [0, 2, 1, 3, 4])
42 | images = tf.reshape(images, [1, -1, length * width, 3])
43 | images = tf.clip_by_value(images, 0., 1.)
44 | return tf.summary.image(name, images)
45 |
--------------------------------------------------------------------------------
/planet/tools/mask.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def mask(tensor, mask=None, length=None, value=0, debug=False):
23 | """Set padding elements of a batch of sequences to a constant.
24 |
25 | Useful for setting padding elements to zero before summing along the time
26 | dimension, or for preventing infinite results in padding elements. Either
27 | mask or length must be provided.
28 |
29 | Args:
30 | tensor: Tensor of sequences.
31 | mask: Boolean mask of valid indices.
32 | length: Batch of sequence lengths.
33 | value: Value to write into padding elemnts.
34 | debug: Test for infinite values; slows down performance.
35 |
36 | Raises:
37 | KeyError: If both or non of `mask` and `length` are provided.
38 |
39 | Returns:
40 | Masked sequences.
41 | """
42 | if len([x for x in (mask, length) if x is not None]) != 1:
43 | raise KeyError('Exactly one of mask and length must be provided.')
44 | with tf.name_scope('mask'):
45 | if mask is None:
46 | range_ = tf.range(tensor.shape[1].value)
47 | mask = range_[None, :] < length[:, None]
48 | batch_dims = mask.shape.ndims
49 | while tensor.shape.ndims > mask.shape.ndims:
50 | mask = mask[..., None]
51 | multiples = [1] * batch_dims + tensor.shape[batch_dims:].as_list()
52 | mask = tf.tile(mask, multiples)
53 | masked = tf.where(mask, tensor, value * tf.ones_like(tensor))
54 | if debug:
55 | masked = tf.check_numerics(masked, 'masked')
56 | return masked
57 |
--------------------------------------------------------------------------------
/planet/tools/nested.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tools for manipulating nested tuples, list, and dictionaries."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 |
20 | _builtin_zip = zip
21 | _builtin_map = map
22 | _builtin_filter = filter
23 |
24 |
25 | def zip_(*structures, **kwargs):
26 | """Combine corresponding elements in multiple nested structure to tuples.
27 |
28 | The nested structures can consist of any combination of lists, tuples, and
29 | dicts. All provided structures must have the same nesting.
30 |
31 | Args:
32 | *structures: Nested structures.
33 | flatten: Whether to flatten the resulting structure into a tuple. Keys of
34 | dictionaries will be discarded.
35 |
36 | Returns:
37 | Nested structure.
38 | """
39 | # Named keyword arguments are not allowed after *args in Python 2.
40 | flatten = kwargs.pop('flatten', False)
41 | assert not kwargs, 'zip() got unexpected keyword arguments.'
42 | return map(
43 | lambda *x: x if len(x) > 1 else x[0],
44 | *structures,
45 | flatten=flatten)
46 |
47 |
48 | def map_(function, *structures, **kwargs):
49 | """Apply a function to every element in a nested structure.
50 |
51 | If multiple structures are provided as input, their structure must match and
52 | the function will be applied to corresponding groups of elements. The nested
53 | structure can consist of any combination of lists, tuples, and dicts.
54 |
55 | Args:
56 | function: The function to apply to the elements of the structure. Receives
57 | one argument for every structure that is provided.
58 | *structures: One of more nested structures.
59 | flatten: Whether to flatten the resulting structure into a tuple. Keys of
60 | dictionaries will be discarded.
61 |
62 | Returns:
63 | Nested structure.
64 | """
65 | # Named keyword arguments are not allowed after *args in Python 2.
66 | flatten = kwargs.pop('flatten', False)
67 | assert not kwargs, 'map() got unexpected keyword arguments.'
68 |
69 | def impl(function, *structures):
70 | if len(structures) == 0:
71 | return structures
72 | if all(isinstance(s, (tuple, list)) for s in structures):
73 | if len(set(len(x) for x in structures)) > 1:
74 | raise ValueError('Cannot merge tuples or lists of different length.')
75 | args = tuple((impl(function, *x) for x in _builtin_zip(*structures)))
76 | if hasattr(structures[0], '_fields'): # namedtuple
77 | return type(structures[0])(*args)
78 | else: # tuple, list
79 | return type(structures[0])(args)
80 | if all(isinstance(s, dict) for s in structures):
81 | if len(set(frozenset(x.keys()) for x in structures)) > 1:
82 | raise ValueError('Cannot merge dicts with different keys.')
83 | merged = {
84 | k: impl(function, *(s[k] for s in structures))
85 | for k in structures[0]}
86 | return type(structures[0])(merged)
87 | return function(*structures)
88 |
89 | result = impl(function, *structures)
90 | if flatten:
91 | result = flatten_(result)
92 | return result
93 |
94 |
95 | def flatten_(structure):
96 | """Combine all leaves of a nested structure into a tuple.
97 |
98 | The nested structure can consist of any combination of tuples, lists, and
99 | dicts. Dictionary keys will be discarded but values will ordered by the
100 | sorting of the keys.
101 |
102 | Args:
103 | structure: Nested structure.
104 |
105 | Returns:
106 | Flat tuple.
107 | """
108 | if isinstance(structure, dict):
109 | result = ()
110 | for key in sorted(list(structure.keys())):
111 | result += flatten_(structure[key])
112 | return result
113 | if isinstance(structure, (tuple, list)):
114 | result = ()
115 | for element in structure:
116 | result += flatten_(element)
117 | return result
118 | return (structure,)
119 |
120 |
121 | def filter_(predicate, *structures, **kwargs):
122 | """Select elements of a nested structure based on a predicate function.
123 |
124 | If multiple structures are provided as input, their structure must match and
125 | the function will be applied to corresponding groups of elements. The nested
126 | structure can consist of any combination of lists, tuples, and dicts.
127 |
128 | Args:
129 | predicate: The function to determine whether an element should be kept.
130 | Receives one argument for every structure that is provided.
131 | *structures: One of more nested structures.
132 | flatten: Whether to flatten the resulting structure into a tuple. Keys of
133 | dictionaries will be discarded.
134 |
135 | Returns:
136 | Nested structure.
137 | """
138 | # Named keyword arguments are not allowed after *args in Python 2.
139 | flatten = kwargs.pop('flatten', False)
140 | assert not kwargs, 'filter() got unexpected keyword arguments.'
141 |
142 | def impl(predicate, *structures):
143 | if len(structures) == 0:
144 | return structures
145 | if all(isinstance(s, (tuple, list)) for s in structures):
146 | if len(set(len(x) for x in structures)) > 1:
147 | raise ValueError('Cannot merge tuples or lists of different length.')
148 | # Only wrap in tuples if more than one structure provided.
149 | if len(structures) > 1:
150 | filtered = (impl(predicate, *x) for x in _builtin_zip(*structures))
151 | else:
152 | filtered = (impl(predicate, x) for x in structures[0])
153 | # Remove empty containers and construct result structure.
154 | if hasattr(structures[0], '_fields'): # namedtuple
155 | filtered = (x if x != () else None for x in filtered)
156 | return type(structures[0])(*filtered)
157 | else: # tuple, list
158 | filtered = (
159 | x for x in filtered if not isinstance(x, (tuple, list, dict)) or x)
160 | return type(structures[0])(filtered)
161 | if all(isinstance(s, dict) for s in structures):
162 | if len(set(frozenset(x.keys()) for x in structures)) > 1:
163 | raise ValueError('Cannot merge dicts with different keys.')
164 | # Only wrap in tuples if more than one structure provided.
165 | if len(structures) > 1:
166 | filtered = {
167 | k: impl(predicate, *(s[k] for s in structures))
168 | for k in structures[0]}
169 | else:
170 | filtered = {k: impl(predicate, v) for k, v in structures[0].items()}
171 | # Remove empty containers and construct result structure.
172 | filtered = {
173 | k: v for k, v in filtered.items()
174 | if not isinstance(v, (tuple, list, dict)) or v}
175 | return type(structures[0])(filtered)
176 | if len(structures) > 1:
177 | return structures if predicate(*structures) else ()
178 | else:
179 | return structures[0] if predicate(structures[0]) else ()
180 |
181 | result = impl(predicate, *structures)
182 | if flatten:
183 | result = flatten_(result)
184 | return result
185 |
186 |
187 | zip = zip_
188 | map = map_
189 | flatten = flatten_
190 | filter = filter_
191 |
--------------------------------------------------------------------------------
/planet/tools/numpy_episodes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Load tensors from a directory of numpy files."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 | import os
23 | import random
24 |
25 | from scipy.ndimage import interpolation
26 | import numpy as np
27 | import tensorflow as tf
28 |
29 | from planet.tools import attr_dict
30 | from planet.tools import chunk_sequence
31 |
32 |
33 | def numpy_episodes(
34 | train_dir, test_dir, shape, reader=None, loader=None,
35 | num_chunks=None, preprocess_fn=None):
36 | """Read sequences stored as compressed Numpy files as a TensorFlow dataset.
37 |
38 | Args:
39 | train_dir: Directory containing NPZ files of the training dataset.
40 | test_dir: Directory containing NPZ files of the testing dataset.
41 | shape: Tuple of batch size and chunk length for the datasets.
42 | reader: Callable that reads an episode from a NPZ filename.
43 | loader: Generator that yields episodes.
44 |
45 | Returns:
46 | Structured data from numpy episodes as Tensors.
47 | """
48 | reader = reader or episode_reader
49 | loader = loader or cache_loader
50 | try:
51 | dtypes, shapes = _read_spec(reader, train_dir)
52 | except ZeroDivisionError:
53 | dtypes, shapes = _read_spec(reader, test_dir)
54 | train = tf.data.Dataset.from_generator(
55 | functools.partial(loader, reader, train_dir, shape[0]),
56 | dtypes, shapes)
57 | test = tf.data.Dataset.from_generator(
58 | functools.partial(loader, reader, test_dir, shape[0]),
59 | dtypes, shapes)
60 | chunking = lambda x: tf.data.Dataset.from_tensor_slices(
61 | chunk_sequence.chunk_sequence(x, shape[1], True, num_chunks))
62 | def sequence_preprocess_fn(sequence):
63 | if preprocess_fn:
64 | sequence['image'] = preprocess_fn(sequence['image'])
65 | return sequence
66 | train = train.flat_map(chunking)
67 | train = train.batch(shape[0], drop_remainder=True)
68 | train = train.map(sequence_preprocess_fn, 10).prefetch(10)
69 | test = test.flat_map(chunking)
70 | test = test.batch(shape[0], drop_remainder=True)
71 | test = test.map(sequence_preprocess_fn, 10).prefetch(10)
72 | return attr_dict.AttrDict(train=train, test=test)
73 |
74 |
75 | def cache_loader(reader, directory, batch_size, every):
76 | cache = {}
77 | while True:
78 | episodes = _sample(cache.values(), every)
79 | for episode in _permuted(episodes, every):
80 | yield episode
81 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz'))
82 | filenames = [filename for filename in filenames if filename not in cache]
83 | for filename in filenames:
84 | cache[filename] = reader(filename)
85 |
86 |
87 | def recent_loader(reader, directory, batch_size, every):
88 | recent = {}
89 | cache = {}
90 | while True:
91 | episodes = []
92 | episodes += _sample(recent.values(), every // 2)
93 | episodes += _sample(cache.values(), every // 2)
94 | for episode in _permuted(episodes, every):
95 | yield episode
96 | cache.update(recent)
97 | recent = {}
98 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz'))
99 | filenames = [filename for filename in filenames if filename not in cache]
100 | for filename in filenames:
101 | recent[filename] = reader(filename)
102 |
103 |
104 | def reload_loader(reader, directory, batch_size):
105 | directory = os.path.expanduser(directory)
106 | while True:
107 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz'))
108 | random.shuffle(filenames)
109 | for filename in filenames:
110 | yield reader(filename)
111 |
112 |
113 | def dummy_loader(reader, directory, batch_size):
114 | random = np.random.RandomState(seed=0)
115 | dtypes, shapes, length = _read_spec(reader, directory, True, True)
116 | while True:
117 | episode = {}
118 | for key in dtypes:
119 | dtype, shape = dtypes[key], (length,) + shapes[key][1:]
120 | if dtype in (np.float32, np.float64):
121 | episode[key] = random.uniform(0, 1, shape).astype(dtype)
122 | elif dtype in (np.int32, np.int64, np.uint8):
123 | episode[key] = random.uniform(0, 255, shape).astype(dtype)
124 | else:
125 | raise NotImplementedError('Unsupported dtype {}.'.format(dtype))
126 | yield episode
127 |
128 |
129 | def episode_reader(filename, resize=None, max_length=None, action_noise=None):
130 | with tf.gfile.Open(filename, 'rb') as file_:
131 | episode = np.load(file_)
132 | episode = {key: _convert_type(episode[key]) for key in episode.keys()}
133 | episode['return'] = np.cumsum(episode['reward'])
134 | if max_length:
135 | episode = {key: value[:max_length] for key, value in episode.items()}
136 | if resize and resize != 1:
137 | factors = (1, resize, resize, 1)
138 | episode['image'] = interpolation.zoom(episode['image'], factors)
139 | if action_noise:
140 | seed = np.fromstring(filename, dtype=np.uint8)
141 | episode['action'] += np.random.RandomState(seed).normal(
142 | 0, action_noise, episode['action'].shape)
143 | return episode
144 |
145 |
146 | def _read_spec(
147 | reader, directory, return_length=False, numpy_types=False):
148 | episodes = reload_loader(reader, directory, batch_size=1)
149 | episode = next(episodes)
150 | episodes.close()
151 | dtypes = {key: value.dtype for key, value in episode.items()}
152 | if not numpy_types:
153 | dtypes = {key: tf.as_dtype(value) for key, value in dtypes.items()}
154 | shapes = {key: value.shape for key, value in episode.items()}
155 | shapes = {key: (None,) + shape[1:] for key, shape in shapes.items()}
156 | if return_length:
157 | length = len(episode[list(shapes.keys())[0]])
158 | return dtypes, shapes, length
159 | else:
160 | return dtypes, shapes
161 |
162 |
163 | def _convert_type(array):
164 | if array.dtype == np.float64:
165 | return array.astype(np.float32)
166 | if array.dtype == np.int64:
167 | return array.astype(np.int32)
168 | return array
169 |
170 |
171 | def _sample(sequence, amount):
172 | sequence = list(sequence)
173 | amount = min(amount, len(sequence))
174 | return random.sample(sequence, amount)
175 |
176 |
177 | def _permuted(sequence, amount):
178 | sequence = list(sequence)
179 | if not sequence:
180 | return
181 | index = 0
182 | while True:
183 | for element in np.random.permutation(sequence):
184 | if index >= amount:
185 | return
186 | yield element
187 | index += 1
188 |
--------------------------------------------------------------------------------
/planet/tools/overshooting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import functools
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 |
24 | from planet.tools import nested
25 | from planet.tools import shape
26 |
27 |
28 | def overshooting(
29 | cell, target, embedded, prev_action, length, amount, posterior=None,
30 | ignore_input=False):
31 | """Perform open loop rollouts from the posteriors at every step.
32 |
33 | First, we apply the encoder to embed raw inputs and apply the model to obtain
34 | posterior states for every time step. Then, we perform `amount` long open
35 | loop rollouts from these posteriors.
36 |
37 | Note that the actions should be those leading to the current time step. So
38 | under common convention, it contains the last actions while observations are
39 | the current ones.
40 |
41 | Input:
42 |
43 | target, embedded:
44 | [A B C D E F] [A B C D E ]
45 |
46 | prev_action:
47 | [0 A B C D E] [0 A B C D ]
48 |
49 | length:
50 | [6 5]
51 |
52 | amount:
53 | 3
54 |
55 | Output:
56 |
57 | prior, posterior, target:
58 | [A B C D E F] [A B C D E ]
59 | [B C D E F ] [B C D E ]
60 | [C D E F ] [C D E ]
61 | [D E F ] [D E ]
62 |
63 | mask:
64 | [1 1 1 1 1 1] [1 1 1 1 1 0]
65 | [1 1 1 1 1 0] [1 1 1 1 0 0]
66 | [1 1 1 1 0 0] [1 1 1 0 0 0]
67 | [1 1 1 0 0 0] [1 1 0 0 0 0]
68 |
69 | """
70 | # Closed loop unroll to get posterior states, which are the starting points
71 | # for open loop unrolls. We don't need the last time step, since we have no
72 | # targets for unrolls from it.
73 | if posterior is None:
74 | use_obs = tf.ones(tf.shape(
75 | nested.flatten(embedded)[0][:, :, :1])[:3], tf.bool)
76 | use_obs = tf.cond(
77 | tf.convert_to_tensor(ignore_input),
78 | lambda: tf.zeros_like(use_obs, tf.bool),
79 | lambda: use_obs)
80 | (_, posterior), _ = tf.nn.dynamic_rnn(
81 | cell, (embedded, prev_action, use_obs), length, dtype=tf.float32,
82 | swap_memory=True)
83 |
84 | # Arrange inputs for every iteration in the open loop unroll. Every loop
85 | # iteration below corresponds to one row in the docstring illustration.
86 | max_length = shape.shape(nested.flatten(embedded)[0])[1]
87 | first_output = {
88 | # 'observ': embedded,
89 | 'prev_action': prev_action,
90 | 'posterior': posterior,
91 | 'target': target,
92 | 'mask': tf.sequence_mask(length, max_length, tf.int32),
93 | }
94 |
95 | progress_fn = lambda tensor: tf.concat([tensor[:, 1:], 0 * tensor[:, :1]], 1)
96 | other_outputs = tf.scan(
97 | lambda past_output, _: nested.map(progress_fn, past_output),
98 | tf.range(amount), first_output)
99 | sequences = nested.map(
100 | lambda lhs, rhs: tf.concat([lhs[None], rhs], 0),
101 | first_output, other_outputs)
102 |
103 | # Merge batch and time dimensions of steps to compute unrolls from every
104 | # time step as one batch. The time dimension becomes the number of
105 | # overshooting distances.
106 | sequences = nested.map(
107 | lambda tensor: _merge_dims(tensor, [1, 2]),
108 | sequences)
109 | sequences = nested.map(
110 | lambda tensor: tf.transpose(
111 | tensor, [1, 0] + list(range(2, tensor.shape.ndims))),
112 | sequences)
113 | merged_length = tf.reduce_sum(sequences['mask'], 1)
114 |
115 | # Mask out padding frames; unnecessary if the input is already masked.
116 | sequences = nested.map(
117 | lambda tensor: tensor * tf.cast(
118 | _pad_dims(sequences['mask'], tensor.shape.ndims),
119 | tensor.dtype),
120 | sequences)
121 |
122 | # Compute open loop rollouts.
123 | use_obs = tf.zeros(tf.shape(sequences['mask']), tf.bool)[..., None]
124 | embed_size = nested.flatten(embedded)[0].shape[2].value
125 | obs = tf.zeros(shape.shape(sequences['mask']) + [embed_size])
126 | prev_state = nested.map(
127 | lambda tensor: tf.concat([0 * tensor[:, :1], tensor[:, :-1]], 1),
128 | posterior)
129 | prev_state = nested.map(
130 | lambda tensor: _merge_dims(tensor, [0, 1]), prev_state)
131 | (priors, _), _ = tf.nn.dynamic_rnn(
132 | cell, (obs, sequences['prev_action'], use_obs),
133 | merged_length,
134 | prev_state)
135 |
136 | # Restore batch dimension.
137 | target, prior, posterior, mask = nested.map(
138 | functools.partial(_restore_batch_dim, batch_size=shape.shape(length)[0]),
139 | (sequences['target'], priors, sequences['posterior'], sequences['mask']))
140 |
141 | mask = tf.cast(mask, tf.bool)
142 | return target, prior, posterior, mask
143 |
144 |
145 | def _merge_dims(tensor, dims):
146 | """Flatten consecutive axes of a tensor trying to preserve static shapes."""
147 | if isinstance(tensor, (list, tuple, dict)):
148 | return nested.map(tensor, lambda x: _merge_dims(x, dims))
149 | tensor = tf.convert_to_tensor(tensor)
150 | if (np.array(dims) - min(dims) != np.arange(len(dims))).all():
151 | raise ValueError('Dimensions to merge must all follow each other.')
152 | start, end = dims[0], dims[-1]
153 | output = tf.reshape(tensor, tf.concat([
154 | tf.shape(tensor)[:start],
155 | [tf.reduce_prod(tf.shape(tensor)[start: end + 1])],
156 | tf.shape(tensor)[end + 1:]], axis=0))
157 | merged = tensor.shape[start: end + 1].as_list()
158 | output.set_shape(
159 | tensor.shape[:start].as_list() +
160 | [None if None in merged else np.prod(merged)] +
161 | tensor.shape[end + 1:].as_list())
162 | return output
163 |
164 |
165 | def _pad_dims(tensor, rank):
166 | """Append empty dimensions to the tensor until it is of the given rank."""
167 | for _ in range(rank - tensor.shape.ndims):
168 | tensor = tensor[..., None]
169 | return tensor
170 |
171 |
172 | def _restore_batch_dim(tensor, batch_size):
173 | """Split batch dimension out of the first dimension of a tensor."""
174 | initial = shape.shape(tensor)
175 | desired = [batch_size, initial[0] // batch_size] + initial[1:]
176 | return tf.reshape(tensor, desired)
177 |
--------------------------------------------------------------------------------
/planet/tools/preprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def preprocess(image, bits):
23 | bins = 2 ** bits
24 | image = tf.to_float(image)
25 | if bits < 8:
26 | image = tf.floor(image / 2 ** (8 - bits))
27 | image = image / bins
28 | image = image + tf.random_uniform(tf.shape(image), 0, 1.0 / bins)
29 | image = image - 0.5
30 | return image
31 |
32 |
33 | def postprocess(image, bits, dtype=tf.float32):
34 | bins = 2 ** bits
35 | if dtype == tf.float32:
36 | image = tf.floor(bins * (image + 0.5)) / bins
37 | elif dtype == tf.uint8:
38 | image = image + 0.5
39 | image = tf.floor(bins * image)
40 | image = image * (256.0 / bins)
41 | image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8)
42 | else:
43 | raise NotImplementedError(dtype)
44 | return image
45 |
--------------------------------------------------------------------------------
/planet/tools/reshape_as.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from planet.tools import nested
22 |
23 |
24 | def reshape_as(tensor, reference):
25 | if isinstance(tensor, (list, tuple, dict)):
26 | return nested.map(tensor, lambda x: reshape_as(x, reference))
27 | tensor = tf.convert_to_tensor(tensor)
28 | reference = tf.convert_to_tensor(reference)
29 | statics = reference.shape.as_list()
30 | dynamics = tf.shape(reference)
31 | shape = [
32 | static if static is not None else dynamics[index]
33 | for index, static in enumerate(statics)]
34 | return tf.reshape(tensor, shape)
35 |
--------------------------------------------------------------------------------
/planet/tools/schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import tensorflow as tf
21 |
22 |
23 | def binary(step, batch_size, after, every, until):
24 | # https://www.desmos.com/calculator/csbhr5cjzz
25 | offset_step = step - after
26 | phase = tf.less(offset_step % every, batch_size)
27 | active = tf.greater_equal(step, after)
28 | if until > 0:
29 | active = tf.logical_and(active, tf.less(step, until))
30 | result = tf.logical_and(phase, active)
31 | result.set_shape(tf.TensorShape([]))
32 | return result
33 |
34 |
35 | def linear(step, ramp, min=None, max=None):
36 | # https://www.desmos.com/calculator/nrumhgvxql
37 | if ramp == 0:
38 | result = tf.constant(1, tf.float32)
39 | if ramp > 0:
40 | result = tf.minimum(tf.to_float(step) / tf.to_float(ramp), 1)
41 | if ramp < 0:
42 | result = 1 - linear(step, abs(ramp))
43 | if min is not None and max is not None:
44 | assert min <= max
45 | if min is not None:
46 | assert 0 <= min <= 1
47 | result = tf.maximum(min, result)
48 | if max is not None:
49 | assert 0 <= min <= 1
50 | result = tf.minimum(result, max)
51 | result.set_shape(tf.TensorShape([]))
52 | return result
53 |
54 |
55 | def linear_reset(step, ramp, after, every):
56 | # https://www.desmos.com/calculator/motbnqhacw
57 | assert every > ramp, (every, ramp) # Would never reach max value.
58 | assert not (every != np.inf and after == np.inf), (every, after)
59 | step, ramp, after, every = [
60 | tf.to_float(x) for x in (step, ramp, after, every)]
61 | before = tf.to_float(tf.less(step, after)) * step
62 | after = tf.to_float(tf.greater_equal(step, after)) * ((step - after) % every)
63 | result = tf.minimum((before + after) / ramp, 1)
64 | result.set_shape(tf.TensorShape([]))
65 | return result
66 |
--------------------------------------------------------------------------------
/planet/tools/shape.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def shape(tensor):
23 | static = tensor.get_shape().as_list()
24 | dynamic = tf.unstack(tf.shape(tensor))
25 | assert len(static) == len(dynamic)
26 | combined = [d if s is None else s for s, d in zip(static, dynamic)]
27 | return combined
28 |
--------------------------------------------------------------------------------
/planet/tools/streaming_mean.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 |
22 | class StreamingMean(object):
23 | """Compute a streaming estimation of the mean of submitted tensors."""
24 |
25 | def __init__(self, shape, dtype, name):
26 | """Specify the shape and dtype of the mean to be estimated.
27 |
28 | Note that a float mean to zero submitted elements is NaN, while computing
29 | the integer mean of zero elements raises a division by zero error.
30 |
31 | Args:
32 | shape: Shape of the mean to compute.
33 | dtype: Data type of the mean to compute.
34 | """
35 | self._dtype = dtype
36 | with tf.variable_scope(name):
37 | self._sum = tf.get_variable(
38 | 'sum', shape, dtype,
39 | tf.constant_initializer(0),
40 | trainable=False)
41 | self._count = tf.get_variable(
42 | 'count', (), tf.int32,
43 | tf.constant_initializer(0),
44 | trainable=False)
45 |
46 | @property
47 | def value(self):
48 | """The current value of the mean."""
49 | return self._sum / tf.cast(self._count, self._dtype)
50 |
51 | @property
52 | def count(self):
53 | """The number of submitted samples."""
54 | return self._count
55 |
56 | def submit(self, value):
57 | """Submit a single or batch tensor to refine the streaming mean."""
58 | value = tf.convert_to_tensor(value)
59 | # Add a batch dimension if necessary.
60 | if value.shape.ndims == self._sum.shape.ndims:
61 | value = value[None, ...]
62 | if str(value.shape[1:]) != str(self._sum.shape):
63 | message = 'Value shape ({}) does not fit tracked tensor ({}).'
64 | raise ValueError(message.format(value.shape[1:], self._sum.shape))
65 | def assign():
66 | return tf.group(
67 | self._sum.assign_add(tf.reduce_sum(value, 0)),
68 | self._count.assign_add(tf.shape(value)[0]))
69 | not_empty = tf.cast(tf.reduce_prod(tf.shape(value)), tf.bool)
70 | return tf.cond(not_empty, assign, tf.no_op)
71 |
72 | def clear(self):
73 | """Return the mean estimate and reset the streaming statistics."""
74 | value = self._sum / tf.cast(self._count, self._dtype)
75 | with tf.control_dependencies([value]):
76 | reset_value = self._sum.assign(tf.zeros_like(self._sum))
77 | reset_count = self._count.assign(0)
78 | with tf.control_dependencies([reset_value, reset_count]):
79 | return tf.identity(value)
80 |
--------------------------------------------------------------------------------
/planet/tools/summary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import matplotlib
20 | matplotlib.use('Agg')
21 | import matplotlib.pyplot as plt
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 | from planet.tools import count_dataset
26 | from planet.tools import gif_summary
27 | from planet.tools import image_strip_summary
28 | from planet.tools import mask as masklib
29 | from planet.tools import shape as shapelib
30 |
31 |
32 | def plot_summary(titles, lines, labels, name):
33 | """Plot lines using matplotlib and create a TensorFlow summary from it.
34 |
35 | Note that only one instance of this summary can be computed at the same time.
36 | This is because matplotlib uses global state. A workaround is to make earlier
37 | plot summaries control dependences of later ones.
38 |
39 | Args:
40 | titles: List of titles for the subplots.
41 | lines: Nested list of tensors. Each list contains the lines of another
42 | subplot in the figure.
43 | labels: Nested list of strings. Each list contains the names for the lines
44 | of another subplot in the figure. Can be None for any of the sub plots.
45 | name: Name of the summary.
46 |
47 | Returns:
48 | Summary tensor.
49 | """
50 |
51 | def body_fn(lines):
52 | fig, axes = plt.subplots(
53 | nrows=len(titles), ncols=1, sharex=True, sharey=False,
54 | squeeze=False, figsize=(6, 3 * len(lines)))
55 | axes = axes[:, 0]
56 | for index, ax in enumerate(axes):
57 | ax.set_title(titles[index])
58 | for line, label in zip(lines[index], labels[index]):
59 | ax.plot(line, label=label)
60 | if any(labels[index]):
61 | ax.legend(frameon=False)
62 | fig.tight_layout()
63 | fig.canvas.draw()
64 | image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
65 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
66 | plt.close(fig)
67 | return image
68 |
69 | image = tf.py_func(body_fn, (lines,), tf.uint8)
70 | image = image[None]
71 | summary = tf.summary.image(name, image)
72 | return summary
73 |
74 |
75 | def data_summaries(data, postprocess_fn, histograms=False, name='data'):
76 | summaries = []
77 | with tf.variable_scope(name):
78 | if histograms:
79 | for key, value in data.items():
80 | if key in ('image',):
81 | continue
82 | summaries.append(tf.summary.histogram(key, data[key]))
83 | image = data['image']
84 | if postprocess_fn:
85 | image = postprocess_fn(image)
86 | summaries.append(image_strip_summary.image_strip_summary('image', image))
87 | return summaries
88 |
89 |
90 | def dataset_summaries(directory, name='dataset'):
91 | summaries = []
92 | with tf.variable_scope(name):
93 | episodes = count_dataset.count_dataset(directory)
94 | summaries.append(tf.summary.scalar('episodes', episodes))
95 | return summaries
96 |
97 |
98 | def state_summaries(
99 | cell, prior, posterior, mask, histograms=False, name='state'):
100 | summaries = []
101 | divergence = cell.divergence_from_states(posterior, prior, mask)
102 | prior = cell.dist_from_state(prior, mask)
103 | posterior = cell.dist_from_state(posterior, mask)
104 | prior_entropy = prior.entropy()
105 | posterior_entropy = posterior.entropy()
106 | nan_to_num = lambda x: tf.where(tf.is_nan(x), tf.zeros_like(x), x)
107 | with tf.variable_scope(name):
108 | if histograms:
109 | summaries.append(tf.summary.histogram(
110 | 'prior_entropy_hist', nan_to_num(prior_entropy)))
111 | summaries.append(tf.summary.scalar(
112 | 'prior_entropy', tf.reduce_mean(prior_entropy)))
113 | summaries.append(tf.summary.scalar(
114 | 'prior_std', tf.reduce_mean(prior.stddev())))
115 | if histograms:
116 | summaries.append(tf.summary.histogram(
117 | 'posterior_entropy_hist', nan_to_num(posterior_entropy)))
118 | summaries.append(tf.summary.scalar(
119 | 'posterior_entropy', tf.reduce_mean(posterior_entropy)))
120 | summaries.append(tf.summary.scalar(
121 | 'posterior_std', tf.reduce_mean(posterior.stddev())))
122 | summaries.append(tf.summary.scalar(
123 | 'divergence', tf.reduce_mean(divergence)))
124 | return summaries
125 |
126 |
127 | def dist_summaries(dists, obs, mask, name='dist_summaries'):
128 | summaries = []
129 | with tf.variable_scope(name):
130 | for name, dist in dists.items():
131 | mode = dist.mode()
132 | mode_mean, mode_var = tf.nn.moments(mode, list(range(mode.shape.ndims)))
133 | mode_std = tf.sqrt(mode_var)
134 | summaries.append(tf.summary.scalar(name + '_mode_mean', mode_mean))
135 | summaries.append(tf.summary.scalar(name + '_mode_std', mode_std))
136 | std = dist.stddev()
137 | std_mean, std_var = tf.nn.moments(std, list(range(std.shape.ndims)))
138 | std_std = tf.sqrt(std_var)
139 | summaries.append(tf.summary.scalar(name + '_std_mean', std_mean))
140 | summaries.append(tf.summary.scalar(name + '_std_std', std_std))
141 | if name in obs:
142 | log_prob = tf.reduce_mean(dist.log_prob(obs[name]))
143 | summaries.append(tf.summary.scalar(name + '_log_prob', log_prob))
144 | abs_error = tf.reduce_mean(tf.abs(dist.mode() - obs[name]))
145 | summaries.append(tf.summary.scalar(name + '_abs_error', abs_error))
146 | return summaries
147 |
148 |
149 | def image_summaries(dist, target, name='image', max_batch=10):
150 | summaries = []
151 | with tf.variable_scope(name):
152 | empty_frame = 0 * target[:max_batch, :1]
153 | image = dist.mode()[:max_batch]
154 | target = target[:max_batch]
155 | change = tf.concat([empty_frame, image[:, 1:] - image[:, :-1]], 1)
156 | error = image - target
157 | summaries.append(image_strip_summary.image_strip_summary(
158 | 'prediction', image))
159 | summaries.append(image_strip_summary.image_strip_summary(
160 | 'change', (change + 1) / 2))
161 | summaries.append(image_strip_summary.image_strip_summary(
162 | 'error', (error + 1) / 2))
163 | # Concat prediction and target vertically.
164 | frames = tf.concat([target, image], 2)
165 | # Stack batch entries horizontally.
166 | frames = tf.transpose(frames, [1, 2, 0, 3, 4])
167 | s = shapelib.shape(frames)
168 | frames = tf.reshape(frames, [s[0], s[1], s[2] * s[3], s[4]])
169 | summaries.append(gif_summary.gif_summary(
170 | 'animation', frames[None], max_outputs=1, fps=20))
171 | return summaries
172 |
173 |
174 | def objective_summaries(objectives, name='objectives'):
175 | summaries = []
176 | with tf.variable_scope(name):
177 | for objective in objectives:
178 | summaries.append(tf.summary.scalar(objective.name, objective.value))
179 | return summaries
180 |
181 |
182 | def prediction_summaries(dists, data, state, name='state'):
183 | summaries = []
184 | with tf.variable_scope(name):
185 | # Predictions.
186 | log_probs = {}
187 | for key, dist in dists.items():
188 | if key in ('image',):
189 | continue
190 | if key not in data:
191 | continue
192 | # We only look at the first example in the batch.
193 | log_prob = dist.log_prob(data[key])[0]
194 | prediction = dist.mode()[0]
195 | truth = data[key][0]
196 | plot_name = key
197 | # Ensure that there is a feature dimension.
198 | if prediction.shape.ndims == 1:
199 | prediction = prediction[:, None]
200 | truth = truth[:, None]
201 | prediction = tf.unstack(tf.transpose(prediction, (1, 0)))
202 | truth = tf.unstack(tf.transpose(truth, (1, 0)))
203 | lines = list(zip(prediction, truth))
204 | titles = ['{} {}'.format(key.title(), i) for i in range(len(lines))]
205 | labels = [['Prediction', 'Truth']] * len(lines)
206 | plot_name = '{}_trajectory'.format(key)
207 | # The control dependencies are needed because rendering in matplotlib
208 | # uses global state, so rendering two plots in parallel interferes.
209 | with tf.control_dependencies(summaries):
210 | summaries.append(plot_summary(titles, lines, labels, plot_name))
211 | log_probs[key] = log_prob
212 | log_probs = sorted(log_probs.items(), key=lambda x: x[0])
213 | titles, lines = zip(*log_probs)
214 | titles = [title.title() for title in titles]
215 | lines = [[line] for line in lines]
216 | labels = [[None]] * len(titles)
217 | plot_name = 'logprobs'
218 | with tf.control_dependencies(summaries):
219 | summaries.append(plot_summary(titles, lines, labels, plot_name))
220 | return summaries
221 |
--------------------------------------------------------------------------------
/planet/tools/target_network.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from planet.tools import schedule as schedule_lib
22 | from planet.tools import copy_weights
23 |
24 |
25 | def track_network(
26 | trainer, batch_size, source_pattern, target_pattern, every, amount):
27 | init_op = tf.cond(
28 | tf.equal(trainer.global_step, 0),
29 | lambda: copy_weights.soft_copy_weights(
30 | source_pattern, target_pattern, 1.0),
31 | tf.no_op)
32 | schedule = schedule_lib.binary(trainer.step, batch_size, 0, every, -1)
33 | with tf.control_dependencies([init_op]):
34 | return tf.cond(
35 | tf.logical_and(tf.equal(trainer.phase, 'train'), schedule),
36 | lambda: copy_weights.soft_copy_weights(
37 | source_pattern, target_pattern, amount),
38 | tf.no_op)
39 |
--------------------------------------------------------------------------------
/planet/tools/test_nested.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 |
21 | import tensorflow as tf
22 |
23 | from planet.tools import nested
24 |
25 |
26 | class ZipTest(tf.test.TestCase):
27 |
28 | def test_scalar(self):
29 | self.assertEqual(42, nested.zip(42))
30 | self.assertEqual((13, 42), nested.zip(13, 42))
31 |
32 | def test_empty(self):
33 | self.assertEqual({}, nested.zip({}, {}))
34 |
35 | def test_base_case(self):
36 | self.assertEqual((1, 2, 3), nested.zip(1, 2, 3))
37 |
38 | def test_shallow_list(self):
39 | a = [1, 2, 3]
40 | b = [4, 5, 6]
41 | c = [7, 8, 9]
42 | result = nested.zip(a, b, c)
43 | self.assertEqual([(1, 4, 7), (2, 5, 8), (3, 6, 9)], result)
44 |
45 | def test_shallow_tuple(self):
46 | a = (1, 2, 3)
47 | b = (4, 5, 6)
48 | c = (7, 8, 9)
49 | result = nested.zip(a, b, c)
50 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result)
51 |
52 | def test_shallow_dict(self):
53 | a = {'a': 1, 'b': 2, 'c': 3}
54 | b = {'a': 4, 'b': 5, 'c': 6}
55 | c = {'a': 7, 'b': 8, 'c': 9}
56 | result = nested.zip(a, b, c)
57 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8), 'c': (3, 6, 9)}, result)
58 |
59 | def test_single(self):
60 | a = [[1, 2], 3]
61 | result = nested.zip(a)
62 | self.assertEqual(a, result)
63 |
64 | def test_mixed_structures(self):
65 | a = [(1, 2), 3, {'foo': [4]}]
66 | b = [(5, 6), 7, {'foo': [8]}]
67 | result = nested.zip(a, b)
68 | self.assertEqual([((1, 5), (2, 6)), (3, 7), {'foo': [(4, 8)]}], result)
69 |
70 | def test_different_types(self):
71 | a = [1, 2, 3]
72 | b = 'a b c'.split()
73 | result = nested.zip(a, b)
74 | self.assertEqual([(1, 'a'), (2, 'b'), (3, 'c')], result)
75 |
76 | def test_use_type_of_first(self):
77 | a = (1, 2, 3)
78 | b = [4, 5, 6]
79 | c = [7, 8, 9]
80 | result = nested.zip(a, b, c)
81 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result)
82 |
83 | def test_namedtuple(self):
84 | Foo = collections.namedtuple('Foo', 'value')
85 | foo, bar = Foo(42), Foo(13)
86 | self.assertEqual(Foo((42, 13)), nested.zip(foo, bar))
87 |
88 |
89 | class MapTest(tf.test.TestCase):
90 |
91 | def test_scalar(self):
92 | self.assertEqual(42, nested.map(lambda x: x, 42))
93 |
94 | def test_empty(self):
95 | self.assertEqual({}, nested.map(lambda x: x, {}))
96 |
97 | def test_shallow_list(self):
98 | self.assertEqual([2, 4, 6], nested.map(lambda x: 2 * x, [1, 2, 3]))
99 |
100 | def test_shallow_dict(self):
101 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
102 | self.assertEqual(data, nested.map(lambda x: x, data))
103 |
104 | def test_mixed_structure(self):
105 | structure = [(1, 2), 3, {'foo': [4]}]
106 | result = nested.map(lambda x: 2 * x, structure)
107 | self.assertEqual([(2, 4), 6, {'foo': [8]}], result)
108 |
109 | def test_mixed_types(self):
110 | self.assertEqual([14, 'foofoo'], nested.map(lambda x: x * 2, [7, 'foo']))
111 |
112 | def test_multiple_lists(self):
113 | a = [1, 2, 3]
114 | b = [4, 5, 6]
115 | c = [7, 8, 9]
116 | result = nested.map(lambda x, y, z: x + y + z, a, b, c)
117 | self.assertEqual([12, 15, 18], result)
118 |
119 | def test_namedtuple(self):
120 | Foo = collections.namedtuple('Foo', 'value')
121 | foo, bar = [Foo(42)], [Foo(13)]
122 | function = nested.map(lambda x, y: (y, x), foo, bar)
123 | self.assertEqual([Foo((13, 42))], function)
124 | function = nested.map(lambda x, y: x + y, foo, bar)
125 | self.assertEqual([Foo(55)], function)
126 |
127 |
128 | class FlattenTest(tf.test.TestCase):
129 |
130 | def test_scalar(self):
131 | self.assertEqual((42,), nested.flatten(42))
132 |
133 | def test_empty(self):
134 | self.assertEqual((), nested.flatten({}))
135 |
136 | def test_base_case(self):
137 | self.assertEqual((1,), nested.flatten(1))
138 |
139 | def test_convert_type(self):
140 | self.assertEqual((1, 2, 3), nested.flatten([1, 2, 3]))
141 |
142 | def test_mixed_structure(self):
143 | self.assertEqual((1, 2, 3, 4), nested.flatten([(1, 2), 3, {'foo': [4]}]))
144 |
145 | def test_value_ordering(self):
146 | self.assertEqual((1, 2, 3), nested.flatten({'a': 1, 'b': 2, 'c': 3}))
147 |
148 |
149 | class FilterTest(tf.test.TestCase):
150 |
151 | def test_empty(self):
152 | self.assertEqual({}, nested.filter(lambda x: True, {}))
153 | self.assertEqual({}, nested.filter(lambda x: False, {}))
154 |
155 | def test_base_case(self):
156 | self.assertEqual((), nested.filter(lambda x: False, 1))
157 |
158 | def test_single_dict(self):
159 | predicate = lambda x: x % 2 == 0
160 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
161 | self.assertEqual({'b': 2, 'd': 4}, nested.filter(predicate, data))
162 |
163 | def test_multiple_lists(self):
164 | a = [1, 2, 3]
165 | b = [4, 5, 6]
166 | c = [7, 8, 9]
167 | predicate = lambda *args: any(x % 4 == 0 for x in args)
168 | result = nested.filter(predicate, a, b, c)
169 | self.assertEqual([(1, 4, 7), (2, 5, 8)], result)
170 |
171 | def test_multiple_dicts(self):
172 | a = {'a': 1, 'b': 2, 'c': 3}
173 | b = {'a': 4, 'b': 5, 'c': 6}
174 | c = {'a': 7, 'b': 8, 'c': 9}
175 | predicate = lambda *args: any(x % 4 == 0 for x in args)
176 | result = nested.filter(predicate, a, b, c)
177 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8)}, result)
178 |
179 | def test_mixed_structure(self):
180 | predicate = lambda x: x % 2 == 0
181 | data = [(1, 2), 3, {'foo': [4]}]
182 | self.assertEqual([(2,), {'foo': [4]}], nested.filter(predicate, data))
183 |
184 | def test_remove_empty_containers(self):
185 | data = [(1, 2, 3), 4, {'foo': [5, 6], 'bar': 7}]
186 | self.assertEqual([], nested.filter(lambda x: False, data))
187 |
188 | def test_namedtuple(self):
189 | Foo = collections.namedtuple('Foo', 'value1, value2')
190 | self.assertEqual(Foo(1, None), nested.filter(lambda x: x == 1, Foo(1, 2)))
191 |
192 | def test_namedtuple_multiple(self):
193 | Foo = collections.namedtuple('Foo', 'value1, value2')
194 | foo = Foo(1, 2)
195 | bar = Foo(2, 3)
196 | result = nested.filter(lambda x, y: x + y > 3, foo, bar)
197 | self.assertEqual(Foo(None, (2, 3)), result)
198 |
199 | def test_namedtuple_nested(self):
200 | Foo = collections.namedtuple('Foo', 'value1, value2')
201 | foo = Foo(1, [1, 2, 3])
202 | self.assertEqual(Foo(None, [2, 3]), nested.filter(lambda x: x > 1, foo))
203 |
--------------------------------------------------------------------------------
/planet/tools/test_overshooting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from planet import models
22 | from planet.tools import overshooting
23 |
24 |
25 | class _MockCell(models.Base):
26 | """Mock state space model.
27 |
28 | The transition function is to add the action to the observation. The
29 | posterior function is to return the ground truth observation. If actions or
30 | observations are collections, only their first element is used.
31 | """
32 |
33 | def __init__(self, obs_size):
34 | self._obs_size = obs_size
35 | super(_MockCell, self).__init__(
36 | tf.make_template('transition', self._transition),
37 | tf.make_template('posterior', self._posterior))
38 |
39 | @property
40 | def state_size(self):
41 | return {'obs': self._obs_size}
42 |
43 | def _transition(self, prev_state, prev_action, zero_obs):
44 | if isinstance(prev_action, (tuple, list)):
45 | prev_action = prev_action[0]
46 | return {'obs': prev_state['obs'] + prev_action}
47 |
48 | def _posterior(self, prev_state, prev_action, obs):
49 | if isinstance(obs, (tuple, list)):
50 | obs = obs[0]
51 | return {'obs': obs}
52 |
53 |
54 | class OvershootingTest(tf.test.TestCase):
55 |
56 | def test_example(self):
57 | obs = tf.constant([
58 | [10, 20, 30, 40, 50, 60],
59 | [70, 80, 0, 0, 0, 0],
60 | ], dtype=tf.float32)[:, :, None]
61 | prev_action = tf.constant([
62 | [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
63 | [9.0, 0.7, 0, 0, 0, 0],
64 | ], dtype=tf.float32)[:, :, None]
65 | length = tf.constant([6, 2], dtype=tf.int32)
66 | cell = _MockCell(1)
67 | _, prior, posterior, mask = overshooting(
68 | cell, obs, obs, prev_action, length, 3)
69 | prior = tf.squeeze(prior['obs'], 3)
70 | posterior = tf.squeeze(posterior['obs'], 3)
71 | mask = tf.to_int32(mask)
72 | with self.test_session():
73 | # Each column corresponds to a different state step, and each row
74 | # corresponds to a different overshooting distance from there.
75 | self.assertAllEqual([
76 | [1, 1, 1, 1, 1, 1],
77 | [1, 1, 1, 1, 1, 0],
78 | [1, 1, 1, 1, 0, 0],
79 | [1, 1, 1, 0, 0, 0],
80 | ], mask.eval()[0].T)
81 | self.assertAllEqual([
82 | [1, 1, 0, 0, 0, 0],
83 | [1, 0, 0, 0, 0, 0],
84 | [0, 0, 0, 0, 0, 0],
85 | [0, 0, 0, 0, 0, 0],
86 | ], mask.eval()[1].T)
87 | self.assertAllClose([
88 | [0.0, 10.1, 20.2, 30.3, 40.4, 50.5],
89 | [0.1, 10.3, 20.5, 30.7, 40.9, 0],
90 | [0.3, 10.6, 20.9, 31.2, 0, 0],
91 | [0.6, 11.0, 21.4, 0, 0, 0],
92 | ], prior.eval()[0].T)
93 | self.assertAllClose([
94 | [10, 20, 30, 40, 50, 60],
95 | [20, 30, 40, 50, 60, 0],
96 | [30, 40, 50, 60, 0, 0],
97 | [40, 50, 60, 0, 0, 0],
98 | ], posterior.eval()[0].T)
99 | self.assertAllClose([
100 | [9.0, 70.7, 0, 0, 0, 0],
101 | [9.7, 0, 0, 0, 0, 0],
102 | [0, 0, 0, 0, 0, 0],
103 | [0, 0, 0, 0, 0, 0],
104 | ], prior.eval()[1].T)
105 | self.assertAllClose([
106 | [70, 80, 0, 0, 0, 0],
107 | [80, 0, 0, 0, 0, 0],
108 | [0, 0, 0, 0, 0, 0],
109 | [0, 0, 0, 0, 0, 0],
110 | ], posterior.eval()[1].T)
111 |
112 | def test_nested(self):
113 | obs = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)), tf.ones((3, 50, 3)))
114 | prev_action = (tf.ones((3, 50, 1)), tf.ones((3, 50, 2)))
115 | length = tf.constant([49, 50, 3], dtype=tf.int32)
116 | cell = _MockCell(1)
117 | overshooting(cell, obs, obs, prev_action, length, 3)
118 |
119 |
120 | if __name__ == '__main__':
121 | tf.test.main()
122 |
--------------------------------------------------------------------------------
/planet/tools/unroll.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 |
21 | from planet.tools import nested
22 | from planet.tools import shape
23 |
24 |
25 | def closed_loop(cell, embedded, prev_action, debug=False):
26 | use_obs = tf.ones(tf.shape(embedded[:, :, :1])[:3], tf.bool)
27 | (prior, posterior), _ = tf.nn.dynamic_rnn(
28 | cell, (embedded, prev_action, use_obs), dtype=tf.float32)
29 | if debug:
30 | with tf.control_dependencies([tf.assert_equal(
31 | tf.shape(nested.flatten(posterior)[0])[1], tf.shape(embedded)[1])]):
32 | prior = nested.map(tf.identity, prior)
33 | posterior = nested.map(tf.identity, posterior)
34 | return prior, posterior
35 |
36 |
37 | def open_loop(cell, embedded, prev_action, context=1, debug=False):
38 | use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool)
39 | (_, closed_state), last_state = tf.nn.dynamic_rnn(
40 | cell, (embedded[:, :context], prev_action[:, :context], use_obs),
41 | dtype=tf.float32)
42 | use_obs = tf.zeros(tf.shape(embedded[:, context:, :1])[:3], tf.bool)
43 | (_, open_state), _ = tf.nn.dynamic_rnn(
44 | cell, (0 * embedded[:, context:], prev_action[:, context:], use_obs),
45 | initial_state=last_state)
46 | state = nested.map(
47 | lambda x, y: tf.concat([x, y], 1),
48 | closed_state, open_state)
49 | if debug:
50 | with tf.control_dependencies([tf.assert_equal(
51 | tf.shape(nested.flatten(state)[0])[1], tf.shape(embedded)[1])]):
52 | state = nested.map(tf.identity, state)
53 | return state
54 |
55 |
56 | def planned(
57 | cell, objective_fn, embedded, prev_action, planner, context=1, length=20,
58 | amount=1000, debug=False):
59 | use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool)
60 | (_, closed_state), last_state = tf.nn.dynamic_rnn(
61 | cell, (embedded[:, :context], prev_action[:, :context], use_obs),
62 | dtype=tf.float32)
63 | _, plan_state, return_ = planner(
64 | cell, objective_fn, last_state,
65 | obs_shape=shape.shape(embedded)[2:],
66 | action_shape=shape.shape(prev_action)[2:],
67 | horizon=length, amount=amount)
68 | state = nested.map(
69 | lambda x, y: tf.concat([x, y], 1),
70 | closed_state, plan_state)
71 | if debug:
72 | with tf.control_dependencies([tf.assert_equal(
73 | tf.shape(nested.flatten(state)[0])[1], context + length)]):
74 | state = nested.map(tf.identity, state)
75 | return_ = tf.identity(return_)
76 | return state, return_
77 |
--------------------------------------------------------------------------------
/planet/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | from . import utility
20 | from .define_model import define_model
21 | from .define_summaries import define_summaries
22 | from .running import Experiment
23 | from .trainer import Trainer
24 |
--------------------------------------------------------------------------------
/planet/training/define_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import functools
20 |
21 | import tensorflow as tf
22 |
23 | from planet import tools
24 | from planet.training import define_summaries
25 | from planet.training import utility
26 |
27 |
28 | def define_model(data, trainer, config):
29 | tf.logging.info('Build TensorFlow compute graph.')
30 | dependencies = []
31 | cleanups = []
32 | step = trainer.step
33 | global_step = trainer.global_step
34 | phase = trainer.phase
35 |
36 | # Instantiate network blocks.
37 | cell = config.cell()
38 | kwargs = dict(create_scope_now_=True)
39 | encoder = tf.make_template('encoder', config.encoder, **kwargs)
40 | heads = tools.AttrDict(_unlocked=True)
41 | dummy_features = cell.features_from_state(cell.zero_state(1, tf.float32))
42 | for key, head in config.heads.items():
43 | name = 'head_{}'.format(key)
44 | kwargs = dict(create_scope_now_=True)
45 | if key in data:
46 | kwargs['data_shape'] = data[key].shape[2:].as_list()
47 | elif key == 'action_target':
48 | kwargs['data_shape'] = data['action'].shape[2:].as_list()
49 | heads[key] = tf.make_template(name, head, **kwargs)
50 | heads[key](dummy_features) # Initialize weights.
51 |
52 | # Apply and optimize model.
53 | embedded = encoder(data)
54 | with tf.control_dependencies(dependencies):
55 | embedded = tf.identity(embedded)
56 | graph = tools.AttrDict(locals())
57 | prior, posterior = tools.unroll.closed_loop(
58 | cell, embedded, data['action'], config.debug)
59 | objectives = utility.compute_objectives(
60 | posterior, prior, data, graph, config)
61 | summaries, grad_norms = utility.apply_optimizers(
62 | objectives, trainer, config)
63 |
64 | # Active data collection.
65 | with tf.variable_scope('collection'):
66 | with tf.control_dependencies(summaries): # Make sure to train first.
67 | for name, params in config.train_collects.items():
68 | schedule = tools.schedule.binary(
69 | step, config.batch_shape[0],
70 | params.steps_after, params.steps_every, params.steps_until)
71 | summary, _ = tf.cond(
72 | tf.logical_and(tf.equal(trainer.phase, 'train'), schedule),
73 | functools.partial(
74 | utility.simulate_episodes, config, params, graph, cleanups,
75 | expensive_summaries=False, gif_summary=False, name=name),
76 | lambda: (tf.constant(''), tf.constant(0.0)),
77 | name='should_collect_' + name)
78 | summaries.append(summary)
79 |
80 | # Compute summaries.
81 | graph = tools.AttrDict(locals())
82 | summary, score = tf.cond(
83 | trainer.log,
84 | lambda: define_summaries.define_summaries(graph, config, cleanups),
85 | lambda: (tf.constant(''), tf.zeros((0,), tf.float32)),
86 | name='summaries')
87 | summaries = tf.summary.merge([summaries, summary])
88 | dependencies.append(utility.print_metrics(
89 | {ob.name: ob.value for ob in objectives},
90 | step, config.print_metrics_every, 'objectives'))
91 | dependencies.append(utility.print_metrics(
92 | grad_norms, step, config.print_metrics_every, 'grad_norms'))
93 | with tf.control_dependencies(dependencies):
94 | score = tf.identity(score)
95 | return score, summaries, cleanups
96 |
--------------------------------------------------------------------------------
/planet/training/define_summaries.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import tensorflow as tf
20 | from tensorflow_probability import distributions as tfd
21 |
22 | from planet import tools
23 | from planet.training import utility
24 | from planet.tools import summary
25 |
26 |
27 | def define_summaries(graph, config, cleanups):
28 | summaries = []
29 | plot_summaries = [] # Control dependencies for non thread-safe matplot.
30 | length = graph.data['length']
31 | mask = tf.range(graph.embedded.shape[1].value)[None, :] < length[:, None]
32 | heads = graph.heads.copy()
33 | last_time = tf.Variable(lambda: tf.timestamp(), trainable=False)
34 | last_step = tf.Variable(lambda: 0.0, trainable=False, dtype=tf.float64)
35 |
36 | def transform(dist):
37 | mean = config.postprocess_fn(dist.mean())
38 | mean = tf.clip_by_value(mean, 0.0, 1.0)
39 | return tfd.Independent(tfd.Normal(mean, 1.0), len(dist.event_shape))
40 | heads.unlock()
41 | heads['image'] = lambda features: transform(graph.heads['image'](features))
42 | heads.lock()
43 |
44 | with tf.variable_scope('general'):
45 | summaries += summary.data_summaries(graph.data, config.postprocess_fn)
46 | summaries += summary.dataset_summaries(config.train_dir)
47 | summaries += summary.objective_summaries(graph.objectives)
48 | summaries.append(tf.summary.scalar('step', graph.step))
49 | new_time, new_step = tf.timestamp(), tf.cast(graph.global_step, tf.float64)
50 | delta_time, delta_step = new_time - last_time, new_step - last_step
51 | with tf.control_dependencies([delta_time, delta_step]):
52 | assign_ops = [last_time.assign(new_time), last_step.assign(new_step)]
53 | with tf.control_dependencies(assign_ops):
54 | summaries.append(tf.summary.scalar(
55 | 'steps_per_second', delta_step / delta_time))
56 | summaries.append(tf.summary.scalar(
57 | 'seconds_per_step', delta_time / delta_step))
58 |
59 | with tf.variable_scope('closedloop'):
60 | prior, posterior = tools.unroll.closed_loop(
61 | graph.cell, graph.embedded, graph.data['action'], config.debug)
62 | summaries += summary.state_summaries(graph.cell, prior, posterior, mask)
63 | with tf.variable_scope('prior'):
64 | prior_features = graph.cell.features_from_state(prior)
65 | prior_dists = {
66 | name: head(prior_features)
67 | for name, head in heads.items()}
68 | summaries += summary.dist_summaries(prior_dists, graph.data, mask)
69 | summaries += summary.image_summaries(
70 | prior_dists['image'], config.postprocess_fn(graph.data['image']))
71 | with tf.variable_scope('posterior'):
72 | posterior_features = graph.cell.features_from_state(posterior)
73 | posterior_dists = {
74 | name: head(posterior_features)
75 | for name, head in heads.items()}
76 | summaries += summary.dist_summaries(
77 | posterior_dists, graph.data, mask)
78 | summaries += summary.image_summaries(
79 | posterior_dists['image'],
80 | config.postprocess_fn(graph.data['image']))
81 |
82 | with tf.variable_scope('openloop'):
83 | state = tools.unroll.open_loop(
84 | graph.cell, graph.embedded, graph.data['action'],
85 | config.open_loop_context, config.debug)
86 | state_features = graph.cell.features_from_state(state)
87 | state_dists = {name: head(state_features) for name, head in heads.items()}
88 | summaries += summary.dist_summaries(state_dists, graph.data, mask)
89 | summaries += summary.image_summaries(
90 | state_dists['image'], config.postprocess_fn(graph.data['image']))
91 | summaries += summary.state_summaries(graph.cell, state, posterior, mask)
92 | with tf.control_dependencies(plot_summaries):
93 | plot_summary = summary.prediction_summaries(
94 | state_dists, graph.data, state)
95 | plot_summaries += plot_summary
96 | summaries += plot_summary
97 |
98 | with tf.variable_scope('simulation'):
99 | sim_returns = []
100 | for name, params in config.test_collects.items():
101 | # These are expensive and equivalent for train and test phases, so only
102 | # do one of them.
103 | sim_summary, sim_return = tf.cond(
104 | tf.equal(graph.phase, 'test'),
105 | lambda: utility.simulate_episodes(
106 | config, params, graph, cleanups,
107 | expensive_summaries=False,
108 | gif_summary=True,
109 | name=name),
110 | lambda: ('', 0.0),
111 | name='should_simulate_' + params.task.name)
112 | summaries.append(sim_summary)
113 | sim_returns.append(sim_return)
114 |
115 | summaries = tf.summary.merge(summaries)
116 | score = tf.reduce_mean(sim_returns)[None]
117 | return summaries, score
118 |
--------------------------------------------------------------------------------
/planet/training/test_running.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import pickle
21 | import threading
22 | import time
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 |
27 | from planet.training import running
28 |
29 |
30 | class TestExperiment(tf.test.TestCase):
31 |
32 | def test_no_kills(self):
33 | tf.logging.set_verbosity(tf.logging.INFO)
34 | basedir = os.path.join(tf.test.get_temp_dir(), 'test_no_kills')
35 | processes = []
36 | for worker_name in range(20):
37 | processes.append(threading.Thread(
38 | target=_worker_normal, args=(basedir, str(worker_name))))
39 | processes[-1].start()
40 | for process in processes:
41 | process.join()
42 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/DONE'))
43 | self.assertEqual(100, len(filepaths))
44 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/PING'))
45 | self.assertEqual(100, len(filepaths))
46 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/started'))
47 | self.assertEqual(100, len(filepaths))
48 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/resumed'))
49 | self.assertEqual(0, len(filepaths))
50 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/failed'))
51 | self.assertEqual(0, len(filepaths))
52 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/numbers'))
53 | self.assertEqual(100, len(filepaths))
54 | for filepath in filepaths:
55 | with tf.gfile.GFile(filepath, 'rb') as file_:
56 | self.assertEqual(10, len(pickle.load(file_)))
57 |
58 | def test_dying_workers(self):
59 | tf.logging.set_verbosity(tf.logging.INFO)
60 | basedir = os.path.join(tf.test.get_temp_dir(), 'test_dying_workers')
61 | processes = []
62 | for worker_name in range(20):
63 | processes.append(threading.Thread(
64 | target=_worker_dying, args=(basedir, 15, str(worker_name))))
65 | processes[-1].start()
66 | for process in processes:
67 | process.join()
68 | processes = []
69 | for worker_name in range(20):
70 | processes.append(threading.Thread(
71 | target=_worker_normal, args=(basedir, str(worker_name))))
72 | processes[-1].start()
73 | for process in processes:
74 | process.join()
75 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/DONE'))
76 | self.assertEqual(100, len(filepaths))
77 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/PING'))
78 | self.assertEqual(100, len(filepaths))
79 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/FAIL'))
80 | self.assertEqual(0, len(filepaths))
81 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/started'))
82 | self.assertEqual(100, len(filepaths))
83 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/resumed'))
84 | self.assertEqual(20, len(filepaths))
85 | filepaths = tf.gfile.Glob(os.path.join(basedir, '*/numbers'))
86 | self.assertEqual(100, len(filepaths))
87 | for filepath in filepaths:
88 | with tf.gfile.GFile(filepath, 'rb') as file_:
89 | self.assertEqual(10, len(pickle.load(file_)))
90 |
91 |
92 | def _worker_normal(basedir, worker_name):
93 | experiment = running.Experiment(
94 | basedir, _process_fn, _start_fn, _resume_fn,
95 | num_runs=100, worker_name=worker_name, ping_every=1.0)
96 | for run in experiment:
97 | for score in run:
98 | pass
99 |
100 |
101 | def _worker_dying(basedir, die_at_step, worker_name):
102 | experiment = running.Experiment(
103 | basedir, _process_fn, _start_fn, _resume_fn,
104 | num_runs=100, worker_name=worker_name, ping_every=1.0)
105 | step = 0
106 | for run in experiment:
107 | for score in run:
108 | step += 1
109 | if step >= die_at_step:
110 | return
111 |
112 |
113 | def _start_fn(logdir):
114 | assert not tf.gfile.Exists(os.path.join(logdir, 'DONE'))
115 | assert not tf.gfile.Exists(os.path.join(logdir, 'started'))
116 | assert not tf.gfile.Exists(os.path.join(logdir, 'resumed'))
117 | with tf.gfile.GFile(os.path.join(logdir, 'started'), 'w') as file_:
118 | file_.write('\n')
119 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'wb') as file_:
120 | pickle.dump([], file_)
121 | return []
122 |
123 |
124 | def _resume_fn(logdir):
125 | assert not tf.gfile.Exists(os.path.join(logdir, 'DONE'))
126 | assert tf.gfile.Exists(os.path.join(logdir, 'started'))
127 | with tf.gfile.GFile(os.path.join(logdir, 'resumed'), 'w') as file_:
128 | file_.write('\n')
129 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'rb') as file_:
130 | numbers = pickle.load(file_)
131 | if len(numbers) != 5:
132 | raise Exception('Expected to be resumed in the middle for this test.')
133 | return numbers
134 |
135 |
136 | def _process_fn(logdir, numbers):
137 | assert tf.gfile.Exists(os.path.join(logdir, 'started'))
138 | while len(numbers) < 10:
139 | number = np.random.uniform(0, 0.1)
140 | time.sleep(number)
141 | numbers.append(number)
142 | with tf.gfile.GFile(os.path.join(logdir, 'numbers'), 'wb') as file_:
143 | pickle.dump(numbers, file_)
144 | yield number
145 |
146 |
147 | if __name__ == '__main__':
148 | tf.test.main()
149 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The PlaNet Authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Setup script for PlaNet."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import setuptools
22 |
23 |
24 | setuptools.setup(
25 | name='planetrl',
26 | version='1.0.0',
27 | description=(
28 | 'Deep Planning Network: Control from pixels by latent planning ' +
29 | 'with learned dynamics.'),
30 | license='Apache 2.0',
31 | url='http://github.com/google-research/planet',
32 | install_requires=[
33 | 'dm_control',
34 | 'gym',
35 | 'matplotlib',
36 | 'ruamel.yaml',
37 | 'scikit-image',
38 | 'scipy',
39 | 'tensorflow-gpu==1.13.1',
40 | 'tensorflow_probability==0.6.0',
41 | ],
42 | packages=setuptools.find_packages(),
43 | classifiers=[
44 | 'Programming Language :: Python :: 2',
45 | 'Programming Language :: Python :: 3',
46 | 'License :: OSI Approved :: Apache Software License',
47 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
48 | 'Intended Audience :: Science/Research',
49 | ],
50 | )
51 |
--------------------------------------------------------------------------------