├── .gitignore
├── .pylintrc
├── .travis.yml
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── agents
├── __init__.py
├── algorithms
│ ├── __init__.py
│ └── ppo
│ │ ├── __init__.py
│ │ ├── ppo.py
│ │ └── utility.py
├── parts
│ ├── __init__.py
│ ├── iterate_sequences.py
│ ├── memory.py
│ └── normalize.py
├── scripts
│ ├── __init__.py
│ ├── configs.py
│ ├── networks.py
│ ├── train.py
│ ├── train_ppo_test.py
│ ├── utility.py
│ └── visualize.py
└── tools
│ ├── __init__.py
│ ├── attr_dict.py
│ ├── attr_dict_test.py
│ ├── batch_env.py
│ ├── count_weights.py
│ ├── count_weights_test.py
│ ├── in_graph_batch_env.py
│ ├── in_graph_env.py
│ ├── loop.py
│ ├── loop_test.py
│ ├── mock_algorithm.py
│ ├── mock_environment.py
│ ├── nested.py
│ ├── nested_test.py
│ ├── simulate.py
│ ├── simulate_test.py
│ ├── streaming_mean.py
│ ├── wrappers.py
│ └── wrappers_test.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # Pip
6 | pip-selfcheck.json
7 | *.whl
8 | *.egg
9 | *.egg-info
10 |
11 | # Setuptools
12 | /build
13 | /dist
14 | *.eggs
15 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | load-plugins=pylint.extensions.docparams
4 |
5 | jobs=4
6 |
7 | [MESSAGES CONTROL]
8 |
9 | enable=pylint.extensions.docparams
10 |
11 | disable=g-bad-import-order,import-self,duplicate-code,too-many-arguments,
12 | import-error,bad-option-value,missing-docstring,no-self-use,
13 | too-few-public-methods,invalid-name,too-many-locals,no-else-return,
14 | blacklisted-name,too-many-instance-attributes,no-member,
15 | c-extension-no-member,unsubscriptable-object,no-name-in-module,
16 | inconsistent-return-statements,not-context-manager,
17 | no-value-for-parameter,undefined-variable,missing-type-doc,
18 | missing-return-type-doc,missing-yield-type-doc,too-many-function-args
19 |
20 | [FORMAT]
21 |
22 | max-line-length=79
23 |
24 | indent-string=' '
25 |
26 | max-module-lines=2000
27 |
28 | expected-line-ending-format=LF
29 |
30 | [REPORTS]
31 |
32 | reports=no
33 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "2.7"
4 | - "3.6"
5 | env:
6 | matrix:
7 | - TF_VERSION="1.12.*"
8 | install:
9 | - pip install -q -U "tensorflow==$TF_VERSION"
10 | - pip install -q -U gym
11 | - pip install -q -U ruamel.yaml
12 | - pip install -q -U numpy
13 | - pip install -q -U pylint
14 | - pip install -q -e .
15 | script:
16 | - python -c "import agents"
17 | - python -m pylint agents || true
18 | - python -m unittest discover -p "*_test.py"
19 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of Agents authors for copyright purposes.
2 | #
3 | # This does not necessarily list everyone who has contributed code, since in
4 | # some cases, their employer may be the copyright holder. To see the full list
5 | # of contributors, see the revision history in source control.
6 | Google Inc.
7 | Danijar Hafner
8 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution,
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2017 The Agents Authors. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Batch PPO
4 | =========
5 |
6 | This project provides optimized infrastructure for reinforcement learning. It
7 | extends the [OpenAI gym interface][post-gym] to multiple parallel environments
8 | and allows agents to be implemented in TensorFlow and perform batched
9 | computation. As a starting point, we provide BatchPPO, an optimized
10 | implementation of [Proximal Policy Optimization][post-ppo].
11 |
12 | Please cite the [TensorFlow Agents paper][paper-agents] if you use code from
13 | this project in your research:
14 |
15 | ```bibtex
16 | @article{hafner2017agents,
17 | title={TensorFlow Agents: Efficient Batched Reinforcement Learning in TensorFlow},
18 | author={Hafner, Danijar and Davidson, James and Vanhoucke, Vincent},
19 | journal={arXiv preprint arXiv:1709.02878},
20 | year={2017}
21 | }
22 | ```
23 |
24 | Dependencies: Python 2/3, TensorFlow 1.3+, Gym, ruamel.yaml
25 |
26 | [paper-agents]: https://arxiv.org/pdf/1709.02878.pdf
27 | [post-gym]: https://blog.openai.com/openai-gym-beta/
28 | [post-ppo]: https://blog.openai.com/openai-baselines-ppo/
29 |
30 | Instructions
31 | ------------
32 |
33 | Clone the repository and run the PPO algorithm by typing:
34 |
35 | ```shell
36 | python3 -m agents.scripts.train --logdir=/path/to/logdir --config=pendulum
37 | ```
38 |
39 | The algorithm to use is defined in the configuration and `pendulum` started
40 | here uses the included PPO implementation. Check out more pre-defined
41 | configurations in `agents/scripts/configs.py`.
42 |
43 | If you want to resume a previously started run, add the `--timestamp=`
44 | flag to the last command and provide the timestamp in the directory name of
45 | your run.
46 |
47 | To visualize metrics start TensorBoard from another terminal, then point your
48 | browser to `http://localhost:2222`:
49 |
50 | ```shell
51 | tensorboard --logdir=/path/to/logdir --port=2222
52 | ```
53 |
54 | To render videos and gather OpenAI Gym statistics to upload to the scoreboard,
55 | type:
56 |
57 | ```shell
58 | python3 -m agents.scripts.visualize --logdir=/path/to/logdir/- --outdir=/path/to/outdir/
59 | ```
60 |
61 | Modifications
62 | -------------
63 |
64 | We release this project as a starting point that makes it easy to implement new
65 | reinforcement learning ideas. These files are good places to start when
66 | modifying the code:
67 |
68 | | File | Content |
69 | | ---- | ------- |
70 | | `scripts/configs.py` | Experiment configurations specifying the tasks and algorithms. |
71 | | `scripts/networks.py` | Neural network models. |
72 | | `scripts/train.py` | The executable file containing the training setup. |
73 | | `algorithms/ppo/ppo.py` | The TensorFlow graph for the PPO algorithm. |
74 |
75 | To run unit tests and linting, type:
76 |
77 | ```shell
78 | python2 -m unittest discover -p "*_test.py"
79 | python3 -m unittest discover -p "*_test.py"
80 | python3 -m pylint agents
81 | ```
82 |
83 | For further questions, please open an issue on Github.
84 |
85 | Implementation
86 | --------------
87 |
88 | We include a batched interface for OpenAI Gym environments that fully integrates
89 | with TensorFlow for efficient algorithm implementations. This is achieved
90 | through these core components:
91 |
92 | - **`agents.tools.wrappers.ExternalProcess`** is an environment wrapper that
93 | constructs an OpenAI Gym environment inside of an external process. Calls to
94 | `step()` and `reset()`, as well as attribute access, are forwarded to the
95 | process and wait for the result. This allows to run multiple environments in
96 | parallel without being restricted by Python's global interpreter lock.
97 | - **`agents.tools.BatchEnv`** extends the OpenAI Gym interface to batches of
98 | environments. It combines multiple OpenAI Gym environments, with `step()`
99 | accepting a batch of actions and returning a batch of observations, rewards,
100 | done flags, and info objects. If the individual environments live in external
101 | processes, they will be stepped in parallel.
102 | - **`agents.tools.InGraphBatchEnv`** integrates a batch environment into the
103 | TensorFlow graph and makes its `step()` and `reset()` functions accessible as
104 | operations. The current batch of observations, last actions, rewards, and done
105 | flags is stored in variables and made available as tensors.
106 | - **`agents.tools.simulate()`** fuses the step of an in-graph batch environment
107 | and a reinforcement learning algorithm together into a single operation to be
108 | called inside the training loop. This reduces the number of session calls and
109 | provides a simple way to train future algorithms.
110 |
111 | To understand all the code, please make yourself familiar with TensorFlow's
112 | control flow operations, especially [`tf.cond()`][tf-cond],
113 | [`tf.scan()`][tf-scan], and
114 | [`tf.control_dependencies()`][tf-control-dependencies].
115 |
116 | [tf-cond]: https://www.tensorflow.org/api_docs/python/tf/cond
117 | [tf-scan]: https://www.tensorflow.org/api_docs/python/tf/scan
118 | [tf-control-dependencies]: https://www.tensorflow.org/api_docs/python/tf/control_dependencies
119 |
120 | Disclaimer
121 | ----------
122 |
123 | This is not an official Google product.
124 |
--------------------------------------------------------------------------------
/agents/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Main package of TensorFlow agents."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from . import algorithms
22 | from . import scripts
23 | from . import tools
24 |
--------------------------------------------------------------------------------
/agents/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Agent implementations."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from .ppo import PPO
22 |
--------------------------------------------------------------------------------
/agents/algorithms/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Proximal Policy Optimization algorithm."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from .ppo import PPO
22 |
--------------------------------------------------------------------------------
/agents/algorithms/ppo/ppo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Proximal Policy Optimization agent.
16 |
17 | Based on John Schulman's implementation in Python and Theano:
18 | https://github.com/joschu/modular_rl/blob/master/modular_rl/ppo.py
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | import functools
26 |
27 | import tensorflow as tf
28 |
29 | from agents import parts
30 | from agents import tools
31 | from agents.algorithms.ppo import utility
32 |
33 |
34 | class PPO(object):
35 | """A vectorized implementation of the PPO algorithm by John Schulman."""
36 |
37 | def __init__(self, batch_env, step, is_training, should_log, config):
38 | """Create an instance of the PPO algorithm.
39 |
40 | Args:
41 | batch_env: In-graph batch environment.
42 | step: Integer tensor holding the current training step.
43 | is_training: Boolean tensor for whether the algorithm should train.
44 | should_log: Boolean tensor for whether summaries should be returned.
45 | config: Object containing the agent configuration as attributes.
46 | """
47 | self._batch_env = batch_env
48 | self._step = step
49 | self._is_training = is_training
50 | self._should_log = should_log
51 | self._config = config
52 | self._observ_filter = parts.StreamingNormalize(
53 | self._batch_env.observ[0], center=True, scale=True, clip=5,
54 | name='normalize_observ')
55 | self._reward_filter = parts.StreamingNormalize(
56 | self._batch_env.reward[0], center=False, scale=True, clip=10,
57 | name='normalize_reward')
58 | self._use_gpu = self._config.use_gpu and utility.available_gpus()
59 | policy_params, state = self._initialize_policy()
60 | self._initialize_memory(policy_params)
61 | # Initialize the optimizer and penalty.
62 | with tf.device('/gpu:0' if self._use_gpu else '/cpu:0'):
63 | self._optimizer = self._config.optimizer(self._config.learning_rate)
64 | self._penalty = tf.Variable(
65 | self._config.kl_init_penalty, False, dtype=tf.float32)
66 | # If the policy is stateful, allocate space to store its state.
67 | with tf.variable_scope('ppo_temporary'):
68 | with tf.device('/gpu:0'):
69 | if state is None:
70 | self._last_state = None
71 | else:
72 | var_like = lambda x: tf.Variable(lambda: tf.zeros_like(x), False)
73 | self._last_state = tools.nested.map(var_like, state)
74 | # Remember the action and policy parameters to write into the memory.
75 | with tf.variable_scope('ppo_temporary'):
76 | self._last_action = tf.Variable(
77 | tf.zeros_like(self._batch_env.action), False, name='last_action')
78 | self._last_policy = tools.nested.map(
79 | lambda x: tf.Variable(tf.zeros_like(x[:, 0], False)), policy_params)
80 |
81 | def begin_episode(self, agent_indices):
82 | """Reset the recurrent states and stored episode.
83 |
84 | Args:
85 | agent_indices: Tensor containing current batch indices.
86 |
87 | Returns:
88 | Summary tensor.
89 | """
90 | with tf.name_scope('begin_episode/'):
91 | if self._last_state is None:
92 | reset_state = tf.no_op()
93 | else:
94 | reset_state = utility.reinit_nested_vars(
95 | self._last_state, agent_indices)
96 | reset_buffer = self._current_episodes.clear(agent_indices)
97 | with tf.control_dependencies([reset_state, reset_buffer]):
98 | return tf.constant('')
99 |
100 | def perform(self, agent_indices, observ):
101 | """Compute batch of actions and a summary for a batch of observation.
102 |
103 | Args:
104 | agent_indices: Tensor containing current batch indices.
105 | observ: Tensor of a batch of observations for all agents.
106 |
107 | Returns:
108 | Tuple of action batch tensor and summary tensor.
109 | """
110 | with tf.name_scope('perform/'):
111 | observ = self._observ_filter.transform(observ)
112 | if self._last_state is None:
113 | state = None
114 | else:
115 | state = tools.nested.map(
116 | lambda x: tf.gather(x, agent_indices), self._last_state)
117 | with tf.device('/gpu:0' if self._use_gpu else '/cpu:0'):
118 | output = self._network(
119 | observ[:, None], tf.ones(observ.shape[0]), state)
120 | action = tf.cond(
121 | self._is_training, output.policy.sample, output.policy.mode)
122 | logprob = output.policy.log_prob(action)[:, 0]
123 | # pylint: disable=g-long-lambda
124 | summary = tf.cond(self._should_log, lambda: tf.summary.merge([
125 | tf.summary.histogram('mode', output.policy.mode()[:, 0]),
126 | tf.summary.histogram('action', action[:, 0]),
127 | tf.summary.histogram('logprob', logprob)]), str)
128 | # Remember current policy to append to memory in the experience callback.
129 | if self._last_state is None:
130 | assign_state = tf.no_op()
131 | else:
132 | assign_state = utility.assign_nested_vars(
133 | self._last_state, output.state, agent_indices)
134 | remember_last_action = tf.scatter_update(
135 | self._last_action, agent_indices, action[:, 0])
136 | policy_params = tools.nested.filter(
137 | lambda x: isinstance(x, tf.Tensor), output.policy.parameters)
138 | assert policy_params, 'Policy has no parameters to store.'
139 | remember_last_policy = tools.nested.map(
140 | lambda var, val: tf.scatter_update(var, agent_indices, val[:, 0]),
141 | self._last_policy, policy_params, flatten=True)
142 | with tf.control_dependencies((
143 | assign_state, remember_last_action) + remember_last_policy):
144 | return action[:, 0], tf.identity(summary)
145 |
146 | def experience(
147 | self, agent_indices, observ, action, reward, unused_done, unused_nextob):
148 | """Process the transition tuple of the current step.
149 |
150 | When training, add the current transition tuple to the memory and update
151 | the streaming statistics for observations and rewards. A summary string is
152 | returned if requested at this step.
153 |
154 | Args:
155 | agent_indices: Tensor containing current batch indices.
156 | observ: Batch tensor of observations.
157 | action: Batch tensor of actions.
158 | reward: Batch tensor of rewards.
159 | unused_done: Batch tensor of done flags.
160 | unused_nextob: Batch tensor of successor observations.
161 |
162 | Returns:
163 | Summary tensor.
164 | """
165 | with tf.name_scope('experience/'):
166 | return tf.cond(
167 | self._is_training,
168 | # pylint: disable=g-long-lambda
169 | lambda: self._define_experience(
170 | agent_indices, observ, action, reward), str)
171 |
172 | def _define_experience(self, agent_indices, observ, action, reward):
173 | """Implement the branch of experience() entered during training."""
174 | update_filters = tf.summary.merge([
175 | self._observ_filter.update(observ),
176 | self._reward_filter.update(reward)])
177 | with tf.control_dependencies([update_filters]):
178 | if self._config.train_on_agent_action:
179 | # NOTE: Doesn't seem to change much.
180 | action = self._last_action
181 | policy = tools.nested.map(
182 | lambda x: tf.gather(x, agent_indices), self._last_policy)
183 | batch = (observ, action, policy, reward)
184 | append = self._current_episodes.append(batch, agent_indices)
185 | with tf.control_dependencies([append]):
186 | norm_observ = self._observ_filter.transform(observ)
187 | norm_reward = tf.reduce_mean(self._reward_filter.transform(reward))
188 | # pylint: disable=g-long-lambda
189 | summary = tf.cond(self._should_log, lambda: tf.summary.merge([
190 | update_filters,
191 | self._observ_filter.summary(),
192 | self._reward_filter.summary(),
193 | tf.summary.scalar('memory_size', self._num_finished_episodes),
194 | tf.summary.histogram('normalized_observ', norm_observ),
195 | tf.summary.histogram('action', self._last_action),
196 | tf.summary.scalar('normalized_reward', norm_reward)]), str)
197 | return summary
198 |
199 | def end_episode(self, agent_indices):
200 | """Add episodes to the memory and perform update steps if memory is full.
201 |
202 | During training, add the collected episodes of the batch indices that
203 | finished their episode to the memory. If the memory is full, train on it,
204 | and then clear the memory. A summary string is returned if requested at
205 | this step.
206 |
207 | Args:
208 | agent_indices: Tensor containing current batch indices.
209 |
210 | Returns:
211 | Summary tensor.
212 | """
213 | with tf.name_scope('end_episode/'):
214 | return tf.cond(
215 | self._is_training,
216 | lambda: self._define_end_episode(agent_indices), str)
217 |
218 | def _initialize_policy(self):
219 | """Initialize the policy.
220 |
221 | Run the policy network on dummy data to initialize its parameters for later
222 | reuse and to analyze the policy distribution. Initializes the attributes
223 | `self._network` and `self._policy_type`.
224 |
225 | Raises:
226 | ValueError: Invalid policy distribution.
227 |
228 | Returns:
229 | Parameters of the policy distribution and policy state.
230 | """
231 | with tf.device('/gpu:0' if self._use_gpu else '/cpu:0'):
232 | network = functools.partial(
233 | self._config.network, self._config, self._batch_env.action_space)
234 | self._network = tf.make_template('network', network)
235 | output = self._network(
236 | tf.zeros_like(self._batch_env.observ)[:, None],
237 | tf.ones(len(self._batch_env)))
238 | if output.policy.event_shape != self._batch_env.action.shape[1:]:
239 | message = 'Policy event shape {} does not match action shape {}.'
240 | message = message.format(
241 | output.policy.event_shape, self._batch_env.action.shape[1:])
242 | raise ValueError(message)
243 | self._policy_type = type(output.policy)
244 | is_tensor = lambda x: isinstance(x, tf.Tensor)
245 | policy_params = tools.nested.filter(is_tensor, output.policy.parameters)
246 | set_batch_dim = lambda x: utility.set_dimension(x, 0, len(self._batch_env))
247 | tools.nested.map(set_batch_dim, policy_params)
248 | if output.state is not None:
249 | tools.nested.map(set_batch_dim, output.state)
250 | return policy_params, output.state
251 |
252 | def _initialize_memory(self, policy_params):
253 | """Initialize temporary and permanent memory.
254 |
255 | Args:
256 | policy_params: Nested tuple of policy parameters with all dimensions set.
257 |
258 | Initializes the attributes `self._current_episodes`,
259 | `self._finished_episodes`, and `self._num_finished_episodes`. The episodes
260 | memory serves to collect multiple episodes in parallel. Finished episodes
261 | are copied into the next free slot of the second memory. The memory index
262 | points to the next free slot.
263 | """
264 | # We store observation, action, policy parameters, and reward.
265 | template = (
266 | self._batch_env.observ[0],
267 | self._batch_env.action[0],
268 | tools.nested.map(lambda x: x[0, 0], policy_params),
269 | self._batch_env.reward[0])
270 | with tf.variable_scope('ppo_temporary'):
271 | self._current_episodes = parts.EpisodeMemory(
272 | template, len(self._batch_env), self._config.max_length, 'episodes')
273 | self._finished_episodes = parts.EpisodeMemory(
274 | template, self._config.update_every, self._config.max_length, 'memory')
275 | self._num_finished_episodes = tf.Variable(0, False)
276 |
277 | def _define_end_episode(self, agent_indices):
278 | """Implement the branch of end_episode() entered during training."""
279 | episodes, length = self._current_episodes.data(agent_indices)
280 | space_left = self._config.update_every - self._num_finished_episodes
281 | use_episodes = tf.range(tf.minimum(
282 | tf.shape(agent_indices)[0], space_left))
283 | episodes = tools.nested.map(lambda x: tf.gather(x, use_episodes), episodes)
284 | append = self._finished_episodes.replace(
285 | episodes, tf.gather(length, use_episodes),
286 | use_episodes + self._num_finished_episodes)
287 | with tf.control_dependencies([append]):
288 | increment_index = self._num_finished_episodes.assign_add(
289 | tf.shape(use_episodes)[0])
290 | with tf.control_dependencies([increment_index]):
291 | memory_full = self._num_finished_episodes >= self._config.update_every
292 | return tf.cond(memory_full, self._training, str)
293 |
294 | def _training(self):
295 | """Perform multiple training iterations of both policy and value baseline.
296 |
297 | Training on the episodes collected in the memory. Reset the memory
298 | afterwards. Always returns a summary string.
299 |
300 | Returns:
301 | Summary tensor.
302 | """
303 | with tf.device('/gpu:0' if self._use_gpu else '/cpu:0'):
304 | with tf.name_scope('training'):
305 | assert_full = tf.assert_equal(
306 | self._num_finished_episodes, self._config.update_every)
307 | with tf.control_dependencies([assert_full]):
308 | data = self._finished_episodes.data()
309 | (observ, action, old_policy_params, reward), length = data
310 | # We set padding frames of the parameters to ones to prevent Gaussians
311 | # with zero variance. This would result in an infinite KL divergence,
312 | # which, even if masked out, would result in NaN gradients.
313 | old_policy_params = tools.nested.map(
314 | lambda param: self._mask(param, length, 1), old_policy_params)
315 | with tf.control_dependencies([tf.assert_greater(length, 0)]):
316 | length = tf.identity(length)
317 | observ = self._observ_filter.transform(observ)
318 | reward = self._reward_filter.transform(reward)
319 | update_summary = self._perform_update_steps(
320 | observ, action, old_policy_params, reward, length)
321 | with tf.control_dependencies([update_summary]):
322 | penalty_summary = self._adjust_penalty(
323 | observ, old_policy_params, length)
324 | with tf.control_dependencies([penalty_summary]):
325 | clear_memory = tf.group(
326 | self._finished_episodes.clear(),
327 | self._num_finished_episodes.assign(0))
328 | with tf.control_dependencies([clear_memory]):
329 | weight_summary = utility.variable_summaries(
330 | tf.trainable_variables(), self._config.weight_summaries)
331 | return tf.summary.merge([
332 | update_summary, penalty_summary, weight_summary])
333 |
334 | def _perform_update_steps(
335 | self, observ, action, old_policy_params, reward, length):
336 | """Perform multiple update steps of value function and policy.
337 |
338 | The advantage is computed once at the beginning and shared across
339 | iterations. We need to decide for the summary of one iteration, and thus
340 | choose the one after half of the iterations.
341 |
342 | Args:
343 | observ: Sequences of observations.
344 | action: Sequences of actions.
345 | old_policy_params: Parameters of the behavioral policy.
346 | reward: Sequences of rewards.
347 | length: Batch of sequence lengths.
348 |
349 | Returns:
350 | Summary tensor.
351 | """
352 | return_ = utility.discounted_return(
353 | reward, length, self._config.discount)
354 | value = self._network(observ, length).value
355 | if self._config.gae_lambda:
356 | advantage = utility.lambda_advantage(
357 | reward, value, length, self._config.discount,
358 | self._config.gae_lambda)
359 | else:
360 | advantage = return_ - value
361 | mean, variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True)
362 | advantage = (advantage - mean) / (tf.sqrt(variance) + 1e-8)
363 | advantage = tf.Print(
364 | advantage, [tf.reduce_mean(return_), tf.reduce_mean(value)],
365 | 'return and value: ')
366 | advantage = tf.Print(
367 | advantage, [tf.reduce_mean(advantage)],
368 | 'normalized advantage: ')
369 | episodes = (observ, action, old_policy_params, reward, advantage)
370 | value_loss, policy_loss, summary = parts.iterate_sequences(
371 | self._update_step, [0., 0., ''], episodes, length,
372 | self._config.chunk_length,
373 | self._config.batch_size,
374 | self._config.update_epochs,
375 | padding_value=1)
376 | print_losses = tf.group(
377 | tf.Print(0, [tf.reduce_mean(value_loss)], 'value loss: '),
378 | tf.Print(0, [tf.reduce_mean(policy_loss)], 'policy loss: '))
379 | with tf.control_dependencies([value_loss, policy_loss, print_losses]):
380 | return summary[self._config.update_epochs // 2]
381 |
382 | def _update_step(self, sequence):
383 | """Compute the current combined loss and perform a gradient update step.
384 |
385 | The sequences must be a dict containing the keys `length` and `sequence`,
386 | where the latter is a tuple containing observations, actions, parameters of
387 | the behavioral policy, rewards, and advantages.
388 |
389 | Args:
390 | sequence: Sequences of episodes or chunks of episodes.
391 |
392 | Returns:
393 | Tuple of value loss, policy loss, and summary tensor.
394 | """
395 | observ, action, old_policy_params, reward, advantage = sequence['sequence']
396 | length = sequence['length']
397 | old_policy = self._policy_type(**old_policy_params)
398 | value_loss, value_summary = self._value_loss(observ, reward, length)
399 | network = self._network(observ, length)
400 | policy_loss, policy_summary = self._policy_loss(
401 | old_policy, network.policy, action, advantage, length)
402 | network_loss = network.get('loss', 0.0)
403 | loss = policy_loss + value_loss + tf.reduce_mean(network_loss)
404 | gradients, variables = (
405 | zip(*self._optimizer.compute_gradients(loss)))
406 | optimize = self._optimizer.apply_gradients(
407 | zip(gradients, variables))
408 | summary = tf.summary.merge([
409 | value_summary, policy_summary,
410 | tf.summary.histogram('network_loss', network_loss),
411 | tf.summary.scalar('avg_network_loss', tf.reduce_mean(network_loss)),
412 | tf.summary.scalar('gradient_norm', tf.global_norm(gradients)),
413 | utility.gradient_summaries(zip(gradients, variables))])
414 | with tf.control_dependencies([optimize]):
415 | return [tf.identity(x) for x in (value_loss, policy_loss, summary)]
416 |
417 | def _value_loss(self, observ, reward, length):
418 | """Compute the loss function for the value baseline.
419 |
420 | The value loss is the difference between empirical and approximated returns
421 | over the collected episodes. Returns the loss tensor and a summary strin.
422 |
423 | Args:
424 | observ: Sequences of observations.
425 | reward: Sequences of reward.
426 | length: Batch of sequence lengths.
427 |
428 | Returns:
429 | Tuple of loss tensor and summary tensor.
430 | """
431 | with tf.name_scope('value_loss'):
432 | value = self._network(observ, length).value
433 | return_ = utility.discounted_return(
434 | reward, length, self._config.discount)
435 | advantage = return_ - value
436 | value_loss = 0.5 * self._mask(advantage ** 2, length)
437 | summary = tf.summary.merge([
438 | tf.summary.histogram('value_loss', value_loss),
439 | tf.summary.scalar('avg_value_loss', tf.reduce_mean(value_loss))])
440 | value_loss = tf.reduce_mean(value_loss)
441 | return tf.check_numerics(value_loss, 'value_loss'), summary
442 |
443 | def _policy_loss(
444 | self, old_policy, policy, action, advantage, length):
445 | """Compute the policy loss composed of multiple components.
446 |
447 | 1. The policy gradient loss is importance sampled from the data-collecting
448 | policy at the beginning of training.
449 | 2. The second term is a KL penalty between the policy at the beginning of
450 | training and the current policy.
451 | 3. Additionally, if this KL already changed more than twice the target
452 | amount, we activate a strong penalty discouraging further divergence.
453 |
454 | Args:
455 | old_policy: Action distribution of the behavioral policy.
456 | policy: Sequences of distribution params of the current policy.
457 | action: Sequences of actions.
458 | advantage: Sequences of advantages.
459 | length: Batch of sequence lengths.
460 |
461 | Returns:
462 | Tuple of loss tensor and summary tensor.
463 | """
464 | with tf.name_scope('policy_loss'):
465 | kl = tf.contrib.distributions.kl_divergence(old_policy, policy)
466 | # Infinite values in the KL, even for padding frames that we mask out,
467 | # cause NaN gradients since TensorFlow computes gradients with respect to
468 | # the whole input tensor.
469 | kl = tf.check_numerics(kl, 'kl')
470 | kl = tf.reduce_mean(self._mask(kl, length), 1)
471 | policy_gradient = tf.exp(
472 | policy.log_prob(action) - old_policy.log_prob(action))
473 | surrogate_loss = -tf.reduce_mean(self._mask(
474 | policy_gradient * tf.stop_gradient(advantage), length), 1)
475 | surrogate_loss = tf.check_numerics(surrogate_loss, 'surrogate_loss')
476 | kl_penalty = self._penalty * kl
477 | cutoff_threshold = self._config.kl_target * self._config.kl_cutoff_factor
478 | cutoff_count = tf.reduce_sum(
479 | tf.cast(kl > cutoff_threshold, tf.int32))
480 | with tf.control_dependencies([tf.cond(
481 | cutoff_count > 0,
482 | lambda: tf.Print(0, [cutoff_count], 'kl cutoff! '), int)]):
483 | kl_cutoff = (
484 | self._config.kl_cutoff_coef *
485 | tf.cast(kl > cutoff_threshold, tf.float32) *
486 | (kl - cutoff_threshold) ** 2)
487 | policy_loss = surrogate_loss + kl_penalty + kl_cutoff
488 | entropy = tf.reduce_mean(policy.entropy(), axis=1)
489 | if self._config.entropy_regularization:
490 | policy_loss -= self._config.entropy_regularization * entropy
491 | summary = tf.summary.merge([
492 | tf.summary.histogram('entropy', entropy),
493 | tf.summary.histogram('kl', kl),
494 | tf.summary.histogram('surrogate_loss', surrogate_loss),
495 | tf.summary.histogram('kl_penalty', kl_penalty),
496 | tf.summary.histogram('kl_cutoff', kl_cutoff),
497 | tf.summary.histogram('kl_penalty_combined', kl_penalty + kl_cutoff),
498 | tf.summary.histogram('policy_loss', policy_loss),
499 | tf.summary.scalar('avg_surr_loss', tf.reduce_mean(surrogate_loss)),
500 | tf.summary.scalar('avg_kl_penalty', tf.reduce_mean(kl_penalty)),
501 | tf.summary.scalar('avg_policy_loss', tf.reduce_mean(policy_loss))])
502 | policy_loss = tf.reduce_mean(policy_loss, 0)
503 | return tf.check_numerics(policy_loss, 'policy_loss'), summary
504 |
505 | def _adjust_penalty(self, observ, old_policy_params, length):
506 | """Adjust the KL policy between the behavioral and current policy.
507 |
508 | Compute how much the policy actually changed during the multiple
509 | update steps. Adjust the penalty strength for the next training phase if we
510 | overshot or undershot the target divergence too much.
511 |
512 | Args:
513 | observ: Sequences of observations.
514 | old_policy_params: Parameters of the behavioral policy.
515 | length: Batch of sequence lengths.
516 |
517 | Returns:
518 | Summary tensor.
519 | """
520 | old_policy = self._policy_type(**old_policy_params)
521 | with tf.name_scope('adjust_penalty'):
522 | network = self._network(observ, length)
523 | print_penalty = tf.Print(0, [self._penalty], 'current penalty: ')
524 | with tf.control_dependencies([print_penalty]):
525 | kl_change = tf.reduce_mean(self._mask(
526 | tf.contrib.distributions.kl_divergence(old_policy, network.policy),
527 | length))
528 | kl_change = tf.Print(kl_change, [kl_change], 'kl change: ')
529 | maybe_increase = tf.cond(
530 | kl_change > 1.3 * self._config.kl_target,
531 | # pylint: disable=g-long-lambda
532 | lambda: tf.Print(self._penalty.assign(
533 | self._penalty * 1.5), [0], 'increase penalty '),
534 | float)
535 | maybe_decrease = tf.cond(
536 | kl_change < 0.7 * self._config.kl_target,
537 | # pylint: disable=g-long-lambda
538 | lambda: tf.Print(self._penalty.assign(
539 | self._penalty / 1.5), [0], 'decrease penalty '),
540 | float)
541 | with tf.control_dependencies([maybe_increase, maybe_decrease]):
542 | return tf.summary.merge([
543 | tf.summary.scalar('kl_change', kl_change),
544 | tf.summary.scalar('penalty', self._penalty)])
545 |
546 | def _mask(self, tensor, length, padding_value=0):
547 | """Set padding elements of a batch of sequences to a constant.
548 |
549 | Useful for setting padding elements to zero before summing along the time
550 | dimension, or for preventing infinite results in padding elements.
551 |
552 | Args:
553 | tensor: Tensor of sequences.
554 | length: Batch of sequence lengths.
555 | padding_value: Value to write into padding elements.
556 |
557 | Returns:
558 | Masked sequences.
559 | """
560 | with tf.name_scope('mask'):
561 | range_ = tf.range(tensor.shape[1].value)
562 | mask = range_[None, :] < length[:, None]
563 | if tensor.shape.ndims > 2:
564 | for _ in range(tensor.shape.ndims - 2):
565 | mask = mask[..., None]
566 | mask = tf.tile(mask, [1, 1] + tensor.shape[2:].as_list())
567 | masked = tf.where(mask, tensor, padding_value * tf.ones_like(tensor))
568 | return tf.check_numerics(masked, 'masked')
569 |
--------------------------------------------------------------------------------
/agents/algorithms/ppo/utility.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for the PPO algorithm."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 |
24 | import tensorflow as tf
25 | from tensorflow.python.client import device_lib
26 |
27 |
28 | def reinit_nested_vars(variables, indices=None):
29 | """Reset all variables in a nested tuple to zeros.
30 |
31 | Args:
32 | variables: Nested tuple or list of variables.
33 | indices: Batch indices to reset, defaults to all.
34 |
35 | Returns:
36 | Operation.
37 | """
38 | if isinstance(variables, (tuple, list)):
39 | return tf.group(*[
40 | reinit_nested_vars(variable, indices) for variable in variables])
41 | if indices is None:
42 | return variables.assign(tf.zeros_like(variables))
43 | else:
44 | zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list())
45 | return tf.scatter_update(variables, indices, zeros)
46 |
47 |
48 | def assign_nested_vars(variables, tensors, indices=None):
49 | """Assign tensors to matching nested tuple of variables.
50 |
51 | Args:
52 | variables: Nested tuple or list of variables to update.
53 | tensors: Nested tuple or list of tensors to assign.
54 | indices: Batch indices to assign to; default to all.
55 |
56 | Returns:
57 | Operation.
58 | """
59 | if isinstance(variables, (tuple, list)):
60 | return tf.group(*[
61 | assign_nested_vars(variable, tensor)
62 | for variable, tensor in zip(variables, tensors)])
63 | if indices is None:
64 | return variables.assign(tensors)
65 | else:
66 | return tf.scatter_update(variables, indices, tensors)
67 |
68 |
69 | def discounted_return(reward, length, discount):
70 | """Discounted Monte-Carlo returns."""
71 | timestep = tf.range(reward.shape[1].value)
72 | mask = tf.cast(timestep[None, :] < length[:, None], tf.float32)
73 | return_ = tf.reverse(tf.transpose(tf.scan(
74 | lambda agg, cur: cur + discount * agg,
75 | tf.transpose(tf.reverse(mask * reward, [1]), [1, 0]),
76 | tf.zeros_like(reward[:, -1]), 1, False), [1, 0]), [1])
77 | return tf.check_numerics(tf.stop_gradient(return_), 'return')
78 |
79 |
80 | def fixed_step_return(reward, value, length, discount, window):
81 | """N-step discounted return."""
82 | timestep = tf.range(reward.shape[1].value)
83 | mask = tf.cast(timestep[None, :] < length[:, None], tf.float32)
84 | return_ = tf.zeros_like(reward)
85 | for _ in range(window):
86 | return_ += reward
87 | reward = discount * tf.concat(
88 | [reward[:, 1:], tf.zeros_like(reward[:, -1:])], 1)
89 | return_ += discount ** window * tf.concat(
90 | [value[:, window:], tf.zeros_like(value[:, -window:])], 1)
91 | return tf.check_numerics(tf.stop_gradient(mask * return_), 'return')
92 |
93 |
94 | def lambda_return(reward, value, length, discount, lambda_):
95 | """TD-lambda returns."""
96 | timestep = tf.range(reward.shape[1].value)
97 | mask = tf.cast(timestep[None, :] < length[:, None], tf.float32)
98 | sequence = mask * reward + discount * value * (1 - lambda_)
99 | discount = mask * discount * lambda_
100 | sequence = tf.stack([sequence, discount], 2)
101 | return_ = tf.reverse(tf.transpose(tf.scan(
102 | lambda agg, cur: cur[0] + cur[1] * agg,
103 | tf.transpose(tf.reverse(sequence, [1]), [1, 2, 0]),
104 | tf.zeros_like(value[:, -1]), 1, False), [1, 0]), [1])
105 | return tf.check_numerics(tf.stop_gradient(return_), 'return')
106 |
107 |
108 | def lambda_advantage(reward, value, length, discount, gae_lambda):
109 | """Generalized Advantage Estimation."""
110 | timestep = tf.range(reward.shape[1].value)
111 | mask = tf.cast(timestep[None, :] < length[:, None], tf.float32)
112 | next_value = tf.concat([value[:, 1:], tf.zeros_like(value[:, -1:])], 1)
113 | delta = reward + discount * next_value - value
114 | advantage = tf.reverse(tf.transpose(tf.scan(
115 | lambda agg, cur: cur + gae_lambda * discount * agg,
116 | tf.transpose(tf.reverse(mask * delta, [1]), [1, 0]),
117 | tf.zeros_like(delta[:, -1]), 1, False), [1, 0]), [1])
118 | return tf.check_numerics(tf.stop_gradient(advantage), 'advantage')
119 |
120 |
121 | def available_gpus():
122 | """List of GPU device names detected by TensorFlow."""
123 | local_device_protos = device_lib.list_local_devices()
124 | return [x.name for x in local_device_protos if x.device_type == 'GPU']
125 |
126 |
127 | def gradient_summaries(grad_vars, groups=None, scope='gradients'):
128 | """Create histogram summaries of the gradient.
129 |
130 | Summaries can be grouped via regexes matching variables names.
131 |
132 | Args:
133 | grad_vars: List of (gradient, variable) tuples as returned by optimizers.
134 | groups: Mapping of name to regex for grouping summaries.
135 | scope: Name scope for this operation.
136 |
137 | Returns:
138 | Summary tensor.
139 | """
140 | groups = groups or {r'all': r'.*'}
141 | grouped = collections.defaultdict(list)
142 | for grad, var in grad_vars:
143 | if grad is None:
144 | continue
145 | for name, pattern in groups.items():
146 | if re.match(pattern, var.name):
147 | name = re.sub(pattern, name, var.name)
148 | grouped[name].append(grad)
149 | for name in groups:
150 | if name not in grouped:
151 | tf.logging.warn("No variables matching '{}' group.".format(name))
152 | summaries = []
153 | for name, grads in grouped.items():
154 | grads = [tf.reshape(grad, [-1]) for grad in grads]
155 | grads = tf.concat(grads, 0)
156 | summaries.append(tf.summary.histogram(scope + '/' + name, grads))
157 | return tf.summary.merge(summaries)
158 |
159 |
160 | def variable_summaries(vars_, groups=None, scope='weights'):
161 | """Create histogram summaries for the provided variables.
162 |
163 | Summaries can be grouped via regexes matching variables names.
164 |
165 | Args:
166 | vars_: List of variables to summarize.
167 | groups: Mapping of name to regex for grouping summaries.
168 | scope: Name scope for this operation.
169 |
170 | Returns:
171 | Summary tensor.
172 | """
173 | groups = groups or {r'all': r'.*'}
174 | grouped = collections.defaultdict(list)
175 | for var in vars_:
176 | for name, pattern in groups.items():
177 | if re.match(pattern, var.name):
178 | name = re.sub(pattern, name, var.name)
179 | grouped[name].append(var)
180 | for name in groups:
181 | if name not in grouped:
182 | tf.logging.warn("No variables matching '{}' group.".format(name))
183 | summaries = []
184 | # pylint: disable=redefined-argument-from-local
185 | for name, vars_ in grouped.items():
186 | vars_ = [tf.reshape(var, [-1]) for var in vars_]
187 | vars_ = tf.concat(vars_, 0)
188 | summaries.append(tf.summary.histogram(scope + '/' + name, vars_))
189 | return tf.summary.merge(summaries)
190 |
191 |
192 | def set_dimension(tensor, axis, value):
193 | """Set the length of a tensor along the specified dimension.
194 |
195 | Args:
196 | tensor: Tensor to define shape of.
197 | axis: Dimension to set the static shape for.
198 | value: Integer holding the length.
199 |
200 | Raises:
201 | ValueError: When the tensor already has a different length specified.
202 | """
203 | shape = tensor.shape.as_list()
204 | if shape[axis] not in (value, None):
205 | message = 'Cannot set dimension {} of tensor {} to {}; is already {}.'
206 | raise ValueError(message.format(axis, tensor.name, value, shape[axis]))
207 | shape[axis] = value
208 | tensor.set_shape(shape)
209 |
--------------------------------------------------------------------------------
/agents/parts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Reusable parts for building agents."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from .normalize import StreamingNormalize
22 | from .memory import EpisodeMemory
23 | from .iterate_sequences import iterate_sequences
24 |
--------------------------------------------------------------------------------
/agents/parts/iterate_sequences.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Normalize tensors based on streaming estimates of mean and variance."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from agents import tools
24 |
25 |
26 | def iterate_sequences(
27 | consumer_fn, output_template, sequences, length, chunk_length=None,
28 | batch_size=None, num_epochs=1, padding_value=0):
29 | """Iterate over batches of chunks of sequences for multiple epochs.
30 |
31 | The batch dimension of the length tensor must be set because it is used to
32 | infer buffer sizes.
33 |
34 | Args:
35 | consumer_fn: Function creating the operation to process the data.
36 | output_template: Nested tensors of same shape and dtype as outputs.
37 | sequences: Nested collection of tensors with batch and time dimension.
38 | length: Tensor containing the length for each sequence.
39 | chunk_length: Split sequences into chunks of this size; optional.
40 | batch_size: Split epochs into batches of this size; optional.
41 | num_epochs: How many times to repeat over the data.
42 | padding_value: Value used for padding the last chunk after the sequence.
43 |
44 | Raises:
45 | ValueError: Unknown batch size of the length tensor.
46 |
47 | Returns:
48 | Concatenated nested tensors returned by the consumer.
49 | """
50 | if not length.shape[0].value:
51 | raise ValueError('Batch size of length tensor must be set.')
52 | num_sequences = length.shape[0].value
53 | sequences = dict(sequence=sequences, length=length)
54 | dataset = tf.data.Dataset.from_tensor_slices(sequences)
55 | dataset = dataset.repeat(num_epochs)
56 | if chunk_length:
57 | dataset = dataset.map(remove_padding).flat_map(
58 | # pylint: disable=g-long-lambda
59 | lambda x: tf.data.Dataset.from_tensor_slices(
60 | chunk_sequence(x, chunk_length, padding_value)))
61 | num_chunks = tf.reduce_sum((length - 1) // chunk_length + 1)
62 | else:
63 | num_chunks = num_sequences
64 | if batch_size:
65 | dataset = dataset.shuffle(num_sequences // 2)
66 | dataset = dataset.batch(batch_size or num_sequences)
67 | dataset = dataset.prefetch(num_epochs)
68 | iterator = dataset.make_initializable_iterator()
69 | with tf.control_dependencies([iterator.initializer]):
70 | num_batches = num_epochs * num_chunks // (batch_size or num_sequences)
71 | return tf.scan(
72 | # pylint: disable=g-long-lambda
73 | lambda _1, index: consumer_fn(iterator.get_next()),
74 | tf.range(num_batches), output_template, parallel_iterations=1)
75 |
76 |
77 | def chunk_sequence(sequence, chunk_length=200, padding_value=0):
78 | """Split a nested dict of sequence tensors into a batch of chunks.
79 |
80 | This function does not expect a batch of sequences, but a single sequence. A
81 | `length` key is added if it did not exist already.
82 |
83 | Args:
84 | sequence: Nested dict of tensors with time dimension.
85 | chunk_length: Size of chunks the sequence will be split into.
86 | padding_value: Value used for padding the last chunk after the sequence.
87 |
88 | Returns:
89 | Nested dict of sequence tensors with chunk dimension.
90 | """
91 | if 'length' in sequence:
92 | length = sequence.pop('length')
93 | else:
94 | length = tf.shape(tools.nested.flatten(sequence)[0])[0]
95 | num_chunks = (length - 1) // chunk_length + 1
96 | padding_length = chunk_length * num_chunks - length
97 | padded = tools.nested.map(
98 | # pylint: disable=g-long-lambda
99 | lambda tensor: tf.concat([
100 | tensor, 0 * tensor[:padding_length] + padding_value], 0),
101 | sequence)
102 | chunks = tools.nested.map(
103 | # pylint: disable=g-long-lambda
104 | lambda tensor: tf.reshape(
105 | tensor, [num_chunks, chunk_length] + tensor.shape[1:].as_list()),
106 | padded)
107 | chunks['length'] = tf.concat([
108 | chunk_length * tf.ones((num_chunks - 1,), dtype=tf.int32),
109 | [chunk_length - padding_length]], 0)
110 | return chunks
111 |
112 |
113 | def remove_padding(sequence):
114 | """Selects the used frames of a sequence, up to its length.
115 |
116 | This function does not expect a batch of sequences, but a single sequence.
117 | The sequence must be a dict with `length` key, which will removed from the
118 | result.
119 |
120 | Args:
121 | sequence: Nested dict of tensors with time dimension.
122 |
123 | Returns:
124 | Nested dict of tensors with padding elements and `length` key removed.
125 | """
126 | length = sequence.pop('length')
127 | sequence = tools.nested.map(lambda tensor: tensor[:length], sequence)
128 | return sequence
129 |
--------------------------------------------------------------------------------
/agents/parts/memory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Memory that stores episodes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from agents import tools
24 |
25 |
26 | class EpisodeMemory(object):
27 | """Memory that stores episodes."""
28 |
29 | def __init__(self, template, capacity, max_length, scope):
30 | """Create a memory that stores episodes.
31 |
32 | Each transition tuple consists of quantities specified by the template.
33 | These quantities would typically be be observations, actions, rewards, and
34 | done indicators.
35 |
36 | Args:
37 | template: Nested tensors to derive shapes and dtypes of each transition.
38 | capacity: Number of episodes, or rows, hold by the memory.
39 | max_length: Allocated sequence length for the episodes.
40 | scope: Variable scope to use for internal variables.
41 | """
42 | self._capacity = capacity
43 | self._max_length = max_length
44 | with tf.variable_scope(scope) as var_scope:
45 | self._scope = var_scope
46 | self._length = tf.Variable(tf.zeros(capacity, tf.int32), False)
47 | self._buffers = tools.nested.map(
48 | lambda x: tf.Variable(tf.zeros(
49 | [capacity, max_length] + x.shape.as_list(), x.dtype), False),
50 | template)
51 |
52 | def length(self, rows=None):
53 | """Tensor holding the current length of episodes.
54 |
55 | Args:
56 | rows: Episodes to select length from, defaults to all.
57 |
58 | Returns:
59 | Batch tensor of sequence lengths.
60 | """
61 | rows = tf.range(self._capacity) if rows is None else rows
62 | return tf.gather(self._length, rows)
63 |
64 | def append(self, transitions, rows=None):
65 | """Append a batch of transitions to rows of the memory.
66 |
67 | Args:
68 | transitions: Tuple of transition quantities with batch dimension.
69 | rows: Episodes to append to, defaults to all.
70 |
71 | Returns:
72 | Operation.
73 | """
74 | rows = tf.range(self._capacity) if rows is None else rows
75 | assert rows.shape.ndims == 1
76 | assert_capacity = tf.assert_less(
77 | rows, self._capacity,
78 | message='capacity exceeded')
79 | with tf.control_dependencies([assert_capacity]):
80 | assert_max_length = tf.assert_less(
81 | tf.gather(self._length, rows), self._max_length,
82 | message='max length exceeded')
83 | with tf.control_dependencies([assert_max_length]):
84 | timestep = tf.gather(self._length, rows)
85 | indices = tf.stack([rows, timestep], 1)
86 | append_ops = tools.nested.map(
87 | lambda var, val: tf.scatter_nd_update(var, indices, val),
88 | self._buffers, transitions, flatten=True)
89 | with tf.control_dependencies(append_ops):
90 | episode_mask = tf.reduce_sum(tf.one_hot(
91 | rows, self._capacity, dtype=tf.int32), 0)
92 | return self._length.assign_add(episode_mask)
93 |
94 | def replace(self, episodes, length, rows=None):
95 | """Replace full episodes.
96 |
97 | Args:
98 | episodes: Tuple of transition quantities with batch and time dimensions.
99 | length: Batch of sequence lengths.
100 | rows: Episodes to replace, defaults to all.
101 |
102 | Returns:
103 | Operation.
104 | """
105 | rows = tf.range(self._capacity) if rows is None else rows
106 | assert rows.shape.ndims == 1
107 | assert_capacity = tf.assert_less(
108 | rows, self._capacity, message='capacity exceeded')
109 | with tf.control_dependencies([assert_capacity]):
110 | assert_max_length = tf.assert_less_equal(
111 | length, self._max_length, message='max length exceeded')
112 | with tf.control_dependencies([assert_max_length]):
113 | replace_ops = tools.nested.map(
114 | lambda var, val: tf.scatter_update(var, rows, val),
115 | self._buffers, episodes, flatten=True)
116 | with tf.control_dependencies(replace_ops):
117 | return tf.scatter_update(self._length, rows, length)
118 |
119 | def data(self, rows=None):
120 | """Access a batch of episodes from the memory.
121 |
122 | Padding elements after the length of each episode are unspecified and might
123 | contain old data.
124 |
125 | Args:
126 | rows: Episodes to select, defaults to all.
127 |
128 | Returns:
129 | Tuple containing a tuple of transition quantities with batch and time
130 | dimensions, and a batch of sequence lengths.
131 | """
132 | rows = tf.range(self._capacity) if rows is None else rows
133 | assert rows.shape.ndims == 1
134 | episode = tools.nested.map(lambda var: tf.gather(var, rows), self._buffers)
135 | length = tf.gather(self._length, rows)
136 | return episode, length
137 |
138 | def clear(self, rows=None):
139 | """Reset episodes in the memory.
140 |
141 | Internally, this only sets their lengths to zero. The memory entries will
142 | be overridden by future calls to append() or replace().
143 |
144 | Args:
145 | rows: Episodes to clear, defaults to all.
146 |
147 | Returns:
148 | Operation.
149 | """
150 | rows = tf.range(self._capacity) if rows is None else rows
151 | assert rows.shape.ndims == 1
152 | return tf.scatter_update(self._length, rows, tf.zeros_like(rows))
153 |
--------------------------------------------------------------------------------
/agents/parts/normalize.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Normalize tensors based on streaming estimates of mean and variance."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 |
24 | class StreamingNormalize(object):
25 | """Normalize tensors based on streaming estimates of mean and variance."""
26 |
27 | def __init__(
28 | self, template, center=True, scale=True, clip=10, name='normalize'):
29 | """Normalize tensors based on streaming estimates of mean and variance.
30 |
31 | Centering the value, scaling it by the standard deviation, and clipping
32 | outlier values are optional.
33 |
34 | Args:
35 | template: Example tensor providing shape and dtype of the vaule to track.
36 | center: Python boolean indicating whether to subtract mean from values.
37 | scale: Python boolean indicating whether to scale values by stddev.
38 | clip: If and when to clip normalized values.
39 | name: Parent scope of operations provided by this class.
40 | """
41 | self._center = center
42 | self._scale = scale
43 | self._clip = clip
44 | self._name = name
45 | with tf.name_scope(name):
46 | self._count = tf.Variable(0, False)
47 | self._mean = tf.Variable(tf.zeros_like(template), False)
48 | self._var_sum = tf.Variable(tf.zeros_like(template), False)
49 |
50 | def transform(self, value):
51 | """Normalize a single or batch tensor.
52 |
53 | Applies the activated transformations in the constructor using current
54 | estimates of mean and variance.
55 |
56 | Args:
57 | value: Batch or single value tensor.
58 |
59 | Returns:
60 | Normalized batch or single value tensor.
61 | """
62 | with tf.name_scope(self._name + '/transform'):
63 | no_batch_dim = value.shape.ndims == self._mean.shape.ndims
64 | if no_batch_dim:
65 | # Add a batch dimension if necessary.
66 | value = value[None, ...]
67 | if self._center:
68 | value -= self._mean[None, ...]
69 | if self._scale:
70 | # We cannot scale before seeing at least two samples.
71 | value /= tf.cond(
72 | self._count > 1, lambda: self._std() + 1e-8,
73 | lambda: tf.ones_like(self._var_sum))[None]
74 | if self._clip:
75 | value = tf.clip_by_value(value, -self._clip, self._clip)
76 | # Remove batch dimension if necessary.
77 | if no_batch_dim:
78 | value = value[0]
79 | return tf.check_numerics(value, 'value')
80 |
81 | def update(self, value):
82 | """Update the mean and variance estimates.
83 |
84 | Args:
85 | value: Batch or single value tensor.
86 |
87 | Returns:
88 | Summary tensor.
89 | """
90 | with tf.name_scope(self._name + '/update'):
91 | if value.shape.ndims == self._mean.shape.ndims:
92 | # Add a batch dimension if necessary.
93 | value = value[None, ...]
94 | count = tf.shape(value)[0]
95 | with tf.control_dependencies([self._count.assign_add(count)]):
96 | step = tf.cast(self._count, tf.float32)
97 | mean_delta = tf.reduce_sum(value - self._mean[None, ...], 0)
98 | new_mean = self._mean + mean_delta / step
99 | new_mean = tf.cond(self._count > 1, lambda: new_mean, lambda: value[0])
100 | var_delta = (
101 | value - self._mean[None, ...]) * (value - new_mean[None, ...])
102 | new_var_sum = self._var_sum + tf.reduce_sum(var_delta, 0)
103 | with tf.control_dependencies([new_mean, new_var_sum]):
104 | update = self._mean.assign(new_mean), self._var_sum.assign(new_var_sum)
105 | with tf.control_dependencies(update):
106 | if value.shape.ndims == 1:
107 | value = tf.reduce_mean(value)
108 | return self._summary('value', tf.reduce_mean(value))
109 |
110 | def reset(self):
111 | """Reset the estimates of mean and variance.
112 |
113 | Resets the full state of this class.
114 |
115 | Returns:
116 | Operation.
117 | """
118 | with tf.name_scope(self._name + '/reset'):
119 | return tf.group(
120 | self._count.assign(0),
121 | self._mean.assign(tf.zeros_like(self._mean)),
122 | self._var_sum.assign(tf.zeros_like(self._var_sum)))
123 |
124 | def summary(self):
125 | """Summary string of mean and standard deviation.
126 |
127 | Returns:
128 | Summary tensor.
129 | """
130 | with tf.name_scope(self._name + '/summary'):
131 | mean_summary = tf.cond(
132 | self._count > 0, lambda: self._summary('mean', self._mean), str)
133 | std_summary = tf.cond(
134 | self._count > 1, lambda: self._summary('stddev', self._std()), str)
135 | return tf.summary.merge([mean_summary, std_summary])
136 |
137 | def _std(self):
138 | """Computes the current estimate of the standard deviation.
139 |
140 | Note that the standard deviation is not defined until at least two samples
141 | were seen.
142 |
143 | Returns:
144 | Tensor of current variance.
145 | """
146 | variance = tf.cond(
147 | self._count > 1,
148 | lambda: self._var_sum / tf.cast(self._count - 1, tf.float32),
149 | lambda: tf.ones_like(self._var_sum) * float('nan'))
150 | # The epsilon corrects for small negative variance values caused by
151 | # the algorithm. It was empirically chosen to work with all environments
152 | # tested.
153 | return tf.sqrt(variance + 1e-4)
154 |
155 | def _summary(self, name, tensor):
156 | """Create a scalar or histogram summary matching the rank of the tensor.
157 |
158 | Args:
159 | name: Name for the summary.
160 | tensor: Tensor to summarize.
161 |
162 | Returns:
163 | Summary tensor.
164 | """
165 | if tensor.shape.ndims == 0:
166 | return tf.summary.scalar(name, tensor)
167 | else:
168 | return tf.summary.histogram(name, tensor)
169 |
--------------------------------------------------------------------------------
/agents/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Executable scripts for reinforcement learning."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from . import train
22 | from . import utility
23 | from . import visualize
24 |
--------------------------------------------------------------------------------
/agents/scripts/configs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example configurations using the PPO algorithm."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # pylint: disable=unused-variable
22 |
23 | import tensorflow as tf
24 |
25 | from agents import algorithms
26 | from agents.scripts import networks
27 |
28 |
29 | def default():
30 | """Default configuration for PPO."""
31 | # General
32 | algorithm = algorithms.PPO
33 | num_agents = 30
34 | eval_episodes = 30
35 | use_gpu = False
36 | # Environment
37 | normalize_ranges = True
38 | # Network
39 | network = networks.feed_forward_gaussian
40 | weight_summaries = dict(
41 | all=r'.*', policy=r'.*/policy/.*', value=r'.*/value/.*')
42 | policy_layers = 200, 100
43 | value_layers = 200, 100
44 | init_output_factor = 0.1
45 | init_std = 0.35
46 | # Optimization
47 | update_every = 30
48 | update_epochs = 25
49 | optimizer = tf.train.AdamOptimizer
50 | learning_rate = 1e-4
51 | # Losses
52 | discount = 0.995
53 | kl_target = 1e-2
54 | kl_cutoff_factor = 2
55 | kl_cutoff_coef = 1000
56 | kl_init_penalty = 1
57 | return locals()
58 |
59 |
60 | def pendulum():
61 | """Configuration for the pendulum classic control task."""
62 | locals().update(default())
63 | # Environment
64 | env = 'Pendulum-v0'
65 | max_length = 200
66 | steps = 1e6 # 1M
67 | # Optimization
68 | batch_size = 20
69 | chunk_length = 50
70 | return locals()
71 |
72 |
73 | def cartpole():
74 | """Configuration for the cart pole classic control task."""
75 | locals().update(default())
76 | # Environment
77 | env = 'CartPole-v1'
78 | max_length = 500
79 | steps = 2e5 # 200k
80 | normalize_ranges = False # The env reports wrong ranges.
81 | # Network
82 | network = networks.feed_forward_categorical
83 | return locals()
84 |
85 |
86 | def reacher():
87 | """Configuration for MuJoCo's reacher task."""
88 | locals().update(default())
89 | # Environment
90 | env = 'Reacher-v2'
91 | max_length = 1000
92 | steps = 5e6 # 5M
93 | discount = 0.985
94 | update_every = 60
95 | return locals()
96 |
97 |
98 | def cheetah():
99 | """Configuration for MuJoCo's half cheetah task."""
100 | locals().update(default())
101 | # Environment
102 | env = 'HalfCheetah-v2'
103 | max_length = 1000
104 | steps = 1e7 # 10M
105 | discount = 0.99
106 | return locals()
107 |
108 |
109 | def walker():
110 | """Configuration for MuJoCo's walker task."""
111 | locals().update(default())
112 | # Environment
113 | env = 'Walker2d-v2'
114 | max_length = 1000
115 | steps = 1e7 # 10M
116 | return locals()
117 |
118 |
119 | def hopper():
120 | """Configuration for MuJoCo's hopper task."""
121 | locals().update(default())
122 | # Environment
123 | env = 'Hopper-v2'
124 | max_length = 1000
125 | steps = 1e7 # 10M
126 | update_every = 60
127 | return locals()
128 |
129 |
130 | def ant():
131 | """Configuration for MuJoCo's ant task."""
132 | locals().update(default())
133 | # Environment
134 | env = 'Ant-v2'
135 | max_length = 1000
136 | steps = 2e7 # 20M
137 | return locals()
138 |
139 |
140 | def humanoid():
141 | """Configuration for MuJoCo's humanoid task."""
142 | locals().update(default())
143 | # Environment
144 | env = 'Humanoid-v2'
145 | max_length = 1000
146 | steps = 5e7 # 50M
147 | update_every = 60
148 | return locals()
149 |
150 |
151 | def bullet_ant():
152 | """Configuration for PyBullet's ant task."""
153 | locals().update(default())
154 | # Environment
155 | import pybullet_envs # noqa pylint: disable=unused-import
156 | env = 'AntBulletEnv-v0'
157 | max_length = 1000
158 | steps = 3e7 # 30M
159 | update_every = 60
160 | return locals()
161 |
--------------------------------------------------------------------------------
/agents/scripts/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Policy networks for agents."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 | import operator
23 |
24 | import gym
25 | import numpy as np
26 | import tensorflow as tf
27 |
28 | import agents
29 |
30 | tfd = tf.contrib.distributions
31 |
32 |
33 | # TensorFlow's default implementation of the KL divergence between two
34 | # tf.contrib.distributions.MultivariateNormalDiag instances sometimes results
35 | # in NaN values in the gradients (not in the forward pass). Until the default
36 | # implementation is fixed, we use our own KL implementation.
37 | class CustomKLDiagNormal(tfd.MultivariateNormalDiag):
38 | """Multivariate Normal with diagonal covariance and our custom KL code."""
39 | pass
40 |
41 |
42 | @tfd.RegisterKL(CustomKLDiagNormal, CustomKLDiagNormal)
43 | def _custom_diag_normal_kl(lhs, rhs, name=None): # pylint: disable=unused-argument
44 | """Empirical KL divergence of two normals with diagonal covariance.
45 |
46 | Args:
47 | lhs: Diagonal Normal distribution.
48 | rhs: Diagonal Normal distribution.
49 | name: Name scope for the op.
50 |
51 | Returns:
52 | KL divergence from lhs to rhs.
53 | """
54 | with tf.name_scope(name or 'kl_divergence'):
55 | mean0 = lhs.mean()
56 | mean1 = rhs.mean()
57 | logstd0 = tf.log(lhs.stddev())
58 | logstd1 = tf.log(rhs.stddev())
59 | logstd0_2, logstd1_2 = 2 * logstd0, 2 * logstd1
60 | return 0.5 * (
61 | tf.reduce_sum(tf.exp(logstd0_2 - logstd1_2), -1) +
62 | tf.reduce_sum((mean1 - mean0) ** 2 / tf.exp(logstd1_2), -1) +
63 | tf.reduce_sum(logstd1_2, -1) - tf.reduce_sum(logstd0_2, -1) -
64 | mean0.shape[-1].value)
65 |
66 |
67 | def feed_forward_gaussian(
68 | config, action_space, observations, unused_length, state=None):
69 | """Independent feed forward networks for policy and value.
70 |
71 | The policy network outputs the mean action and the standard deviation is
72 | learned as independent parameter vector.
73 |
74 | Args:
75 | config: Configuration object.
76 | action_space: Action space of the environment.
77 | observations: Sequences of observations.
78 | unused_length: Batch of sequence lengths.
79 | state: Unused batch of initial states.
80 |
81 | Raises:
82 | ValueError: Unexpected action space.
83 |
84 | Returns:
85 | Attribute dictionary containing the policy, value, and unused state.
86 | """
87 | if not isinstance(action_space, gym.spaces.Box):
88 | raise ValueError('Network expects continuous actions.')
89 | if not len(action_space.shape) == 1:
90 | raise ValueError('Network only supports 1D action vectors.')
91 | action_size = action_space.shape[0]
92 | init_output_weights = tf.contrib.layers.variance_scaling_initializer(
93 | factor=config.init_output_factor)
94 | before_softplus_std_initializer = tf.constant_initializer(
95 | np.log(np.exp(config.init_std) - 1))
96 | flat_observations = tf.reshape(observations, [
97 | tf.shape(observations)[0], tf.shape(observations)[1],
98 | functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
99 | with tf.variable_scope('policy'):
100 | x = flat_observations
101 | for size in config.policy_layers:
102 | x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
103 | mean = tf.contrib.layers.fully_connected(
104 | x, action_size, tf.tanh,
105 | weights_initializer=init_output_weights)
106 | std = tf.nn.softplus(tf.get_variable(
107 | 'before_softplus_std', mean.shape[2:], tf.float32,
108 | before_softplus_std_initializer))
109 | std = tf.tile(
110 | std[None, None],
111 | [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
112 | with tf.variable_scope('value'):
113 | x = flat_observations
114 | for size in config.value_layers:
115 | x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
116 | value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
117 | mean = tf.check_numerics(mean, 'mean')
118 | std = tf.check_numerics(std, 'std')
119 | value = tf.check_numerics(value, 'value')
120 | policy = CustomKLDiagNormal(mean, std)
121 | return agents.tools.AttrDict(policy=policy, value=value, state=state)
122 |
123 |
124 | def feed_forward_categorical(
125 | config, action_space, observations, unused_length, state=None):
126 | """Independent feed forward networks for policy and value.
127 |
128 | The policy network outputs the mean action and the log standard deviation
129 | is learned as independent parameter vector.
130 |
131 | Args:
132 | config: Configuration object.
133 | action_space: Action space of the environment.
134 | observations: Sequences of observations.
135 | unused_length: Batch of sequence lengths.
136 | state: Unused batch of initial recurrent states.
137 |
138 | Raises:
139 | ValueError: Unexpected action space.
140 |
141 | Returns:
142 | Attribute dictionary containing the policy, value, and unused state.
143 | """
144 | init_output_weights = tf.contrib.layers.variance_scaling_initializer(
145 | factor=config.init_output_factor)
146 | if not isinstance(action_space, gym.spaces.Discrete):
147 | raise ValueError('Network expects discrete actions.')
148 | flat_observations = tf.reshape(observations, [
149 | tf.shape(observations)[0], tf.shape(observations)[1],
150 | functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
151 | with tf.variable_scope('policy'):
152 | x = flat_observations
153 | for size in config.policy_layers:
154 | x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
155 | logits = tf.contrib.layers.fully_connected(
156 | x, action_space.n, None, weights_initializer=init_output_weights)
157 | with tf.variable_scope('value'):
158 | x = flat_observations
159 | for size in config.value_layers:
160 | x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
161 | value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
162 | policy = tfd.Categorical(logits)
163 | return agents.tools.AttrDict(policy=policy, value=value, state=state)
164 |
165 |
166 | def recurrent_gaussian(
167 | config, action_space, observations, length, state=None):
168 | """Independent recurrent policy and feed forward value networks.
169 |
170 | The policy network outputs the mean action and the standard deviation is
171 | learned as independent parameter vector. The last policy layer is recurrent
172 | and uses a GRU cell.
173 |
174 | Args:
175 | config: Configuration object.
176 | action_space: Action space of the environment.
177 | observations: Sequences of observations.
178 | length: Batch of sequence lengths.
179 | state: Batch of initial recurrent states.
180 |
181 | Raises:
182 | ValueError: Unexpected action space.
183 |
184 | Returns:
185 | Attribute dictionary containing the policy, value, and state.
186 | """
187 | if not isinstance(action_space, gym.spaces.Box):
188 | raise ValueError('Network expects continuous actions.')
189 | if not len(action_space.shape) == 1:
190 | raise ValueError('Network only supports 1D action vectors.')
191 | action_size = action_space.shape[0]
192 | init_output_weights = tf.contrib.layers.variance_scaling_initializer(
193 | factor=config.init_output_factor)
194 | before_softplus_std_initializer = tf.constant_initializer(
195 | np.log(np.exp(config.init_std) - 1))
196 | cell = tf.contrib.rnn.GRUBlockCell(config.policy_layers[-1])
197 | flat_observations = tf.reshape(observations, [
198 | tf.shape(observations)[0], tf.shape(observations)[1],
199 | functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
200 | with tf.variable_scope('policy'):
201 | x = flat_observations
202 | for size in config.policy_layers[:-1]:
203 | x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
204 | x, state = tf.nn.dynamic_rnn(cell, x, length, state, tf.float32)
205 | mean = tf.contrib.layers.fully_connected(
206 | x, action_size, tf.tanh,
207 | weights_initializer=init_output_weights)
208 | std = tf.nn.softplus(tf.get_variable(
209 | 'before_softplus_std', mean.shape[2:], tf.float32,
210 | before_softplus_std_initializer))
211 | std = tf.tile(
212 | std[None, None],
213 | [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
214 | with tf.variable_scope('value'):
215 | x = flat_observations
216 | for size in config.value_layers:
217 | x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
218 | value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
219 | mean = tf.check_numerics(mean, 'mean')
220 | std = tf.check_numerics(std, 'std')
221 | value = tf.check_numerics(value, 'value')
222 | policy = CustomKLDiagNormal(mean, std)
223 | return agents.tools.AttrDict(policy=policy, value=value, state=state)
224 |
--------------------------------------------------------------------------------
/agents/scripts/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Script to train a batch reinforcement learning algorithm.
16 |
17 | Command line:
18 |
19 | python3 -m agents.scripts.train --logdir=/path/to/logdir --config=pendulum
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import datetime
27 | import os
28 |
29 | import gym
30 | import tensorflow as tf
31 |
32 | from agents import tools
33 | from agents.scripts import configs
34 | from agents.scripts import utility
35 |
36 |
37 | def _create_environment(config):
38 | """Constructor for an instance of the environment.
39 |
40 | Args:
41 | config: Object providing configurations via attributes.
42 |
43 | Raises:
44 | NotImplementedError: For action spaces other than Box and Discrete.
45 |
46 | Returns:
47 | Wrapped OpenAI Gym environment.
48 | """
49 | if isinstance(config.env, str):
50 | env = gym.make(config.env)
51 | else:
52 | env = config.env()
53 | if config.max_length:
54 | env = tools.wrappers.LimitDuration(env, config.max_length)
55 | if isinstance(env.action_space, gym.spaces.Box):
56 | if config.normalize_ranges:
57 | env = tools.wrappers.RangeNormalize(env)
58 | env = tools.wrappers.ClipAction(env)
59 | elif isinstance(env.action_space, gym.spaces.Discrete):
60 | if config.normalize_ranges:
61 | env = tools.wrappers.RangeNormalize(env, action=False)
62 | else:
63 | message = "Unsupported action space '{}'".format(type(env.action_space))
64 | raise NotImplementedError(message)
65 | env = tools.wrappers.ConvertTo32Bit(env)
66 | env = tools.wrappers.CacheSpaces(env)
67 | return env
68 |
69 |
70 | def _define_loop(graph, logdir, train_steps, eval_steps):
71 | """Create and configure a training loop with training and evaluation phases.
72 |
73 | Args:
74 | graph: Object providing graph elements via attributes.
75 | logdir: Log directory for storing checkpoints and summaries.
76 | train_steps: Number of training steps per epoch.
77 | eval_steps: Number of evaluation steps per epoch.
78 |
79 | Returns:
80 | Loop object.
81 | """
82 | loop = tools.Loop(
83 | logdir, graph.step, graph.should_log, graph.do_report,
84 | graph.force_reset)
85 | loop.add_phase(
86 | 'train', graph.done, graph.score, graph.summary, train_steps,
87 | report_every=train_steps,
88 | log_every=train_steps // 2,
89 | checkpoint_every=None,
90 | feed={graph.is_training: True})
91 | loop.add_phase(
92 | 'eval', graph.done, graph.score, graph.summary, eval_steps,
93 | report_every=eval_steps,
94 | log_every=eval_steps // 2,
95 | checkpoint_every=10 * eval_steps,
96 | feed={graph.is_training: False})
97 | return loop
98 |
99 |
100 | def train(config, env_processes):
101 | """Training and evaluation entry point yielding scores.
102 |
103 | Resolves some configuration attributes, creates environments, graph, and
104 | training loop. By default, assigns all operations to the CPU.
105 |
106 | Args:
107 | config: Object providing configurations via attributes.
108 | env_processes: Whether to step environments in separate processes.
109 |
110 | Yields:
111 | Evaluation scores.
112 | """
113 | tf.reset_default_graph()
114 | if config.update_every % config.num_agents:
115 | tf.logging.warn('Number of agents should divide episodes per update.')
116 | with tf.device('/cpu:0'):
117 | batch_env = utility.define_batch_env(
118 | lambda: _create_environment(config),
119 | config.num_agents, env_processes)
120 | graph = utility.define_simulation_graph(
121 | batch_env, config.algorithm, config)
122 | loop = _define_loop(
123 | graph, config.logdir,
124 | config.update_every * config.max_length,
125 | config.eval_episodes * config.max_length)
126 | total_steps = int(
127 | config.steps / config.update_every *
128 | (config.update_every + config.eval_episodes))
129 | # Exclude episode related variables since the Python state of environments is
130 | # not checkpointed and thus new episodes start after resuming.
131 | saver = utility.define_saver(exclude=(r'.*_temporary.*',))
132 | sess_config = tf.ConfigProto(allow_soft_placement=True)
133 | sess_config.gpu_options.allow_growth = True
134 | with tf.Session(config=sess_config) as sess:
135 | utility.initialize_variables(sess, saver, config.logdir)
136 | for score in loop.run(sess, saver, total_steps):
137 | yield score
138 | batch_env.close()
139 |
140 |
141 | def main(_):
142 | """Create or load configuration and launch the trainer."""
143 | utility.set_up_logging()
144 | if not FLAGS.config:
145 | raise KeyError('You must specify a configuration.')
146 | logdir = FLAGS.logdir and os.path.expanduser(os.path.join(
147 | FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp, FLAGS.config)))
148 | try:
149 | config = utility.load_config(logdir)
150 | except IOError:
151 | config = tools.AttrDict(getattr(configs, FLAGS.config)())
152 | config = utility.save_config(config, logdir)
153 | for score in train(config, FLAGS.env_processes):
154 | tf.logging.info('Score {}.'.format(score))
155 |
156 |
157 | if __name__ == '__main__':
158 | FLAGS = tf.app.flags.FLAGS
159 | tf.app.flags.DEFINE_string(
160 | 'logdir', None,
161 | 'Base directory to store logs.')
162 | tf.app.flags.DEFINE_string(
163 | 'timestamp', datetime.datetime.now().strftime('%Y%m%dT%H%M%S'),
164 | 'Sub directory to store logs.')
165 | tf.app.flags.DEFINE_string(
166 | 'config', None,
167 | 'Configuration to execute.')
168 | tf.app.flags.DEFINE_boolean(
169 | 'env_processes', True,
170 | 'Step environments in separate processes to circumvent the GIL.')
171 | tf.app.run()
172 |
--------------------------------------------------------------------------------
/agents/scripts/train_ppo_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for the PPO algorithm usage example."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 | import itertools
23 |
24 | import tensorflow as tf
25 |
26 | from agents import algorithms
27 | from agents import tools
28 | from agents.scripts import configs
29 | from agents.scripts import networks
30 | from agents.scripts import train
31 |
32 |
33 | class PPOTest(tf.test.TestCase):
34 |
35 | def test_pendulum_no_crash(self):
36 | nets = networks.feed_forward_gaussian, networks.recurrent_gaussian
37 | for network in nets:
38 | config = self._define_config()
39 | with config.unlocked:
40 | config.env = 'Pendulum-v0'
41 | config.max_length = 200
42 | config.steps = 500
43 | config.network = network
44 | for score in train.train(config, env_processes=True):
45 | float(score)
46 |
47 | def test_no_crash_cartpole(self):
48 | config = self._define_config()
49 | with config.unlocked:
50 | config.env = 'CartPole-v1'
51 | config.max_length = 200
52 | config.steps = 500
53 | config.normalize_ranges = False # The env reports wrong ranges.
54 | config.network = networks.feed_forward_categorical
55 | for score in train.train(config, env_processes=True):
56 | float(score)
57 |
58 | def test_no_crash_observation_shape(self):
59 | nets = networks.feed_forward_gaussian, networks.recurrent_gaussian
60 | observ_shapes = (1,), (2, 3), (2, 3, 4)
61 | for network, observ_shape in itertools.product(nets, observ_shapes):
62 | config = self._define_config()
63 | with config.unlocked:
64 | config.env = functools.partial(
65 | tools.MockEnvironment, observ_shape, action_shape=(3,),
66 | min_duration=15, max_duration=15)
67 | config.max_length = 20
68 | config.steps = 50
69 | config.network = network
70 | for score in train.train(config, env_processes=False):
71 | float(score)
72 |
73 | def test_no_crash_variable_duration(self):
74 | config = self._define_config()
75 | with config.unlocked:
76 | config.env = functools.partial(
77 | tools.MockEnvironment, observ_shape=(2, 3), action_shape=(3,),
78 | min_duration=5, max_duration=25)
79 | config.max_length = 25
80 | config.steps = 100
81 | config.network = networks.recurrent_gaussian
82 | for score in train.train(config, env_processes=False):
83 | float(score)
84 |
85 | def test_no_crash_chunking(self):
86 | config = self._define_config()
87 | with config.unlocked:
88 | config.env = functools.partial(
89 | tools.MockEnvironment, observ_shape=(2, 3), action_shape=(3,),
90 | min_duration=5, max_duration=25)
91 | config.max_length = 25
92 | config.steps = 100
93 | config.network = networks.recurrent_gaussian
94 | config.chunk_length = 10
95 | config.batch_size = 5
96 | for score in train.train(config, env_processes=False):
97 | float(score)
98 |
99 | def _define_config(self):
100 | # Start from the example configuration.
101 | locals().update(configs.default())
102 | # pylint: disable=unused-variable
103 | # General
104 | algorithm = algorithms.PPO
105 | num_agents = 2
106 | update_every = 4
107 | use_gpu = False
108 | # Network
109 | policy_layers = 20, 10
110 | value_layers = 20, 10
111 | # Optimization
112 | update_epochs_policy = 2
113 | update_epochs_value = 2
114 | # pylint: enable=unused-variable
115 | return tools.AttrDict(locals())
116 |
117 |
118 | if __name__ == '__main__':
119 | tf.test.main()
120 |
--------------------------------------------------------------------------------
/agents/scripts/utility.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for using reinforcement learning algorithms."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import logging
22 | import os
23 | import re
24 |
25 | import ruamel.yaml as yaml
26 | import tensorflow as tf
27 |
28 | from agents import tools
29 |
30 |
31 | def define_simulation_graph(batch_env, algo_cls, config):
32 | """Define the algorithm and environment interaction.
33 |
34 | Args:
35 | batch_env: In-graph environments object.
36 | algo_cls: Constructor of a batch algorithm.
37 | config: Configuration object for the algorithm.
38 |
39 | Returns:
40 | Object providing graph elements via attributes.
41 | """
42 | # pylint: disable=unused-variable
43 | step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
44 | is_training = tf.placeholder(tf.bool, name='is_training')
45 | should_log = tf.placeholder(tf.bool, name='should_log')
46 | do_report = tf.placeholder(tf.bool, name='do_report')
47 | force_reset = tf.placeholder(tf.bool, name='force_reset')
48 | algo = algo_cls(batch_env, step, is_training, should_log, config)
49 | done, score, summary = tools.simulate(
50 | batch_env, algo, should_log, force_reset)
51 | message = 'Graph contains {} trainable variables.'
52 | tf.logging.info(message.format(tools.count_weights()))
53 | # pylint: enable=unused-variable
54 | return tools.AttrDict(locals())
55 |
56 |
57 | def define_batch_env(constructor, num_agents, env_processes):
58 | """Create environments and apply all desired wrappers.
59 |
60 | Args:
61 | constructor: Constructor of an OpenAI gym environment.
62 | num_agents: Number of environments to combine in the batch.
63 | env_processes: Whether to step environment in external processes.
64 |
65 | Returns:
66 | In-graph environments object.
67 | """
68 | with tf.variable_scope('environments'):
69 | if env_processes:
70 | envs = [
71 | tools.wrappers.ExternalProcess(constructor)
72 | for _ in range(num_agents)]
73 | else:
74 | envs = [constructor() for _ in range(num_agents)]
75 | batch_env = tools.BatchEnv(envs, blocking=not env_processes)
76 | batch_env = tools.InGraphBatchEnv(batch_env)
77 | return batch_env
78 |
79 |
80 | def define_saver(exclude=None):
81 | """Create a saver for the variables we want to checkpoint.
82 |
83 | Args:
84 | exclude: List of regexes to match variable names to exclude.
85 |
86 | Returns:
87 | Saver object.
88 | """
89 | variables = []
90 | exclude = exclude or []
91 | exclude = [re.compile(regex) for regex in exclude]
92 | for variable in tf.global_variables():
93 | if any(regex.match(variable.name) for regex in exclude):
94 | continue
95 | variables.append(variable)
96 | saver = tf.train.Saver(variables, keep_checkpoint_every_n_hours=5)
97 | return saver
98 |
99 |
100 | def initialize_variables(sess, saver, logdir, checkpoint=None, resume=None):
101 | """Initialize or restore variables from a checkpoint if available.
102 |
103 | Args:
104 | sess: Session to initialize variables in.
105 | saver: Saver to restore variables.
106 | logdir: Directory to search for checkpoints.
107 | checkpoint: Specify what checkpoint name to use; defaults to most recent.
108 | resume: Whether to expect recovering a checkpoint or starting a new run.
109 |
110 | Raises:
111 | ValueError: If resume expected but no log directory specified.
112 | RuntimeError: If no resume expected but a checkpoint was found.
113 | """
114 | sess.run(tf.group(
115 | tf.local_variables_initializer(),
116 | tf.global_variables_initializer()))
117 | if resume and not (logdir or checkpoint):
118 | raise ValueError('Need to specify logdir to resume a checkpoint.')
119 | if logdir:
120 | state = tf.train.get_checkpoint_state(logdir)
121 | if checkpoint:
122 | checkpoint = os.path.join(logdir, checkpoint)
123 | if not checkpoint and state and state.model_checkpoint_path:
124 | checkpoint = state.model_checkpoint_path
125 | if checkpoint and resume is False:
126 | message = 'Found unexpected checkpoint when starting a new run.'
127 | raise RuntimeError(message)
128 | if checkpoint:
129 | saver.restore(sess, checkpoint)
130 |
131 |
132 | def save_config(config, logdir=None):
133 | """Save a new configuration by name.
134 |
135 | If a logging directory is specified, is will be created and the configuration
136 | will be stored there. Otherwise, a log message will be printed.
137 |
138 | Args:
139 | config: Configuration object.
140 | logdir: Location for writing summaries and checkpoints if specified.
141 |
142 | Returns:
143 | Configuration object.
144 | """
145 | if logdir:
146 | with config.unlocked:
147 | config.logdir = logdir
148 | message = 'Start a new run and write summaries and checkpoints to {}.'
149 | tf.logging.info(message.format(config.logdir))
150 | tf.gfile.MakeDirs(config.logdir)
151 | config_path = os.path.join(config.logdir, 'config.yaml')
152 | with tf.gfile.FastGFile(config_path, 'w') as file_:
153 | yaml.dump(config, file_, default_flow_style=False)
154 | else:
155 | message = (
156 | 'Start a new run without storing summaries and checkpoints since no '
157 | 'logging directory was specified.')
158 | tf.logging.info(message)
159 | return config
160 |
161 |
162 | def load_config(logdir):
163 | # pylint: disable=missing-raises-doc
164 | """Load a configuration from the log directory.
165 |
166 | Args:
167 | logdir: The logging directory containing the configuration file.
168 |
169 | Raises:
170 | IOError: The logging directory does not contain a configuration file.
171 |
172 | Returns:
173 | Configuration object.
174 | """
175 | config_path = logdir and os.path.join(logdir, 'config.yaml')
176 | if not config_path or not tf.gfile.Exists(config_path):
177 | message = (
178 | 'Cannot resume an existing run since the logging directory does not '
179 | 'contain a configuration file.')
180 | raise IOError(message)
181 | with tf.gfile.FastGFile(config_path, 'r') as file_:
182 | config = yaml.load(file_, Loader=yaml.Loader)
183 | message = 'Resume run and write summaries and checkpoints to {}.'
184 | tf.logging.info(message.format(config.logdir))
185 | return config
186 |
187 |
188 | def set_up_logging():
189 | """Configure the TensorFlow logger."""
190 | tf.logging.set_verbosity(tf.logging.INFO)
191 | logging.getLogger('tensorflow').propagate = False
192 |
--------------------------------------------------------------------------------
/agents/scripts/visualize.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Script to render videos of the Proximal Policy Gradient algorithm.
16 |
17 | Command line:
18 |
19 | python3 -m agents.scripts.visualize \
20 | --logdir=/path/to/logdir/- --outdir=/path/to/outdir/
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import os
28 |
29 | import gym
30 | import tensorflow as tf
31 |
32 | from agents import tools
33 | from agents.scripts import utility
34 |
35 |
36 | def _create_environment(config, outdir):
37 | """Constructor for an instance of the environment.
38 |
39 | Args:
40 | config: Object providing configurations via attributes.
41 | outdir: Directory to store videos in.
42 |
43 | Raises:
44 | NotImplementedError: For action spaces other than Box and Discrete.
45 |
46 | Returns:
47 | Wrapped OpenAI Gym environment.
48 | """
49 | if isinstance(config.env, str):
50 | env = gym.make(config.env)
51 | else:
52 | env = config.env()
53 | # Ensure that the environment has the specification attribute set as expected
54 | # by the monitor wrapper.
55 | if not hasattr(env, 'spec'):
56 | setattr(env, 'spec', getattr(env, 'spec', None))
57 | if config.max_length:
58 | env = tools.wrappers.LimitDuration(env, config.max_length)
59 | env = gym.wrappers.Monitor(
60 | env, outdir, lambda unused_episode_number: True)
61 | if isinstance(env.action_space, gym.spaces.Box):
62 | env = tools.wrappers.RangeNormalize(env)
63 | env = tools.wrappers.ClipAction(env)
64 | elif isinstance(env.action_space, gym.spaces.Discrete):
65 | env = tools.wrappers.RangeNormalize(env, action=False)
66 | else:
67 | message = "Unsupported action space '{}'".format(type(env.action_space))
68 | raise NotImplementedError(message)
69 | env = tools.wrappers.ConvertTo32Bit(env)
70 | env = tools.wrappers.CacheSpaces(env)
71 | return env
72 |
73 |
74 | def _define_loop(graph, eval_steps):
75 | """Create and configure an evaluation loop.
76 |
77 | Args:
78 | graph: Object providing graph elements via attributes.
79 | eval_steps: Number of evaluation steps per epoch.
80 |
81 | Returns:
82 | Loop object.
83 | """
84 | loop = tools.Loop(
85 | None, graph.step, graph.should_log, graph.do_report, graph.force_reset)
86 | loop.add_phase(
87 | 'eval', graph.done, graph.score, graph.summary, eval_steps,
88 | report_every=eval_steps,
89 | log_every=None,
90 | checkpoint_every=None,
91 | feed={graph.is_training: False})
92 | return loop
93 |
94 |
95 | def visualize(
96 | logdir, outdir, num_agents, num_episodes, checkpoint=None,
97 | env_processes=True):
98 | """Recover checkpoint and render videos from it.
99 |
100 | Args:
101 | logdir: Logging directory of the trained algorithm.
102 | outdir: Directory to store rendered videos in.
103 | num_agents: Number of environments to simulate in parallel.
104 | num_episodes: Total number of episodes to simulate.
105 | checkpoint: Checkpoint name to load; defaults to most recent.
106 | env_processes: Whether to step environments in separate processes.
107 | """
108 | config = utility.load_config(logdir)
109 | with tf.device('/cpu:0'):
110 | batch_env = utility.define_batch_env(
111 | lambda: _create_environment(config, outdir),
112 | num_agents, env_processes)
113 | graph = utility.define_simulation_graph(
114 | batch_env, config.algorithm, config)
115 | total_steps = num_episodes * config.max_length
116 | loop = _define_loop(graph, total_steps)
117 | saver = utility.define_saver(
118 | exclude=(r'.*_temporary.*', r'global_step'))
119 | sess_config = tf.ConfigProto(allow_soft_placement=True)
120 | sess_config.gpu_options.allow_growth = True
121 | with tf.Session(config=sess_config) as sess:
122 | utility.initialize_variables(
123 | sess, saver, config.logdir, checkpoint, resume=True)
124 | for unused_score in loop.run(sess, saver, total_steps):
125 | pass
126 | batch_env.close()
127 |
128 |
129 | def main(_):
130 | """Load a trained algorithm and render videos."""
131 | utility.set_up_logging()
132 | if not FLAGS.logdir or not FLAGS.outdir:
133 | raise KeyError('You must specify logging and outdirs directories.')
134 | FLAGS.logdir = os.path.expanduser(FLAGS.logdir)
135 | FLAGS.outdir = os.path.expanduser(FLAGS.outdir)
136 | visualize(
137 | FLAGS.logdir, FLAGS.outdir, FLAGS.num_agents, FLAGS.num_episodes,
138 | FLAGS.checkpoint, FLAGS.env_processes)
139 |
140 |
141 | if __name__ == '__main__':
142 | FLAGS = tf.app.flags.FLAGS
143 | tf.app.flags.DEFINE_string(
144 | 'logdir', None,
145 | 'Directory to the checkpoint of a training run.')
146 | tf.app.flags.DEFINE_string(
147 | 'outdir', None,
148 | 'Local directory for storing the monitoring outdir.')
149 | tf.app.flags.DEFINE_string(
150 | 'checkpoint', None,
151 | 'Checkpoint name to load; defaults to most recent.')
152 | tf.app.flags.DEFINE_integer(
153 | 'num_agents', 1,
154 | 'How many environments to step in parallel.')
155 | tf.app.flags.DEFINE_integer(
156 | 'num_episodes', 5,
157 | 'Minimum number of episodes to render.')
158 | tf.app.flags.DEFINE_boolean(
159 | 'env_processes', True,
160 | 'Step environments in separate processes to circumvent the GIL.')
161 | tf.app.run()
162 |
--------------------------------------------------------------------------------
/agents/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tools for reinforcement learning."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from . import nested
22 | from . import wrappers
23 | from .attr_dict import AttrDict
24 | from .batch_env import BatchEnv
25 | from .count_weights import count_weights
26 | from .in_graph_batch_env import InGraphBatchEnv
27 | from .in_graph_env import InGraphEnv
28 | from .loop import Loop
29 | from .mock_algorithm import MockAlgorithm
30 | from .mock_environment import MockEnvironment
31 | from .simulate import simulate
32 | from .streaming_mean import StreamingMean
33 |
--------------------------------------------------------------------------------
/agents/tools/attr_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Wrap a dictionary to access keys as attributes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import contextlib
22 |
23 |
24 | class AttrDict(dict):
25 | """Wrap a dictionary to access keys as attributes."""
26 |
27 | def __init__(self, *args, **kwargs):
28 | super(AttrDict, self).__init__(*args, **kwargs)
29 | super(AttrDict, self).__setattr__('_mutable', False)
30 |
31 | def __getattr__(self, key):
32 | # Do not provide None for unimplemented magic attributes.
33 | if key.startswith('__'):
34 | raise AttributeError
35 | return self.get(key, None)
36 |
37 | def __setattr__(self, key, value):
38 | if not self._mutable:
39 | message = "Cannot set attribute '{}'.".format(key)
40 | message += " Use 'with obj.unlocked:' scope to set attributes."
41 | raise RuntimeError(message)
42 | if key.startswith('__'):
43 | raise AttributeError("Cannot set magic attribute '{}'".format(key))
44 | self[key] = value
45 |
46 | @property
47 | @contextlib.contextmanager
48 | def unlocked(self):
49 | super(AttrDict, self).__setattr__('_mutable', True)
50 | yield
51 | super(AttrDict, self).__setattr__('_mutable', False)
52 |
53 | def copy(self):
54 | return type(self)(super(AttrDict, self).copy())
55 |
--------------------------------------------------------------------------------
/agents/tools/attr_dict_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for the attribute dictionary."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from agents.tools import attr_dict
24 |
25 |
26 | class AttrDictTest(tf.test.TestCase):
27 |
28 | def test_construct_from_dict(self):
29 | initial = dict(foo=13, bar=42)
30 | obj = attr_dict.AttrDict(initial)
31 | self.assertEqual(13, obj.foo)
32 | self.assertEqual(42, obj.bar)
33 |
34 | def test_construct_from_kwargs(self):
35 | obj = attr_dict.AttrDict(foo=13, bar=42)
36 | self.assertEqual(13, obj.foo)
37 | self.assertEqual(42, obj.bar)
38 |
39 | def test_has_attribute(self):
40 | obj = attr_dict.AttrDict(foo=13)
41 | self.assertTrue('foo' in obj)
42 | self.assertFalse('bar' in obj)
43 |
44 | def test_access_default(self):
45 | obj = attr_dict.AttrDict()
46 | self.assertEqual(None, obj.foo)
47 |
48 | def test_access_magic(self):
49 | obj = attr_dict.AttrDict()
50 | with self.assertRaises(AttributeError):
51 | obj.__getstate__ # pylint: disable=pointless-statement
52 |
53 | def test_immutable_create(self):
54 | obj = attr_dict.AttrDict()
55 | with self.assertRaises(RuntimeError):
56 | obj.foo = 42
57 |
58 | def test_immutable_modify(self):
59 | obj = attr_dict.AttrDict(foo=13)
60 | with self.assertRaises(RuntimeError):
61 | obj.foo = 42
62 |
63 | def test_immutable_unlocked(self):
64 | obj = attr_dict.AttrDict()
65 | with obj.unlocked:
66 | obj.foo = 42
67 | self.assertEqual(42, obj.foo)
68 |
69 |
70 | if __name__ == '__main__':
71 | tf.test.main()
72 |
--------------------------------------------------------------------------------
/agents/tools/batch_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Combine multiple environments to step them in batch."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 |
23 |
24 | class BatchEnv(object):
25 | """Combine multiple environments to step them in batch."""
26 |
27 | def __init__(self, envs, blocking):
28 | """Combine multiple environments to step them in batch.
29 |
30 | To step environments in parallel, environments must support a
31 | `blocking=False` argument to their step and reset functions that makes them
32 | return callables instead to receive the result at a later time.
33 |
34 | Args:
35 | envs: List of environments.
36 | blocking: Step environments after another rather than in parallel.
37 |
38 | Raises:
39 | ValueError: Environments have different observation or action spaces.
40 | """
41 | self._envs = envs
42 | self._blocking = blocking
43 | observ_space = self._envs[0].observation_space
44 | if not all(env.observation_space == observ_space for env in self._envs):
45 | raise ValueError('All environments must use the same observation space.')
46 | action_space = self._envs[0].action_space
47 | if not all(env.action_space == action_space for env in self._envs):
48 | raise ValueError('All environments must use the same observation space.')
49 |
50 | def __len__(self):
51 | """Number of combined environments."""
52 | return len(self._envs)
53 |
54 | def __getitem__(self, index):
55 | """Access an underlying environment by index."""
56 | return self._envs[index]
57 |
58 | def __getattr__(self, name):
59 | """Forward unimplemented attributes to one of the original environments.
60 |
61 | Args:
62 | name: Attribute that was accessed.
63 |
64 | Returns:
65 | Value behind the attribute name one of the wrapped environments.
66 | """
67 | return getattr(self._envs[0], name)
68 |
69 | def step(self, actions):
70 | """Forward a batch of actions to the wrapped environments.
71 |
72 | Args:
73 | actions: Batched action to apply to the environment.
74 |
75 | Raises:
76 | ValueError: Invalid actions.
77 |
78 | Returns:
79 | Batch of observations, rewards, and done flags.
80 | """
81 | for index, (env, action) in enumerate(zip(self._envs, actions)):
82 | if not env.action_space.contains(action):
83 | message = 'Invalid action at index {}: {}'
84 | raise ValueError(message.format(index, action))
85 | if self._blocking:
86 | transitions = [
87 | env.step(action)
88 | for env, action in zip(self._envs, actions)]
89 | else:
90 | transitions = [
91 | env.step(action, blocking=False)
92 | for env, action in zip(self._envs, actions)]
93 | transitions = [transition() for transition in transitions]
94 | observs, rewards, dones, infos = zip(*transitions)
95 | observ = np.stack(observs)
96 | reward = np.stack(rewards)
97 | done = np.stack(dones)
98 | info = tuple(infos)
99 | return observ, reward, done, info
100 |
101 | def reset(self, indices=None):
102 | """Reset the environment and convert the resulting observation.
103 |
104 | Args:
105 | indices: The batch indices of environments to reset; defaults to all.
106 |
107 | Returns:
108 | Batch of observations.
109 | """
110 | if indices is None:
111 | indices = np.arange(len(self._envs))
112 | if self._blocking:
113 | observs = [self._envs[index].reset() for index in indices]
114 | else:
115 | observs = [self._envs[index].reset(blocking=False) for index in indices]
116 | observs = [observ() for observ in observs]
117 | observ = np.stack(observs)
118 | return observ
119 |
120 | def close(self):
121 | """Send close messages to the external process and join them."""
122 | for env in self._envs:
123 | if hasattr(env, 'close'):
124 | env.close()
125 |
--------------------------------------------------------------------------------
/agents/tools/count_weights.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Count learnable parameters."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 |
23 | import numpy as np
24 | import tensorflow as tf
25 |
26 |
27 | def count_weights(scope=None, exclude=None, graph=None):
28 | """Count learnable parameters.
29 |
30 | Args:
31 | scope: Restrict the count to a variable scope.
32 | exclude: Regex to match variable names to exclude.
33 | graph: Operate on a graph other than the current default graph.
34 |
35 | Returns:
36 | Number of learnable parameters as integer.
37 | """
38 | if scope:
39 | scope = scope if scope.endswith('/') else scope + '/'
40 | graph = graph or tf.get_default_graph()
41 | vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
42 | if scope:
43 | vars_ = [var for var in vars_ if var.name.startswith(scope)]
44 | if exclude:
45 | exclude = re.compile(exclude)
46 | vars_ = [var for var in vars_ if not exclude.match(var.name)]
47 | shapes = [var.get_shape().as_list() for var in vars_]
48 | return int(sum(np.prod(shape) for shape in shapes))
49 |
--------------------------------------------------------------------------------
/agents/tools/count_weights_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for the weight counting utility."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from agents.tools import count_weights
24 |
25 |
26 | class CountWeightsTest(tf.test.TestCase):
27 |
28 | def test_count_trainable(self):
29 | tf.Variable(tf.zeros((5, 3)), trainable=True)
30 | tf.Variable(tf.zeros((1, 1)), trainable=True)
31 | tf.Variable(tf.zeros((5,)), trainable=True)
32 | self.assertEqual(15 + 1 + 5, count_weights())
33 |
34 | def test_ignore_non_trainable(self):
35 | tf.Variable(tf.zeros((5, 3)), trainable=False)
36 | tf.Variable(tf.zeros((1, 1)), trainable=False)
37 | tf.Variable(tf.zeros((5,)), trainable=False)
38 | self.assertEqual(0, count_weights())
39 |
40 | def test_trainable_and_non_trainable(self):
41 | tf.Variable(tf.zeros((5, 3)), trainable=True)
42 | tf.Variable(tf.zeros((8, 2)), trainable=False)
43 | tf.Variable(tf.zeros((1, 1)), trainable=True)
44 | tf.Variable(tf.zeros((5,)), trainable=True)
45 | tf.Variable(tf.zeros((3, 1)), trainable=False)
46 | self.assertEqual(15 + 1 + 5, count_weights())
47 |
48 | def test_include_scopes(self):
49 | tf.Variable(tf.zeros((3, 2)), trainable=True)
50 | with tf.variable_scope('foo'):
51 | tf.Variable(tf.zeros((5, 2)), trainable=True)
52 | self.assertEqual(6 + 10, count_weights())
53 |
54 | def test_restrict_scope(self):
55 | tf.Variable(tf.zeros((3, 2)), trainable=True)
56 | with tf.variable_scope('foo'):
57 | tf.Variable(tf.zeros((5, 2)), trainable=True)
58 | with tf.variable_scope('bar'):
59 | tf.Variable(tf.zeros((1, 2)), trainable=True)
60 | self.assertEqual(10 + 2, count_weights('foo'))
61 |
62 | def test_restrict_nested_scope(self):
63 | tf.Variable(tf.zeros((3, 2)), trainable=True)
64 | with tf.variable_scope('foo'):
65 | tf.Variable(tf.zeros((5, 2)), trainable=True)
66 | with tf.variable_scope('bar'):
67 | tf.Variable(tf.zeros((1, 2)), trainable=True)
68 | self.assertEqual(2, count_weights('foo/bar'))
69 |
70 | def test_restrict_invalid_scope(self):
71 | tf.Variable(tf.zeros((3, 2)), trainable=True)
72 | with tf.variable_scope('foo'):
73 | tf.Variable(tf.zeros((5, 2)), trainable=True)
74 | with tf.variable_scope('bar'):
75 | tf.Variable(tf.zeros((1, 2)), trainable=True)
76 | self.assertEqual(0, count_weights('bar'))
77 |
78 | def test_exclude_by_regex(self):
79 | tf.Variable(tf.zeros((3, 2)), trainable=True)
80 | with tf.variable_scope('foo'):
81 | tf.Variable(tf.zeros((5, 2)), trainable=True)
82 | with tf.variable_scope('bar'):
83 | tf.Variable(tf.zeros((1, 2)), trainable=True)
84 | self.assertEqual(0, count_weights(exclude=r'.*'))
85 | self.assertEqual(6, count_weights(exclude=r'(^|/)foo/.*'))
86 | self.assertEqual(16, count_weights(exclude=r'.*/bar/.*'))
87 |
88 | def test_non_default_graph(self):
89 | graph = tf.Graph()
90 | with graph.as_default():
91 | tf.Variable(tf.zeros((5, 3)), trainable=True)
92 | tf.Variable(tf.zeros((8, 2)), trainable=False)
93 | self.assertNotEqual(graph, tf.get_default_graph)
94 | self.assertEqual(15, count_weights(graph=graph))
95 |
96 |
97 | if __name__ == '__main__':
98 | tf.test.main()
99 |
--------------------------------------------------------------------------------
/agents/tools/in_graph_batch_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Batch of environments inside the TensorFlow graph."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import gym
22 | import tensorflow as tf
23 |
24 |
25 | class InGraphBatchEnv(object):
26 | """Batch of environments inside the TensorFlow graph.
27 |
28 | The batch of environments will be stepped and reset inside of the graph using
29 | a tf.py_func(). The current batch of observations, actions, rewards, and done
30 | flags are held in according variables.
31 | """
32 |
33 | def __init__(self, batch_env):
34 | """Batch of environments inside the TensorFlow graph.
35 |
36 | Args:
37 | batch_env: Batch environment.
38 | """
39 | self._batch_env = batch_env
40 | batch_dims = (len(self._batch_env),)
41 | observ_shape = self._parse_shape(self._batch_env.observation_space)
42 | observ_dtype = self._parse_dtype(self._batch_env.observation_space)
43 | action_shape = self._parse_shape(self._batch_env.action_space)
44 | action_dtype = self._parse_dtype(self._batch_env.action_space)
45 | with tf.variable_scope('env_temporary'):
46 | self._observ = tf.Variable(
47 | lambda: tf.zeros(batch_dims + observ_shape, observ_dtype),
48 | name='observ', trainable=False)
49 | self._action = tf.Variable(
50 | lambda: tf.zeros(batch_dims + action_shape, action_dtype),
51 | name='action', trainable=False)
52 | self._reward = tf.Variable(
53 | lambda: tf.zeros(batch_dims, tf.float32),
54 | name='reward', trainable=False)
55 | self._done = tf.Variable(
56 | lambda: tf.cast(tf.ones(batch_dims), tf.bool),
57 | name='done', trainable=False)
58 |
59 | def __getattr__(self, name):
60 | """Forward unimplemented attributes to one of the original environments.
61 |
62 | Args:
63 | name: Attribute that was accessed.
64 |
65 | Returns:
66 | Value behind the attribute name in one of the original environments.
67 | """
68 | return getattr(self._batch_env, name)
69 |
70 | def __len__(self):
71 | """Number of combined environments."""
72 | return len(self._batch_env)
73 |
74 | def __getitem__(self, index):
75 | """Access an underlying environment by index."""
76 | return self._batch_env[index]
77 |
78 | def simulate(self, action):
79 | """Step the batch of environments.
80 |
81 | The results of the step can be accessed from the variables defined below.
82 |
83 | Args:
84 | action: Tensor holding the batch of actions to apply.
85 |
86 | Returns:
87 | Operation.
88 | """
89 | with tf.name_scope('environment/simulate'):
90 | if action.dtype in (tf.float16, tf.float32, tf.float64):
91 | action = tf.check_numerics(action, 'action')
92 | observ_dtype = self._parse_dtype(self._batch_env.observation_space)
93 | observ, reward, done = tf.py_func(
94 | lambda a: self._batch_env.step(a)[:3], [action],
95 | [observ_dtype, tf.float32, tf.bool], name='step')
96 | observ = tf.check_numerics(observ, 'observ')
97 | reward = tf.check_numerics(reward, 'reward')
98 | return tf.group(
99 | self._observ.assign(observ),
100 | self._action.assign(action),
101 | self._reward.assign(reward),
102 | self._done.assign(done))
103 |
104 | def reset(self, indices=None):
105 | """Reset the batch of environments.
106 |
107 | Args:
108 | indices: The batch indices of the environments to reset; defaults to all.
109 |
110 | Returns:
111 | Batch tensor of the new observations.
112 | """
113 | if indices is None:
114 | indices = tf.range(len(self._batch_env))
115 | observ_dtype = self._parse_dtype(self._batch_env.observation_space)
116 | observ = tf.py_func(
117 | self._batch_env.reset, [indices], observ_dtype, name='reset')
118 | observ = tf.check_numerics(observ, 'observ')
119 | reward = tf.zeros_like(indices, tf.float32)
120 | done = tf.zeros_like(indices, tf.bool)
121 | with tf.control_dependencies([
122 | tf.scatter_update(self._observ, indices, observ),
123 | tf.scatter_update(self._reward, indices, reward),
124 | tf.scatter_update(self._done, indices, done)]):
125 | return tf.identity(observ)
126 |
127 | @property
128 | def observ(self):
129 | """Access the variable holding the current observation."""
130 | return self._observ
131 |
132 | @property
133 | def action(self):
134 | """Access the variable holding the last received action."""
135 | return self._action
136 |
137 | @property
138 | def reward(self):
139 | """Access the variable holding the current reward."""
140 | return self._reward
141 |
142 | @property
143 | def done(self):
144 | """Access the variable indicating whether the episode is done."""
145 | return self._done
146 |
147 | def close(self):
148 | """Send close messages to the external process and join them."""
149 | self._batch_env.close()
150 |
151 | def _parse_shape(self, space):
152 | """Get a tensor shape from a OpenAI Gym space.
153 |
154 | Args:
155 | space: Gym space.
156 |
157 | Raises:
158 | NotImplementedError: For spaces other than Box and Discrete.
159 |
160 | Returns:
161 | Shape tuple.
162 | """
163 | if isinstance(space, gym.spaces.Discrete):
164 | return ()
165 | if isinstance(space, gym.spaces.Box):
166 | return space.shape
167 | raise NotImplementedError()
168 |
169 | def _parse_dtype(self, space):
170 | """Get a tensor dtype from a OpenAI Gym space.
171 |
172 | Args:
173 | space: Gym space.
174 |
175 | Raises:
176 | NotImplementedError: For spaces other than Box and Discrete.
177 |
178 | Returns:
179 | TensorFlow data type.
180 | """
181 | if isinstance(space, gym.spaces.Discrete):
182 | return tf.int32
183 | if isinstance(space, gym.spaces.Box):
184 | return tf.float32
185 | raise NotImplementedError()
186 |
--------------------------------------------------------------------------------
/agents/tools/in_graph_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Put an OpenAI Gym environment into the TensorFlow graph."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import gym
22 | import tensorflow as tf
23 |
24 |
25 | class InGraphEnv(object):
26 | """Put an OpenAI Gym environment into the TensorFlow graph.
27 |
28 | The environment will be stepped and reset inside of the graph using
29 | tf.py_func(). The current observation, action, reward, and done flag are held
30 | in according variables.
31 | """
32 |
33 | def __init__(self, env):
34 | """Put an OpenAI Gym environment into the TensorFlow graph.
35 |
36 | Args:
37 | env: OpenAI Gym environment.
38 | """
39 | self._env = env
40 | observ_shape = self._parse_shape(self._env.observation_space)
41 | observ_dtype = self._parse_dtype(self._env.observation_space)
42 | action_shape = self._parse_shape(self._env.action_space)
43 | action_dtype = self._parse_dtype(self._env.action_space)
44 | with tf.name_scope('environment'):
45 | self._observ = tf.Variable(
46 | tf.zeros(observ_shape, observ_dtype), name='observ', trainable=False)
47 | self._action = tf.Variable(
48 | tf.zeros(action_shape, action_dtype), name='action', trainable=False)
49 | self._reward = tf.Variable(
50 | 0.0, dtype=tf.float32, name='reward', trainable=False)
51 | self._done = tf.Variable(
52 | True, dtype=tf.bool, name='done', trainable=False)
53 | self._step = tf.Variable(
54 | 0, dtype=tf.int32, name='step', trainable=False)
55 |
56 | def __getattr__(self, name):
57 | """Forward unimplemented attributes to the original environment.
58 |
59 | Args:
60 | name: Attribute that was accessed.
61 |
62 | Returns:
63 | Value behind the attribute name in the wrapped environment.
64 | """
65 | return getattr(self._env, name)
66 |
67 | def simulate(self, action):
68 | """Step the environment.
69 |
70 | The result of the step can be accessed from the variables defined below.
71 |
72 | Args:
73 | action: Tensor holding the action to apply.
74 |
75 | Returns:
76 | Operation.
77 | """
78 | with tf.name_scope('environment/simulate'):
79 | if action.dtype in (tf.float16, tf.float32, tf.float64):
80 | action = tf.check_numerics(action, 'action')
81 | observ_dtype = self._parse_dtype(self._env.observation_space)
82 | observ, reward, done = tf.py_func(
83 | lambda a: self._env.step(a)[:3], [action],
84 | [observ_dtype, tf.float32, tf.bool], name='step')
85 | observ = tf.check_numerics(observ, 'observ')
86 | reward = tf.check_numerics(reward, 'reward')
87 | return tf.group(
88 | self._observ.assign(observ),
89 | self._action.assign(action),
90 | self._reward.assign(reward),
91 | self._done.assign(done),
92 | self._step.assign_add(1))
93 |
94 | def reset(self):
95 | """Reset the environment.
96 |
97 | Returns:
98 | Tensor of the current observation.
99 | """
100 | observ_dtype = self._parse_dtype(self._env.observation_space)
101 | observ = tf.py_func(self._env.reset, [], observ_dtype, name='reset')
102 | observ = tf.check_numerics(observ, 'observ')
103 | with tf.control_dependencies([
104 | self._observ.assign(observ),
105 | self._reward.assign(0),
106 | self._done.assign(False)]):
107 | return tf.identity(observ)
108 |
109 | @property
110 | def observ(self):
111 | """Access the variable holding the current observation."""
112 | return self._observ
113 |
114 | @property
115 | def action(self):
116 | """Access the variable holding the last received action."""
117 | return self._action
118 |
119 | @property
120 | def reward(self):
121 | """Access the variable holding the current reward."""
122 | return self._reward
123 |
124 | @property
125 | def done(self):
126 | """Access the variable indicating whether the episode is done."""
127 | return self._done
128 |
129 | @property
130 | def step(self):
131 | """Access the variable containing total steps of this environment."""
132 | return self._step
133 |
134 | def _parse_shape(self, space):
135 | """Get a tensor shape from a OpenAI Gym space.
136 |
137 | Args:
138 | space: Gym space.
139 |
140 | Raises:
141 | NotImplementedError: For spaces other than Box and Discrete.
142 |
143 | Returns:
144 | Shape tuple.
145 | """
146 | if isinstance(space, gym.spaces.Discrete):
147 | return ()
148 | if isinstance(space, gym.spaces.Box):
149 | return space.shape
150 | raise NotImplementedError()
151 |
152 | def _parse_dtype(self, space):
153 | """Get a tensor dtype from a OpenAI Gym space.
154 |
155 | Args:
156 | space: Gym space.
157 |
158 | Raises:
159 | NotImplementedError: For spaces other than Box and Discrete.
160 |
161 | Returns:
162 | TensorFlow data type.
163 | """
164 | if isinstance(space, gym.spaces.Discrete):
165 | return tf.int32
166 | if isinstance(space, gym.spaces.Box):
167 | return tf.float32
168 | raise NotImplementedError()
169 |
--------------------------------------------------------------------------------
/agents/tools/loop.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Execute operations in a loop and coordinate logging and checkpoints."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | import tensorflow as tf
25 |
26 | from agents.tools import streaming_mean
27 |
28 |
29 | _Phase = collections.namedtuple(
30 | 'Phase',
31 | 'name, writer, op, batch, steps, feed, report_every, log_every,'
32 | 'checkpoint_every')
33 |
34 |
35 | class Loop(object):
36 | """Execute operations in a loop and coordinate logging and checkpoints.
37 |
38 | Supports multiple phases, that define their own operations to run, and
39 | intervals for reporting scores, logging summaries, and storing checkpoints.
40 | All class state is stored in-graph to properly recover from checkpoints.
41 | """
42 |
43 | def __init__(self, logdir, step=None, log=None, report=None, reset=None):
44 | """Execute operations in a loop and coordinate logging and checkpoints.
45 |
46 | The step, log, report, and report arguments will get created if not
47 | provided. Reset is used to indicate switching to a new phase, so that the
48 | model can start a new computation in case its computation is split over
49 | multiple training steps.
50 |
51 | Args:
52 | logdir: Will contain checkpoints and summaries for each phase.
53 | step: Variable of the global step (optional).
54 | log: Tensor indicating to the model to compute summary tensors.
55 | report: Tensor indicating to the loop to report the current mean score.
56 | reset: Tensor indicating to the model to start a new computation.
57 | """
58 | self._logdir = logdir
59 | self._step = (
60 | tf.Variable(0, False, name='global_step') if step is None else step)
61 | self._log = tf.placeholder(tf.bool) if log is None else log
62 | self._report = tf.placeholder(tf.bool) if report is None else report
63 | self._reset = tf.placeholder(tf.bool) if reset is None else reset
64 | self._phases = []
65 |
66 | def add_phase(
67 | self, name, done, score, summary, steps,
68 | report_every=None, log_every=None, checkpoint_every=None, feed=None):
69 | """Add a phase to the loop protocol.
70 |
71 | If the model breaks long computation into multiple steps, the done tensor
72 | indicates whether the current score should be added to the mean counter.
73 | For example, in reinforcement learning we only have a valid score at the
74 | end of the episode.
75 |
76 | Score and done tensors can either be scalars or vectors, to support
77 | single and batched computations.
78 |
79 | Args:
80 | name: Name for the phase, used for the summary writer.
81 | done: Tensor indicating whether current score can be used.
82 | score: Tensor holding the current, possibly intermediate, score.
83 | summary: Tensor holding summary string to write if not an empty string.
84 | steps: Duration of the phase in steps.
85 | report_every: Yield mean score every this number of steps.
86 | log_every: Request summaries via `log` tensor every this number of steps.
87 | checkpoint_every: Write checkpoint every this number of steps.
88 | feed: Additional feed dictionary for the session run call.
89 |
90 | Raises:
91 | ValueError: Unknown rank for done or score tensors.
92 | """
93 | done = tf.convert_to_tensor(done, tf.bool)
94 | score = tf.convert_to_tensor(score, tf.float32)
95 | summary = tf.convert_to_tensor(summary, tf.string)
96 | feed = feed or {}
97 | if done.shape.ndims is None or score.shape.ndims is None:
98 | raise ValueError("Rank of 'done' and 'score' tensors must be known.")
99 | writer = self._logdir and tf.summary.FileWriter(
100 | os.path.join(self._logdir, name), tf.get_default_graph(),
101 | flush_secs=60)
102 | op = self._define_step(done, score, summary)
103 | batch = 1 if score.shape.ndims == 0 else score.shape[0].value
104 | self._phases.append(_Phase(
105 | name, writer, op, batch, int(steps), feed, report_every,
106 | log_every, checkpoint_every))
107 |
108 | def run(self, sess, saver, max_step=None):
109 | """Run the loop schedule for a specified number of steps.
110 |
111 | Call the operation of the current phase until the global step reaches the
112 | specified maximum step. Phases are repeated over and over in the order they
113 | were added.
114 |
115 | Args:
116 | sess: Session to use to run the phase operation.
117 | saver: Saver used for checkpointing.
118 | max_step: Run the operations until the step reaches this limit.
119 |
120 | Yields:
121 | Reported mean scores.
122 | """
123 | global_step = sess.run(self._step)
124 | steps_made = 1
125 | while True:
126 | if max_step and global_step >= max_step:
127 | break
128 | phase, epoch, steps_in = self._find_current_phase(global_step)
129 | phase_step = epoch * phase.steps + steps_in
130 | if steps_in % phase.steps < steps_made:
131 | message = '\n' + ('-' * 50) + '\n'
132 | message += 'Phase {} (phase step {}, global step {}).'
133 | tf.logging.info(message.format(phase.name, phase_step, global_step))
134 | # Populate book keeping tensors.
135 | phase.feed[self._reset] = (steps_in < steps_made)
136 | phase.feed[self._log] = (
137 | phase.writer and
138 | self._is_every_steps(phase_step, phase.batch, phase.log_every))
139 | phase.feed[self._report] = (
140 | self._is_every_steps(phase_step, phase.batch, phase.report_every))
141 | summary, mean_score, global_step, steps_made = sess.run(
142 | phase.op, phase.feed)
143 | if self._is_every_steps(phase_step, phase.batch, phase.checkpoint_every):
144 | self._store_checkpoint(sess, saver, global_step)
145 | if self._is_every_steps(phase_step, phase.batch, phase.report_every):
146 | yield mean_score
147 | if summary and phase.writer:
148 | # We want smaller phases to catch up at the beginnig of each epoch so
149 | # that their graphs are aligned.
150 | longest_phase = max(phase.steps for phase in self._phases)
151 | summary_step = epoch * longest_phase + steps_in
152 | phase.writer.add_summary(summary, summary_step)
153 |
154 | def _is_every_steps(self, phase_step, batch, every):
155 | """Determine whether a periodic event should happen at this step.
156 |
157 | Args:
158 | phase_step: The incrementing step.
159 | batch: The number of steps progressed at once.
160 | every: The interval of the period.
161 |
162 | Returns:
163 | Boolean of whether the event should happen.
164 | """
165 | if not every:
166 | return False
167 | covered_steps = range(phase_step, phase_step + batch)
168 | return any((step + 1) % every == 0 for step in covered_steps)
169 |
170 | def _find_current_phase(self, global_step):
171 | """Determine the current phase based on the global step.
172 |
173 | This ensures continuing the correct phase after restoring checkoints.
174 |
175 | Args:
176 | global_step: The global number of steps performed across all phases.
177 |
178 | Returns:
179 | Tuple of phase object, epoch number, and phase steps within the epoch.
180 | """
181 | epoch_size = sum(phase.steps for phase in self._phases)
182 | epoch = int(global_step // epoch_size)
183 | steps_in = global_step % epoch_size
184 | for phase in self._phases:
185 | if steps_in < phase.steps:
186 | return phase, epoch, steps_in
187 | steps_in -= phase.steps
188 |
189 | def _define_step(self, done, score, summary):
190 | """Combine operations of a phase.
191 |
192 | Keeps track of the mean score and when to report it.
193 |
194 | Args:
195 | done: Tensor indicating whether current score can be used.
196 | score: Tensor holding the current, possibly intermediate, score.
197 | summary: Tensor holding summary string to write if not an empty string.
198 |
199 | Returns:
200 | Tuple of summary tensor, mean score, and new global step. The mean score
201 | is zero for non reporting steps.
202 | """
203 | if done.shape.ndims == 0:
204 | done = done[None]
205 | if score.shape.ndims == 0:
206 | score = score[None]
207 | score_mean = streaming_mean.StreamingMean((), tf.float32)
208 | with tf.control_dependencies([done, score, summary]):
209 | done_score = tf.gather(score, tf.where(done)[:, 0])
210 | submit_score = tf.cond(
211 | tf.reduce_any(done), lambda: score_mean.submit(done_score), tf.no_op)
212 | with tf.control_dependencies([submit_score]):
213 | mean_score = tf.cond(self._report, score_mean.clear, float)
214 | steps_made = tf.shape(score)[0]
215 | next_step = self._step.assign_add(steps_made)
216 | with tf.control_dependencies([mean_score, next_step]):
217 | return tf.identity(summary), mean_score, next_step, steps_made
218 |
219 | def _store_checkpoint(self, sess, saver, global_step):
220 | """Store a checkpoint if a log directory was provided to the constructor.
221 |
222 | The directory will be created if needed.
223 |
224 | Args:
225 | sess: Session containing variables to store.
226 | saver: Saver used for checkpointing.
227 | global_step: Step number of the checkpoint name.
228 | """
229 | if not self._logdir or not saver:
230 | return
231 | tf.gfile.MakeDirs(self._logdir)
232 | filename = os.path.join(self._logdir, 'model.ckpt')
233 | saver.save(sess, filename, global_step)
234 |
--------------------------------------------------------------------------------
/agents/tools/loop_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for the training loop."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from agents import tools
24 |
25 |
26 | class LoopTest(tf.test.TestCase):
27 |
28 | def test_report_every_step(self):
29 | step = tf.Variable(0, False, dtype=tf.int32, name='step')
30 | loop = tools.Loop(None, step)
31 | loop.add_phase(
32 | 'phase_1', done=True, score=0, summary='', steps=1, report_every=3)
33 | # Step: 0 1 2 3 4 5 6 7 8
34 | # Report: x x x
35 | with self.test_session() as sess:
36 | sess.run(tf.global_variables_initializer())
37 | scores = loop.run(sess, saver=None, max_step=9)
38 | next(scores)
39 | self.assertEqual(3, sess.run(step))
40 | next(scores)
41 | self.assertEqual(6, sess.run(step))
42 | next(scores)
43 | self.assertEqual(9, sess.run(step))
44 |
45 | def test_phases_feed(self):
46 | score = tf.placeholder(tf.float32, [])
47 | loop = tools.Loop(None)
48 | loop.add_phase(
49 | 'phase_1', done=True, score=score, summary='', steps=1, report_every=1,
50 | log_every=None, checkpoint_every=None, feed={score: 1})
51 | loop.add_phase(
52 | 'phase_2', done=True, score=score, summary='', steps=3, report_every=1,
53 | log_every=None, checkpoint_every=None, feed={score: 2})
54 | loop.add_phase(
55 | 'phase_3', done=True, score=score, summary='', steps=2, report_every=1,
56 | log_every=None, checkpoint_every=None, feed={score: 3})
57 | with self.test_session() as sess:
58 | sess.run(tf.global_variables_initializer())
59 | scores = list(loop.run(sess, saver=None, max_step=15))
60 | self.assertAllEqual([1, 2, 2, 2, 3, 3, 1, 2, 2, 2, 3, 3, 1, 2, 2], scores)
61 |
62 | def test_average_score_over_phases(self):
63 | loop = tools.Loop(None)
64 | loop.add_phase(
65 | 'phase_1', done=True, score=1, summary='', steps=1, report_every=2)
66 | loop.add_phase(
67 | 'phase_2', done=True, score=2, summary='', steps=2, report_every=5)
68 | # Score: 1 2 2 1 2 2 1 2 2 1 2 2 1 2 2 1 2
69 | # Report 1: x x x
70 | # Report 2: x x
71 | with self.test_session() as sess:
72 | sess.run(tf.global_variables_initializer())
73 | scores = list(loop.run(sess, saver=None, max_step=17))
74 | self.assertAllEqual([1, 2, 1, 2, 1], scores)
75 |
76 | def test_not_done(self):
77 | step = tf.Variable(0, False, dtype=tf.int32, name='step')
78 | done = tf.equal((step + 1) % 2, 0)
79 | score = tf.cast(step, tf.float32)
80 | loop = tools.Loop(None, step)
81 | loop.add_phase(
82 | 'phase_1', done, score, summary='', steps=1, report_every=3)
83 | # Score: 0 1 2 3 4 5 6 7 8
84 | # Done: x x x x
85 | # Report: x x x
86 | with self.test_session() as sess:
87 | sess.run(tf.global_variables_initializer())
88 | scores = list(loop.run(sess, saver=None, max_step=9))
89 | self.assertAllEqual([1, 4, 7], scores)
90 |
91 | def test_not_done_batch(self):
92 | step = tf.Variable(0, False, dtype=tf.int32, name='step')
93 | done = tf.equal([step % 3, step % 4], 0)
94 | score = tf.cast([step, step ** 2], tf.float32)
95 | loop = tools.Loop(None, step)
96 | loop.add_phase(
97 | 'phase_1', done, score, summary='', steps=1, report_every=8)
98 | # Step: 0 2 4 6
99 | # Score 1: 0 2 4 6
100 | # Done 1: x x
101 | # Score 2: 0 4 16 32
102 | # Done 2: x x
103 | with self.test_session() as sess:
104 | sess.run(tf.global_variables_initializer())
105 | scores = list(loop.run(sess, saver=None, max_step=8))
106 | self.assertEqual(8, sess.run(step))
107 | self.assertAllEqual([(0 + 0 + 16 + 6) / 4], scores)
108 |
109 |
110 | if __name__ == '__main__':
111 | tf.test.main()
112 |
--------------------------------------------------------------------------------
/agents/tools/mock_algorithm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Mock algorithm for testing reinforcement learning code."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 |
24 | class MockAlgorithm(object):
25 | """Produce random actions and empty summaries."""
26 |
27 | def __init__(self, envs):
28 | """Produce random actions and empty summaries.
29 |
30 | Args:
31 | envs: List of in-graph environments.
32 | """
33 | self._envs = envs
34 |
35 | def begin_episode(self, unused_agent_indices):
36 | return tf.constant('')
37 |
38 | def perform(self, agent_indices, unused_observ):
39 | shape = (tf.shape(agent_indices)[0],) + self._envs[0].action_space.shape
40 | low = self._envs[0].action_space.low
41 | high = self._envs[0].action_space.high
42 | action = tf.random_uniform(shape) * (high - low) + low
43 | return action, tf.constant('')
44 |
45 | def experience(self, unused_agent_indices, *unused_transition):
46 | return tf.constant('')
47 |
48 | def end_episode(self, unused_agent_indices):
49 | return tf.constant('')
50 |
--------------------------------------------------------------------------------
/agents/tools/mock_environment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Mock environment for testing reinforcement learning code."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import gym
22 | import gym.spaces
23 | import numpy as np
24 |
25 |
26 | class MockEnvironment(object):
27 | """Generate random agent input and keep track of statistics."""
28 |
29 | def __init__(self, observ_shape, action_shape, min_duration, max_duration):
30 | """Generate random agent input and keep track of statistics.
31 |
32 | Args:
33 | observ_shape: Shape for the random observations.
34 | action_shape: Shape for the action space.
35 | min_duration: Minimum number of steps per episode.
36 | max_duration: Maximum number of steps per episode.
37 |
38 | Attributes:
39 | steps: List of actual simulated lengths for all episodes.
40 | durations: List of decided lengths for all episodes.
41 | """
42 | self._observ_shape = observ_shape
43 | self._action_shape = action_shape
44 | self._min_duration = min_duration
45 | self._max_duration = max_duration
46 | self._random = np.random.RandomState(0)
47 | self.steps = []
48 | self.durations = []
49 |
50 | @property
51 | def observation_space(self):
52 | low = np.zeros(self._observ_shape)
53 | high = np.ones(self._observ_shape)
54 | return gym.spaces.Box(low, high, dtype=np.float32)
55 |
56 | @property
57 | def action_space(self):
58 | low = np.zeros(self._action_shape)
59 | high = np.ones(self._action_shape)
60 | return gym.spaces.Box(low, high, dtype=np.float32)
61 |
62 | @property
63 | def unwrapped(self):
64 | return self
65 |
66 | def step(self, action):
67 | assert self.action_space.contains(action)
68 | assert self.steps[-1] < self.durations[-1]
69 | self.steps[-1] += 1
70 | observ = self._current_observation()
71 | reward = self._current_reward()
72 | done = self.steps[-1] >= self.durations[-1]
73 | info = {}
74 | return observ, reward, done, info
75 |
76 | def reset(self):
77 | duration = self._random.randint(self._min_duration, self._max_duration + 1)
78 | self.steps.append(0)
79 | self.durations.append(duration)
80 | return self._current_observation()
81 |
82 | def _current_observation(self):
83 | return self._random.uniform(0, 1, self._observ_shape)
84 |
85 | def _current_reward(self):
86 | return self._random.uniform(-1, 1)
87 |
--------------------------------------------------------------------------------
/agents/tools/nested.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tools for manipulating nested tuples, list, and dictionaries."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # Disable linter warning for using `flatten` as argument name.
22 | # pylint: disable=redefined-outer-name
23 |
24 | _builtin_zip = zip
25 | _builtin_map = map
26 | _builtin_filter = filter
27 |
28 |
29 | def zip_(*structures, **kwargs):
30 | # pylint: disable=differing-param-doc,missing-param-doc
31 | """Combine corresponding elements in multiple nested structure to tuples.
32 |
33 | The nested structures can consist of any combination of lists, tuples, and
34 | dicts. All provided structures must have the same nesting.
35 |
36 | Args:
37 | *structures: Nested structures.
38 | flatten: Whether to flatten the resulting structure into a tuple. Keys of
39 | dictionaries will be discarded.
40 |
41 | Returns:
42 | Nested structure.
43 | """
44 | # Named keyword arguments are not allowed after *args in Python 2.
45 | flatten = kwargs.pop('flatten', False)
46 | assert not kwargs, 'zip() got unexpected keyword arguments.'
47 | return map(
48 | lambda *x: x if len(x) > 1 else x[0],
49 | *structures,
50 | flatten=flatten)
51 |
52 |
53 | def map_(function, *structures, **kwargs):
54 | # pylint: disable=differing-param-doc,missing-param-doc
55 | """Apply a function to every element in a nested structure.
56 |
57 | If multiple structures are provided as input, their structure must match and
58 | the function will be applied to corresponding groups of elements. The nested
59 | structure can consist of any combination of lists, tuples, and dicts.
60 |
61 | Args:
62 | function: The function to apply to the elements of the structure. Receives
63 | one argument for every structure that is provided.
64 | *structures: One of more nested structures.
65 | flatten: Whether to flatten the resulting structure into a tuple. Keys of
66 | dictionaries will be discarded.
67 |
68 | Returns:
69 | Nested structure.
70 | """
71 | # Named keyword arguments are not allowed after *args in Python 2.
72 | flatten = kwargs.pop('flatten', False)
73 | assert not kwargs, 'map() got unexpected keyword arguments.'
74 |
75 | def impl(function, *structures):
76 | if len(structures) == 0: # pylint: disable=len-as-condition
77 | return structures
78 | if all(isinstance(s, (tuple, list)) for s in structures):
79 | if len(set(len(x) for x in structures)) > 1:
80 | raise ValueError('Cannot merge tuples or lists of different length.')
81 | args = tuple((impl(function, *x) for x in _builtin_zip(*structures)))
82 | if hasattr(structures[0], '_fields'): # namedtuple
83 | return type(structures[0])(*args)
84 | else: # tuple, list
85 | return type(structures[0])(args)
86 | if all(isinstance(s, dict) for s in structures):
87 | if len(set(frozenset(x.keys()) for x in structures)) > 1:
88 | raise ValueError('Cannot merge dicts with different keys.')
89 | merged = {
90 | k: impl(function, *(s[k] for s in structures))
91 | for k in structures[0]}
92 | return type(structures[0])(merged)
93 | return function(*structures)
94 |
95 | result = impl(function, *structures)
96 | if flatten:
97 | result = flatten_(result)
98 | return result
99 |
100 |
101 | def flatten_(structure):
102 | """Combine all leaves of a nested structure into a tuple.
103 |
104 | The nested structure can consist of any combination of tuples, lists, and
105 | dicts. Dictionary keys will be discarded but values will ordered by the
106 | sorting of the keys.
107 |
108 | Args:
109 | structure: Nested structure.
110 |
111 | Returns:
112 | Flat tuple.
113 | """
114 | if isinstance(structure, dict):
115 | if structure:
116 | structure = zip(*sorted(structure.items(), key=lambda x: x[0]))[1]
117 | else:
118 | # Zip doesn't work on an the items of an empty dictionary.
119 | structure = ()
120 | if isinstance(structure, (tuple, list)):
121 | result = []
122 | for element in structure:
123 | result += flatten_(element)
124 | return tuple(result)
125 | return (structure,)
126 |
127 |
128 | def filter_(predicate, *structures, **kwargs):
129 | # pylint: disable=differing-param-doc,missing-param-doc, too-many-branches
130 | """Select elements of a nested structure based on a predicate function.
131 |
132 | If multiple structures are provided as input, their structure must match and
133 | the function will be applied to corresponding groups of elements. The nested
134 | structure can consist of any combination of lists, tuples, and dicts.
135 |
136 | Args:
137 | predicate: The function to determine whether an element should be kept.
138 | Receives one argument for every structure that is provided.
139 | *structures: One of more nested structures.
140 | flatten: Whether to flatten the resulting structure into a tuple. Keys of
141 | dictionaries will be discarded.
142 |
143 | Returns:
144 | Nested structure.
145 | """
146 | # Named keyword arguments are not allowed after *args in Python 2.
147 | flatten = kwargs.pop('flatten', False)
148 | assert not kwargs, 'filter() got unexpected keyword arguments.'
149 |
150 | def impl(predicate, *structures):
151 | if len(structures) == 0: # pylint: disable=len-as-condition
152 | return structures
153 | if all(isinstance(s, (tuple, list)) for s in structures):
154 | if len(set(len(x) for x in structures)) > 1:
155 | raise ValueError('Cannot merge tuples or lists of different length.')
156 | # Only wrap in tuples if more than one structure provided.
157 | if len(structures) > 1:
158 | filtered = (impl(predicate, *x) for x in _builtin_zip(*structures))
159 | else:
160 | filtered = (impl(predicate, x) for x in structures[0])
161 | # Remove empty containers and construct result structure.
162 | if hasattr(structures[0], '_fields'): # namedtuple
163 | filtered = (x if x != () else None for x in filtered)
164 | return type(structures[0])(*filtered)
165 | else: # tuple, list
166 | filtered = (
167 | x for x in filtered if not isinstance(x, (tuple, list, dict)) or x)
168 | return type(structures[0])(filtered)
169 | if all(isinstance(s, dict) for s in structures):
170 | if len(set(frozenset(x.keys()) for x in structures)) > 1:
171 | raise ValueError('Cannot merge dicts with different keys.')
172 | # Only wrap in tuples if more than one structure provided.
173 | if len(structures) > 1:
174 | filtered = {
175 | k: impl(predicate, *(s[k] for s in structures))
176 | for k in structures[0]}
177 | else:
178 | filtered = {k: impl(predicate, v) for k, v in structures[0].items()}
179 | # Remove empty containers and construct result structure.
180 | filtered = {
181 | k: v for k, v in filtered.items()
182 | if not isinstance(v, (tuple, list, dict)) or v}
183 | return type(structures[0])(filtered)
184 | if len(structures) > 1:
185 | return structures if predicate(*structures) else ()
186 | else:
187 | return structures[0] if predicate(structures[0]) else ()
188 |
189 | result = impl(predicate, *structures)
190 | if flatten:
191 | result = flatten_(result)
192 | return result
193 |
194 |
195 | # pylint: disable=redefined-builtin
196 | zip = zip_
197 | map = map_
198 | flatten = flatten_
199 | filter = filter_
200 |
--------------------------------------------------------------------------------
/agents/tools/nested_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests of tools for managing nested structures."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 |
23 | import tensorflow as tf
24 |
25 | from agents.tools import nested
26 |
27 |
28 | class ZipTest(tf.test.TestCase):
29 |
30 | def test_scalar(self):
31 | self.assertEqual(42, nested.zip(42))
32 | self.assertEqual((13, 42), nested.zip(13, 42))
33 |
34 | def test_empty(self):
35 | self.assertEqual({}, nested.zip({}, {}))
36 |
37 | def test_base_case(self):
38 | self.assertEqual((1, 2, 3), nested.zip(1, 2, 3))
39 |
40 | def test_shallow_list(self):
41 | a = [1, 2, 3]
42 | b = [4, 5, 6]
43 | c = [7, 8, 9]
44 | result = nested.zip(a, b, c)
45 | self.assertEqual([(1, 4, 7), (2, 5, 8), (3, 6, 9)], result)
46 |
47 | def test_shallow_tuple(self):
48 | a = (1, 2, 3)
49 | b = (4, 5, 6)
50 | c = (7, 8, 9)
51 | result = nested.zip(a, b, c)
52 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result)
53 |
54 | def test_shallow_dict(self):
55 | a = {'a': 1, 'b': 2, 'c': 3}
56 | b = {'a': 4, 'b': 5, 'c': 6}
57 | c = {'a': 7, 'b': 8, 'c': 9}
58 | result = nested.zip(a, b, c)
59 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8), 'c': (3, 6, 9)}, result)
60 |
61 | def test_single(self):
62 | a = [[1, 2], 3]
63 | result = nested.zip(a)
64 | self.assertEqual(a, result)
65 |
66 | def test_mixed_structures(self):
67 | a = [(1, 2), 3, {'foo': [4]}]
68 | b = [(5, 6), 7, {'foo': [8]}]
69 | result = nested.zip(a, b)
70 | self.assertEqual([((1, 5), (2, 6)), (3, 7), {'foo': [(4, 8)]}], result)
71 |
72 | def test_different_types(self):
73 | a = [1, 2, 3]
74 | b = 'a b c'.split()
75 | result = nested.zip(a, b)
76 | self.assertEqual([(1, 'a'), (2, 'b'), (3, 'c')], result)
77 |
78 | def test_use_type_of_first(self):
79 | a = (1, 2, 3)
80 | b = [4, 5, 6]
81 | c = [7, 8, 9]
82 | result = nested.zip(a, b, c)
83 | self.assertEqual(((1, 4, 7), (2, 5, 8), (3, 6, 9)), result)
84 |
85 | def test_namedtuple(self):
86 | Foo = collections.namedtuple('Foo', 'value')
87 | foo, bar = Foo(42), Foo(13)
88 | self.assertEqual(Foo((42, 13)), nested.zip(foo, bar))
89 |
90 |
91 | class MapTest(tf.test.TestCase):
92 |
93 | def test_scalar(self):
94 | self.assertEqual(42, nested.map(lambda x: x, 42))
95 |
96 | def test_empty(self):
97 | self.assertEqual({}, nested.map(lambda x: x, {}))
98 |
99 | def test_shallow_list(self):
100 | self.assertEqual([2, 4, 6], nested.map(lambda x: 2 * x, [1, 2, 3]))
101 |
102 | def test_shallow_dict(self):
103 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
104 | self.assertEqual(data, nested.map(lambda x: x, data))
105 |
106 | def test_mixed_structure(self):
107 | structure = [(1, 2), 3, {'foo': [4]}]
108 | result = nested.map(lambda x: 2 * x, structure)
109 | self.assertEqual([(2, 4), 6, {'foo': [8]}], result)
110 |
111 | def test_mixed_types(self):
112 | self.assertEqual([14, 'foofoo'], nested.map(lambda x: x * 2, [7, 'foo']))
113 |
114 | def test_multiple_lists(self):
115 | a = [1, 2, 3]
116 | b = [4, 5, 6]
117 | c = [7, 8, 9]
118 | result = nested.map(lambda x, y, z: x + y + z, a, b, c)
119 | self.assertEqual([12, 15, 18], result)
120 |
121 | def test_namedtuple(self):
122 | Foo = collections.namedtuple('Foo', 'value')
123 | foo, bar = [Foo(42)], [Foo(13)]
124 | function = nested.map(lambda x, y: (y, x), foo, bar)
125 | self.assertEqual([Foo((13, 42))], function)
126 | function = nested.map(lambda x, y: x + y, foo, bar)
127 | self.assertEqual([Foo(55)], function)
128 |
129 |
130 | class FlattenTest(tf.test.TestCase):
131 |
132 | def test_scalar(self):
133 | self.assertEqual((42,), nested.flatten(42))
134 |
135 | def test_empty(self):
136 | self.assertEqual((), nested.flatten({}))
137 |
138 | def test_base_case(self):
139 | self.assertEqual((1,), nested.flatten(1))
140 |
141 | def test_convert_type(self):
142 | self.assertEqual((1, 2, 3), nested.flatten([1, 2, 3]))
143 |
144 | def test_mixed_structure(self):
145 | self.assertEqual((1, 2, 3, 4), nested.flatten([(1, 2), 3, {'foo': [4]}]))
146 |
147 | def test_value_ordering(self):
148 | self.assertEqual((1, 2, 3), nested.flatten({'a': 1, 'b': 2, 'c': 3}))
149 |
150 |
151 | class FilterTest(tf.test.TestCase):
152 |
153 | def test_empty(self):
154 | self.assertEqual({}, nested.filter(lambda x: True, {}))
155 | self.assertEqual({}, nested.filter(lambda x: False, {}))
156 |
157 | def test_base_case(self):
158 | self.assertEqual((), nested.filter(lambda x: False, 1))
159 |
160 | def test_single_dict(self):
161 | predicate = lambda x: x % 2 == 0
162 | data = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
163 | self.assertEqual({'b': 2, 'd': 4}, nested.filter(predicate, data))
164 |
165 | def test_multiple_lists(self):
166 | a = [1, 2, 3]
167 | b = [4, 5, 6]
168 | c = [7, 8, 9]
169 | predicate = lambda *args: any(x % 4 == 0 for x in args)
170 | result = nested.filter(predicate, a, b, c)
171 | self.assertEqual([(1, 4, 7), (2, 5, 8)], result)
172 |
173 | def test_multiple_dicts(self):
174 | a = {'a': 1, 'b': 2, 'c': 3}
175 | b = {'a': 4, 'b': 5, 'c': 6}
176 | c = {'a': 7, 'b': 8, 'c': 9}
177 | predicate = lambda *args: any(x % 4 == 0 for x in args)
178 | result = nested.filter(predicate, a, b, c)
179 | self.assertEqual({'a': (1, 4, 7), 'b': (2, 5, 8)}, result)
180 |
181 | def test_mixed_structure(self):
182 | predicate = lambda x: x % 2 == 0
183 | data = [(1, 2), 3, {'foo': [4]}]
184 | self.assertEqual([(2,), {'foo': [4]}], nested.filter(predicate, data))
185 |
186 | def test_remove_empty_containers(self):
187 | data = [(1, 2, 3), 4, {'foo': [5, 6], 'bar': 7}]
188 | self.assertEqual([], nested.filter(lambda x: False, data))
189 |
190 | def test_namedtuple(self):
191 | Foo = collections.namedtuple('Foo', 'value1, value2')
192 | self.assertEqual(Foo(1, None), nested.filter(lambda x: x == 1, Foo(1, 2)))
193 |
194 | def test_namedtuple_multiple(self):
195 | Foo = collections.namedtuple('Foo', 'value1, value2')
196 | foo = Foo(1, 2)
197 | bar = Foo(2, 3)
198 | result = nested.filter(lambda x, y: x + y > 3, foo, bar)
199 | self.assertEqual(Foo(None, (2, 3)), result)
200 |
201 | def test_namedtuple_nested(self):
202 | Foo = collections.namedtuple('Foo', 'value1, value2')
203 | foo = Foo(1, [1, 2, 3])
204 | self.assertEqual(Foo(None, [2, 3]), nested.filter(lambda x: x > 1, foo))
205 |
--------------------------------------------------------------------------------
/agents/tools/simulate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """In-graph simulation step of a vectorized algorithm with environments."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from agents.tools import streaming_mean
24 |
25 |
26 | def simulate(batch_env, algo, log=True, reset=False):
27 | """Simulation step of a vectorized algorithm with in-graph environments.
28 |
29 | Integrates the operations implemented by the algorithm and the environments
30 | into a combined operation.
31 |
32 | Args:
33 | batch_env: In-graph batch environment.
34 | algo: Algorithm instance implementing required operations.
35 | log: Tensor indicating whether to compute and return summaries.
36 | reset: Tensor causing all environments to reset.
37 |
38 | Returns:
39 | Tuple of tensors containing done flags for the current episodes, possibly
40 | intermediate scores for the episodes, and a summary tensor.
41 | """
42 |
43 | def _define_begin_episode(agent_indices):
44 | """Reset environments, intermediate scores and durations for new episodes.
45 |
46 | Args:
47 | agent_indices: Tensor containing batch indices starting an episode.
48 |
49 | Returns:
50 | Summary tensor.
51 | """
52 | assert agent_indices.shape.ndims == 1
53 | zero_scores = tf.zeros_like(agent_indices, tf.float32)
54 | zero_durations = tf.zeros_like(agent_indices)
55 | reset_ops = [
56 | batch_env.reset(agent_indices),
57 | tf.scatter_update(score, agent_indices, zero_scores),
58 | tf.scatter_update(length, agent_indices, zero_durations)]
59 | with tf.control_dependencies(reset_ops):
60 | return algo.begin_episode(agent_indices)
61 |
62 | def _define_step():
63 | """Request actions from the algorithm and apply them to the environments.
64 |
65 | Increments the lengths of all episodes and increases their scores by the
66 | current reward. After stepping the environments, provides the full
67 | transition tuple to the algorithm.
68 |
69 | Returns:
70 | Summary tensor.
71 | """
72 | prevob = batch_env.observ + 0 # Ensure a copy of the variable value.
73 | agent_indices = tf.range(len(batch_env))
74 | action, step_summary = algo.perform(agent_indices, prevob)
75 | action.set_shape(batch_env.action.shape)
76 | with tf.control_dependencies([batch_env.simulate(action)]):
77 | add_score = score.assign_add(batch_env.reward)
78 | inc_length = length.assign_add(tf.ones(len(batch_env), tf.int32))
79 | with tf.control_dependencies([add_score, inc_length]):
80 | agent_indices = tf.range(len(batch_env))
81 | experience_summary = algo.experience(
82 | agent_indices, prevob, batch_env.action, batch_env.reward,
83 | batch_env.done, batch_env.observ)
84 | return tf.summary.merge([step_summary, experience_summary])
85 |
86 | def _define_end_episode(agent_indices):
87 | """Notify the algorithm of ending episodes.
88 |
89 | Also updates the mean score and length counters used for summaries.
90 |
91 | Args:
92 | agent_indices: Tensor holding batch indices that end their episodes.
93 |
94 | Returns:
95 | Summary tensor.
96 | """
97 | assert agent_indices.shape.ndims == 1
98 | submit_score = mean_score.submit(tf.gather(score, agent_indices))
99 | submit_length = mean_length.submit(
100 | tf.cast(tf.gather(length, agent_indices), tf.float32))
101 | with tf.control_dependencies([submit_score, submit_length]):
102 | return algo.end_episode(agent_indices)
103 |
104 | def _define_summaries():
105 | """Reset the average score and duration, and return them as summary.
106 |
107 | Returns:
108 | Summary string.
109 | """
110 | score_summary = tf.cond(
111 | tf.logical_and(log, tf.cast(mean_score.count, tf.bool)),
112 | lambda: tf.summary.scalar('mean_score', mean_score.clear()), str)
113 | length_summary = tf.cond(
114 | tf.logical_and(log, tf.cast(mean_length.count, tf.bool)),
115 | lambda: tf.summary.scalar('mean_length', mean_length.clear()), str)
116 | return tf.summary.merge([score_summary, length_summary])
117 |
118 | with tf.name_scope('simulate'):
119 | log = tf.convert_to_tensor(log)
120 | reset = tf.convert_to_tensor(reset)
121 | with tf.variable_scope('simulate_temporary'):
122 | score = tf.Variable(
123 | lambda: tf.zeros(len(batch_env), dtype=tf.float32),
124 | trainable=False, name='score')
125 | length = tf.Variable(
126 | lambda: tf.zeros(len(batch_env), dtype=tf.int32),
127 | trainable=False, name='length')
128 | mean_score = streaming_mean.StreamingMean((), tf.float32)
129 | mean_length = streaming_mean.StreamingMean((), tf.float32)
130 | agent_indices = tf.cond(
131 | reset,
132 | lambda: tf.range(len(batch_env)),
133 | lambda: tf.cast(tf.where(batch_env.done)[:, 0], tf.int32))
134 | begin_episode = tf.cond(
135 | tf.cast(tf.shape(agent_indices)[0], tf.bool),
136 | lambda: _define_begin_episode(agent_indices), str)
137 | with tf.control_dependencies([begin_episode]):
138 | step = _define_step()
139 | with tf.control_dependencies([step]):
140 | agent_indices = tf.cast(tf.where(batch_env.done)[:, 0], tf.int32)
141 | end_episode = tf.cond(
142 | tf.cast(tf.shape(agent_indices)[0], tf.bool),
143 | lambda: _define_end_episode(agent_indices), str)
144 | with tf.control_dependencies([end_episode]):
145 | summary = tf.summary.merge([
146 | _define_summaries(), begin_episode, step, end_episode])
147 | with tf.control_dependencies([summary]):
148 | done, score = tf.identity(batch_env.done), tf.identity(score)
149 | return done, score, summary
150 |
--------------------------------------------------------------------------------
/agents/tools/simulate_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for the simulation operation."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from agents import tools
24 |
25 |
26 | class SimulateTest(tf.test.TestCase):
27 |
28 | def test_done_automatic(self):
29 | batch_env = self._create_test_batch_env((1, 2, 3, 4))
30 | algo = tools.MockAlgorithm(batch_env)
31 | done, _, _ = tools.simulate(batch_env, algo, log=False, reset=False)
32 | with self.test_session() as sess:
33 | sess.run(tf.global_variables_initializer())
34 | self.assertAllEqual([True, False, False, False], sess.run(done))
35 | self.assertAllEqual([True, True, False, False], sess.run(done))
36 | self.assertAllEqual([True, False, True, False], sess.run(done))
37 | self.assertAllEqual([True, True, False, True], sess.run(done))
38 |
39 | def test_done_forced(self):
40 | reset = tf.placeholder_with_default(False, ())
41 | batch_env = self._create_test_batch_env((2, 4))
42 | algo = tools.MockAlgorithm(batch_env)
43 | done, _, _ = tools.simulate(batch_env, algo, False, reset)
44 | with self.test_session() as sess:
45 | sess.run(tf.global_variables_initializer())
46 | self.assertAllEqual([False, False], sess.run(done))
47 | self.assertAllEqual([False, False], sess.run(done, {reset: True}))
48 | self.assertAllEqual([True, False], sess.run(done))
49 | self.assertAllEqual([False, False], sess.run(done, {reset: True}))
50 | self.assertAllEqual([True, False], sess.run(done))
51 | self.assertAllEqual([False, False], sess.run(done))
52 | self.assertAllEqual([True, True], sess.run(done))
53 |
54 | def test_reset_automatic(self):
55 | batch_env = self._create_test_batch_env((1, 2, 3, 4))
56 | algo = tools.MockAlgorithm(batch_env)
57 | done, _, _ = tools.simulate(batch_env, algo, log=False, reset=False)
58 | with self.test_session() as sess:
59 | sess.run(tf.global_variables_initializer())
60 | for _ in range(10):
61 | sess.run(done)
62 | self.assertAllEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], batch_env[0].steps)
63 | self.assertAllEqual([2, 2, 2, 2, 2], batch_env[1].steps)
64 | self.assertAllEqual([3, 3, 3, 1], batch_env[2].steps)
65 | self.assertAllEqual([4, 4, 2], batch_env[3].steps)
66 |
67 | def test_reset_forced(self):
68 | reset = tf.placeholder_with_default(False, ())
69 | batch_env = self._create_test_batch_env((2, 4))
70 | algo = tools.MockAlgorithm(batch_env)
71 | done, _, _ = tools.simulate(batch_env, algo, False, reset)
72 | with self.test_session() as sess:
73 | sess.run(tf.global_variables_initializer())
74 | sess.run(done)
75 | sess.run(done, {reset: True})
76 | sess.run(done)
77 | sess.run(done, {reset: True})
78 | sess.run(done)
79 | sess.run(done)
80 | sess.run(done)
81 | self.assertAllEqual([1, 2, 2, 2], batch_env[0].steps)
82 | self.assertAllEqual([1, 2, 4], batch_env[1].steps)
83 |
84 | def _create_test_batch_env(self, durations):
85 | envs = []
86 | for duration in durations:
87 | env = tools.MockEnvironment(
88 | observ_shape=(2, 3), action_shape=(3,),
89 | min_duration=duration, max_duration=duration)
90 | env = tools.wrappers.ConvertTo32Bit(env)
91 | envs.append(env)
92 | batch_env = tools.BatchEnv(envs, blocking=True)
93 | batch_env = tools.InGraphBatchEnv(batch_env)
94 | return batch_env
95 |
96 |
97 | if __name__ == '__main__':
98 | tf.test.main()
99 |
--------------------------------------------------------------------------------
/agents/tools/streaming_mean.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Compute a streaming estimation of the mean of submitted tensors."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 |
24 | class StreamingMean(object):
25 | """Compute a streaming estimation of the mean of submitted tensors."""
26 |
27 | def __init__(self, shape, dtype):
28 | """Specify the shape and dtype of the mean to be estimated.
29 |
30 | Note that a float mean to zero submitted elements is NaN, while computing
31 | the integer mean of zero elements raises a division by zero error.
32 |
33 | Args:
34 | shape: Shape of the mean to compute.
35 | dtype: Data type of the mean to compute.
36 | """
37 | self._dtype = dtype
38 | self._sum = tf.Variable(lambda: tf.zeros(shape, dtype), False)
39 | self._count = tf.Variable(lambda: 0, trainable=False)
40 |
41 | @property
42 | def value(self):
43 | """The current value of the mean."""
44 | return self._sum / tf.cast(self._count, self._dtype)
45 |
46 | @property
47 | def count(self):
48 | """The number of submitted samples."""
49 | return self._count
50 |
51 | def submit(self, value):
52 | """Submit a single or batch tensor to refine the streaming mean."""
53 | # Add a batch dimension if necessary.
54 | if value.shape.ndims == self._sum.shape.ndims:
55 | value = value[None, ...]
56 | return tf.group(
57 | self._sum.assign_add(tf.reduce_sum(value, 0)),
58 | self._count.assign_add(tf.shape(value)[0]))
59 |
60 | def clear(self):
61 | """Return the mean estimate and reset the streaming statistics."""
62 | value = self._sum / tf.cast(self._count, self._dtype)
63 | with tf.control_dependencies([value]):
64 | reset_value = self._sum.assign(tf.zeros_like(self._sum))
65 | reset_count = self._count.assign(0)
66 | with tf.control_dependencies([reset_value, reset_count]):
67 | return tf.identity(value)
68 |
--------------------------------------------------------------------------------
/agents/tools/wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Wrappers for OpenAI Gym environments."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import atexit
22 | import multiprocessing
23 | import sys
24 | import traceback
25 |
26 | import gym
27 | import gym.spaces
28 | import numpy as np
29 | import tensorflow as tf
30 |
31 |
32 | class AutoReset(object):
33 | """Automatically reset environment when the episode is done."""
34 |
35 | def __init__(self, env):
36 | self._env = env
37 | self._done = True
38 |
39 | def __getattr__(self, name):
40 | return getattr(self._env, name)
41 |
42 | def step(self, action):
43 | if self._done:
44 | observ, reward, done, info = self._env.reset(), 0.0, False, {}
45 | else:
46 | observ, reward, done, info = self._env.step(action)
47 | self._done = done
48 | return observ, reward, done, info
49 |
50 | def reset(self):
51 | self._done = False
52 | return self._env.reset()
53 |
54 |
55 | class ActionRepeat(object):
56 | """Repeat the agent action multiple steps."""
57 |
58 | def __init__(self, env, amount):
59 | self._env = env
60 | self._amount = amount
61 |
62 | def __getattr__(self, name):
63 | return getattr(self._env, name)
64 |
65 | def step(self, action):
66 | done = False
67 | total_reward = 0
68 | current_step = 0
69 | while current_step < self._amount and not done:
70 | observ, reward, done, info = self._env.step(action)
71 | total_reward += reward
72 | current_step += 1
73 | return observ, total_reward, done, info
74 |
75 |
76 | class RandomStart(object):
77 | """Perform random number of random actions at the start of the episode."""
78 |
79 | def __init__(self, env, max_steps):
80 | self._env = env
81 | self._max_steps = max_steps
82 |
83 | def __getattr__(self, name):
84 | return getattr(self._env, name)
85 |
86 | def reset(self):
87 | observ = self._env.reset()
88 | random_steps = np.random.randint(0, self._max_steps)
89 | for _ in range(random_steps):
90 | action = self._env.action_space.sample()
91 | observ, unused_reward, done, unused_info = self._env.step(action)
92 | if done:
93 | tf.logging.warning('Episode ended during random start.')
94 | return self.reset()
95 | return observ
96 |
97 |
98 | class FrameHistory(object):
99 | """Augment the observation with past observations."""
100 |
101 | def __init__(self, env, past_indices, flatten):
102 | """Augment the observation with past observations.
103 |
104 | Implemented as a Numpy ring buffer holding the necessary past observations.
105 |
106 | Args:
107 | env: OpenAI Gym environment to wrap.
108 | past_indices: List of non-negative integers indicating the time offsets
109 | from the current time step of observations to include.
110 | flatten: Concatenate the past observations rather than stacking them.
111 |
112 | Raises:
113 | KeyError: The current observation is not included in the indices.
114 | """
115 | if 0 not in past_indices:
116 | raise KeyError('Past indices should include 0 for the current frame.')
117 | self._env = env
118 | self._past_indices = past_indices
119 | self._step = 0
120 | self._buffer = None
121 | self._capacity = max(past_indices) + 1
122 | self._flatten = flatten
123 |
124 | def __getattr__(self, name):
125 | return getattr(self._env, name)
126 |
127 | @property
128 | def observation_space(self):
129 | low = self._env.observation_space.low
130 | high = self._env.observation_space.high
131 | low = np.repeat(low[None, ...], len(self._past_indices), 0)
132 | high = np.repeat(high[None, ...], len(self._past_indices), 0)
133 | if self._flatten:
134 | low = np.reshape(low, (-1,) + low.shape[2:])
135 | high = np.reshape(high, (-1,) + high.shape[2:])
136 | return gym.spaces.Box(low, high, dtype=np.float32)
137 |
138 | def step(self, action):
139 | observ, reward, done, info = self._env.step(action)
140 | self._step += 1
141 | self._buffer[self._step % self._capacity] = observ
142 | observ = self._select_frames()
143 | return observ, reward, done, info
144 |
145 | def reset(self):
146 | observ = self._env.reset()
147 | self._buffer = np.repeat(observ[None, ...], self._capacity, 0)
148 | self._step = 0
149 | return self._select_frames()
150 |
151 | def _select_frames(self):
152 | indices = [
153 | (self._step - index) % self._capacity for index in self._past_indices]
154 | observ = self._buffer[indices]
155 | if self._flatten:
156 | observ = np.reshape(observ, (-1,) + observ.shape[2:])
157 | return observ
158 |
159 |
160 | class FrameDelta(object):
161 | """Convert the observation to a difference from the previous observation."""
162 |
163 | def __init__(self, env):
164 | self._env = env
165 | self._last = None
166 |
167 | def __getattr__(self, name):
168 | return getattr(self._env, name)
169 |
170 | @property
171 | def observation_space(self):
172 | low = self._env.observation_space.low
173 | high = self._env.observation_space.high
174 | low, high = low - high, high - low
175 | return gym.spaces.Box(low, high, dtype=np.float32)
176 |
177 | def step(self, action):
178 | observ, reward, done, info = self._env.step(action)
179 | delta = observ - self._last
180 | self._last = observ
181 | return delta, reward, done, info
182 |
183 | def reset(self):
184 | observ = self._env.reset()
185 | self._last = observ
186 | return observ
187 |
188 |
189 | class RangeNormalize(object):
190 | """Normalize the specialized observation and action ranges to [-1, 1]."""
191 |
192 | def __init__(self, env, observ=None, action=None):
193 | self._env = env
194 | self._should_normalize_observ = (
195 | observ is not False and self._is_finite(self._env.observation_space))
196 | if observ is True and not self._should_normalize_observ:
197 | raise ValueError('Cannot normalize infinite observation range.')
198 | if observ is None and not self._should_normalize_observ:
199 | tf.logging.info('Not normalizing infinite observation range.')
200 | self._should_normalize_action = (
201 | action is not False and self._is_finite(self._env.action_space))
202 | if action is True and not self._should_normalize_action:
203 | raise ValueError('Cannot normalize infinite action range.')
204 | if action is None and not self._should_normalize_action:
205 | tf.logging.info('Not normalizing infinite action range.')
206 |
207 | def __getattr__(self, name):
208 | return getattr(self._env, name)
209 |
210 | @property
211 | def observation_space(self):
212 | space = self._env.observation_space
213 | if not self._should_normalize_observ:
214 | return space
215 | low, high = -np.ones(space.shape), np.ones(space.shape)
216 | return gym.spaces.Box(low, high, dtype=np.float32)
217 |
218 | @property
219 | def action_space(self):
220 | space = self._env.action_space
221 | if not self._should_normalize_action:
222 | return space
223 | low, high = -np.ones(space.shape), np.ones(space.shape)
224 | return gym.spaces.Box(low, high, dtype=np.float32)
225 |
226 | def step(self, action):
227 | if self._should_normalize_action:
228 | action = self._denormalize_action(action)
229 | observ, reward, done, info = self._env.step(action)
230 | if self._should_normalize_observ:
231 | observ = self._normalize_observ(observ)
232 | return observ, reward, done, info
233 |
234 | def reset(self):
235 | observ = self._env.reset()
236 | if self._should_normalize_observ:
237 | observ = self._normalize_observ(observ)
238 | return observ
239 |
240 | def _denormalize_action(self, action):
241 | min_ = self._env.action_space.low
242 | max_ = self._env.action_space.high
243 | action = (action + 1) / 2 * (max_ - min_) + min_
244 | return action
245 |
246 | def _normalize_observ(self, observ):
247 | min_ = self._env.observation_space.low
248 | max_ = self._env.observation_space.high
249 | observ = 2 * (observ - min_) / (max_ - min_) - 1
250 | return observ
251 |
252 | def _is_finite(self, space):
253 | return np.isfinite(space.low).all() and np.isfinite(space.high).all()
254 |
255 |
256 | class ClipAction(object):
257 | """Clip out of range actions to the action space of the environment."""
258 |
259 | def __init__(self, env):
260 | self._env = env
261 |
262 | def __getattr__(self, name):
263 | return getattr(self._env, name)
264 |
265 | @property
266 | def action_space(self):
267 | shape = self._env.action_space.shape
268 | low, high = -np.inf * np.ones(shape), np.inf * np.ones(shape)
269 | return gym.spaces.Box(low, high, dtype=np.float32)
270 |
271 | def step(self, action):
272 | action_space = self._env.action_space
273 | action = np.clip(action, action_space.low, action_space.high)
274 | return self._env.step(action)
275 |
276 |
277 | class LimitDuration(object):
278 | """End episodes after specified number of steps."""
279 |
280 | def __init__(self, env, duration):
281 | self._env = env
282 | self._duration = duration
283 | self._step = None
284 |
285 | def __getattr__(self, name):
286 | return getattr(self._env, name)
287 |
288 | def step(self, action):
289 | if self._step is None:
290 | raise RuntimeError('Must reset environment.')
291 | observ, reward, done, info = self._env.step(action)
292 | self._step += 1
293 | if self._step >= self._duration:
294 | done = True
295 | self._step = None
296 | return observ, reward, done, info
297 |
298 | def reset(self):
299 | self._step = 0
300 | return self._env.reset()
301 |
302 |
303 | class ExternalProcess(object):
304 | """Step environment in a separate process for lock free paralellism."""
305 |
306 | # Message types for communication via the pipe.
307 | _ACCESS = 1
308 | _CALL = 2
309 | _RESULT = 3
310 | _EXCEPTION = 4
311 | _CLOSE = 5
312 |
313 | def __init__(self, constructor):
314 | """Step environment in a separate process for lock free parallelism.
315 |
316 | The environment will be created in the external process by calling the
317 | specified callable. This can be an environment class, or a function
318 | creating the environment and potentially wrapping it. The returned
319 | environment should not access global variables.
320 |
321 | Args:
322 | constructor: Callable that creates and returns an OpenAI gym environment.
323 |
324 | Attributes:
325 | observation_space: The cached observation space of the environment.
326 | action_space: The cached action space of the environment.
327 | """
328 | self._conn, conn = multiprocessing.Pipe()
329 | self._process = multiprocessing.Process(
330 | target=self._worker, args=(constructor, conn))
331 | atexit.register(self.close)
332 | self._process.start()
333 | self._observ_space = None
334 | self._action_space = None
335 |
336 | @property
337 | def observation_space(self):
338 | if not self._observ_space:
339 | self._observ_space = self.__getattr__('observation_space')
340 | return self._observ_space
341 |
342 | @property
343 | def action_space(self):
344 | if not self._action_space:
345 | self._action_space = self.__getattr__('action_space')
346 | return self._action_space
347 |
348 | def __getattr__(self, name):
349 | """Request an attribute from the environment.
350 |
351 | Note that this involves communication with the external process, so it can
352 | be slow.
353 |
354 | Args:
355 | name: Attribute to access.
356 |
357 | Returns:
358 | Value of the attribute.
359 | """
360 | self._conn.send((self._ACCESS, name))
361 | return self._receive()
362 |
363 | def call(self, name, *args, **kwargs):
364 | """Asynchronously call a method of the external environment.
365 |
366 | Args:
367 | name: Name of the method to call.
368 | *args: Positional arguments to forward to the method.
369 | **kwargs: Keyword arguments to forward to the method.
370 |
371 | Returns:
372 | Promise object that blocks and provides the return value when called.
373 | """
374 | payload = name, args, kwargs
375 | self._conn.send((self._CALL, payload))
376 | return self._receive
377 |
378 | def close(self):
379 | """Send a close message to the external process and join it."""
380 | try:
381 | self._conn.send((self._CLOSE, None))
382 | self._conn.close()
383 | except IOError:
384 | # The connection was already closed.
385 | pass
386 | self._process.join()
387 |
388 | def step(self, action, blocking=True):
389 | """Step the environment.
390 |
391 | Args:
392 | action: The action to apply to the environment.
393 | blocking: Whether to wait for the result.
394 |
395 | Returns:
396 | Transition tuple when blocking, otherwise callable that returns the
397 | transition tuple.
398 | """
399 | promise = self.call('step', action)
400 | if blocking:
401 | return promise()
402 | else:
403 | return promise
404 |
405 | def reset(self, blocking=True):
406 | """Reset the environment.
407 |
408 | Args:
409 | blocking: Whether to wait for the result.
410 |
411 | Returns:
412 | New observation when blocking, otherwise callable that returns the new
413 | observation.
414 | """
415 | promise = self.call('reset')
416 | if blocking:
417 | return promise()
418 | else:
419 | return promise
420 |
421 | def _receive(self):
422 | """Wait for a message from the worker process and return its payload.
423 |
424 | Raises:
425 | Exception: An exception was raised inside the worker process.
426 | KeyError: The received message is of an unknown type.
427 |
428 | Returns:
429 | Payload object of the message.
430 | """
431 | message, payload = self._conn.recv()
432 | # Re-raise exceptions in the main process.
433 | if message == self._EXCEPTION:
434 | stacktrace = payload
435 | raise Exception(stacktrace)
436 | if message == self._RESULT:
437 | return payload
438 | raise KeyError('Received message of unexpected type {}'.format(message))
439 |
440 | def _worker(self, constructor, conn):
441 | """The process waits for actions and sends back environment results.
442 |
443 | Args:
444 | constructor: Constructor for the OpenAI Gym environment.
445 | conn: Connection for communication to the main process.
446 |
447 | Raises:
448 | KeyError: When receiving a message of unknown type.
449 | """
450 | try:
451 | env = constructor()
452 | while True:
453 | try:
454 | # Only block for short times to have keyboard exceptions be raised.
455 | if not conn.poll(0.1):
456 | continue
457 | message, payload = conn.recv()
458 | except (EOFError, KeyboardInterrupt):
459 | break
460 | if message == self._ACCESS:
461 | name = payload
462 | result = getattr(env, name)
463 | conn.send((self._RESULT, result))
464 | continue
465 | if message == self._CALL:
466 | name, args, kwargs = payload
467 | result = getattr(env, name)(*args, **kwargs)
468 | conn.send((self._RESULT, result))
469 | continue
470 | if message == self._CLOSE:
471 | assert payload is None
472 | break
473 | raise KeyError('Received message of unknown type {}'.format(message))
474 | except Exception: # pylint: disable=broad-except
475 | stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
476 | tf.logging.error('Error in environment process: {}'.format(stacktrace))
477 | conn.send((self._EXCEPTION, stacktrace))
478 | conn.close()
479 |
480 |
481 | class ConvertTo32Bit(object):
482 | """Convert data types of an OpenAI Gym environment to 32 bit."""
483 |
484 | def __init__(self, env):
485 | """Convert data types of an OpenAI Gym environment to 32 bit.
486 |
487 | Args:
488 | env: OpenAI Gym environment.
489 | """
490 | self._env = env
491 |
492 | def __getattr__(self, name):
493 | """Forward unimplemented attributes to the original environment.
494 |
495 | Args:
496 | name: Attribute that was accessed.
497 |
498 | Returns:
499 | Value behind the attribute name in the wrapped environment.
500 | """
501 | return getattr(self._env, name)
502 |
503 | def step(self, action):
504 | """Forward action to the wrapped environment.
505 |
506 | Args:
507 | action: Action to apply to the environment.
508 |
509 | Raises:
510 | ValueError: Invalid action.
511 |
512 | Returns:
513 | Converted observation, converted reward, done flag, and info object.
514 | """
515 | observ, reward, done, info = self._env.step(action)
516 | observ = self._convert_observ(observ)
517 | reward = self._convert_reward(reward)
518 | return observ, reward, done, info
519 |
520 | def reset(self):
521 | """Reset the environment and convert the resulting observation.
522 |
523 | Returns:
524 | Converted observation.
525 | """
526 | observ = self._env.reset()
527 | observ = self._convert_observ(observ)
528 | return observ
529 |
530 | def _convert_observ(self, observ):
531 | """Convert the observation to 32 bits.
532 |
533 | Args:
534 | observ: Numpy observation.
535 |
536 | Raises:
537 | ValueError: Observation contains infinite values.
538 |
539 | Returns:
540 | Numpy observation with 32-bit data type.
541 | """
542 | if not np.isfinite(observ).all():
543 | raise ValueError('Infinite observation encountered.')
544 | if observ.dtype == np.float64:
545 | return observ.astype(np.float32)
546 | if observ.dtype == np.int64:
547 | return observ.astype(np.int32)
548 | return observ
549 |
550 | def _convert_reward(self, reward):
551 | """Convert the reward to 32 bits.
552 |
553 | Args:
554 | reward: Numpy reward.
555 |
556 | Raises:
557 | ValueError: Rewards contain infinite values.
558 |
559 | Returns:
560 | Numpy reward with 32-bit data type.
561 | """
562 | if not np.isfinite(reward).all():
563 | raise ValueError('Infinite reward encountered.')
564 | return np.array(reward, dtype=np.float32)
565 |
566 |
567 | class CacheSpaces(object):
568 | """Cache observation and action space to not recompute them repeatedly."""
569 |
570 | def __init__(self, env):
571 | """Cache observation and action space to not recompute them repeatedly.
572 |
573 | Args:
574 | env: OpenAI Gym environment.
575 | """
576 | self._env = env
577 | self._observation_space = self._env.observation_space
578 | self._action_space = self._env.action_space
579 |
580 | def __getattr__(self, name):
581 | """Forward unimplemented attributes to the original environment.
582 |
583 | Args:
584 | name: Attribute that was accessed.
585 |
586 | Returns:
587 | Value behind the attribute name in the wrapped environment.
588 | """
589 | return getattr(self._env, name)
590 |
591 | @property
592 | def observation_space(self):
593 | return self._observation_space
594 |
595 | @property
596 | def action_space(self):
597 | return self._action_space
598 |
--------------------------------------------------------------------------------
/agents/tools/wrappers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for environment wrappers."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 |
23 | import tensorflow as tf
24 |
25 | from agents import tools
26 |
27 |
28 | class ExternalProcessTest(tf.test.TestCase):
29 |
30 | def test_close_no_hang_after_init(self):
31 | constructor = functools.partial(
32 | tools.MockEnvironment,
33 | observ_shape=(2, 3), action_shape=(2,),
34 | min_duration=2, max_duration=2)
35 | env = tools.wrappers.ExternalProcess(constructor)
36 | env.close()
37 |
38 | def test_close_no_hang_after_step(self):
39 | constructor = functools.partial(
40 | tools.MockEnvironment,
41 | observ_shape=(2, 3), action_shape=(2,),
42 | min_duration=5, max_duration=5)
43 | env = tools.wrappers.ExternalProcess(constructor)
44 | env.reset()
45 | env.step(env.action_space.sample())
46 | env.step(env.action_space.sample())
47 | env.close()
48 |
49 | def test_reraise_exception_in_init(self):
50 | constructor = MockEnvironmentCrashInInit
51 | env = tools.wrappers.ExternalProcess(constructor)
52 | with self.assertRaises(Exception):
53 | env.step(env.action_space.sample())
54 |
55 | def test_reraise_exception_in_step(self):
56 | constructor = functools.partial(
57 | MockEnvironmentCrashInStep, crash_at_step=3)
58 | env = tools.wrappers.ExternalProcess(constructor)
59 | env.reset()
60 | env.step(env.action_space.sample())
61 | env.step(env.action_space.sample())
62 | with self.assertRaises(Exception):
63 | env.step(env.action_space.sample())
64 |
65 |
66 | class MockEnvironmentCrashInInit(object):
67 | """Raise an error when instantiated."""
68 |
69 | def __init__(self, *unused_args, **unused_kwargs):
70 | raise RuntimeError()
71 |
72 |
73 | class MockEnvironmentCrashInStep(tools.MockEnvironment):
74 | """Raise an error after specified number of steps in an episode."""
75 |
76 | def __init__(self, crash_at_step):
77 | super(MockEnvironmentCrashInStep, self).__init__(
78 | observ_shape=(2, 3), action_shape=(2,),
79 | min_duration=crash_at_step + 1, max_duration=crash_at_step + 1)
80 | self._crash_at_step = crash_at_step
81 |
82 | def step(self, *args, **kwargs): # pylint: disable=arguments-differ
83 | transition = super(MockEnvironmentCrashInStep, self).step(*args, **kwargs)
84 | if self.steps[-1] == self._crash_at_step:
85 | raise RuntimeError()
86 | return transition
87 |
88 |
89 | if __name__ == '__main__':
90 | tf.test.main()
91 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Agents Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Setup script for TensorFlow Agents."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import setuptools
22 |
23 |
24 | setuptools.setup(
25 | name='batch-ppo',
26 | version='1.4.0',
27 | description=(
28 | 'Efficient TensorFlow implementation of ' +
29 | 'Proximal Policy Optimization.'),
30 | license='Apache 2.0',
31 | url='http://github.com/google-research/batch-ppo',
32 | install_requires=[
33 | 'tensorflow',
34 | 'gym',
35 | 'ruamel.yaml',
36 | ],
37 | packages=setuptools.find_packages(),
38 | classifiers=[
39 | 'Programming Language :: Python :: 2',
40 | 'Programming Language :: Python :: 3',
41 | 'License :: OSI Approved :: Apache Software License',
42 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
43 | 'Intended Audience :: Science/Research',
44 | ],
45 | )
46 |
--------------------------------------------------------------------------------