├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── README.md
├── batcher.cc
├── dmlab30.py
├── dynamic_batching.py
├── dynamic_batching_test.py
├── environments.py
├── experiment.py
├── py_process.py
├── py_process_test.py
├── vtrace.py
└── vtrace_test.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution,
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:18.04
2 |
3 | # Install dependencies.
4 | # g++ (v. 5.4) does not work: https://github.com/tensorflow/tensorflow/issues/13308
5 | RUN apt-get update && apt-get install -y \
6 | curl \
7 | zip \
8 | unzip \
9 | software-properties-common \
10 | pkg-config \
11 | g++-4.8 \
12 | zlib1g-dev \
13 | python \
14 | lua5.1 \
15 | liblua5.1-0-dev \
16 | libffi-dev \
17 | gettext \
18 | freeglut3 \
19 | libsdl2-dev \
20 | libosmesa6-dev \
21 | libglu1-mesa \
22 | libglu1-mesa-dev \
23 | python-dev \
24 | build-essential \
25 | git \
26 | python-setuptools \
27 | python-pip \
28 | libjpeg-dev
29 |
30 | # Install bazel
31 | RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | \
32 | tee /etc/apt/sources.list.d/bazel.list && \
33 | curl https://bazel.build/bazel-release.pub.gpg | \
34 | apt-key add - && \
35 | apt-get update && apt-get install -y bazel
36 |
37 | # Install TensorFlow and other dependencies
38 | RUN pip install tensorflow==1.9.0 dm-sonnet==1.23
39 |
40 | # Build and install DeepMind Lab pip package.
41 | # We explicitly set the Numpy path as shown here:
42 | # https://github.com/deepmind/lab/blob/master/docs/users/build.md
43 | RUN NP_INC="$(python -c 'import numpy as np; print(np.get_include())[5:]')" && \
44 | git clone https://github.com/deepmind/lab.git && \
45 | cd lab && \
46 | sed -i 's@hdrs = glob(\[@hdrs = glob(["'"$NP_INC"'/\*\*/*.h", @g' python.BUILD && \
47 | sed -i 's@includes = \[@includes = ["'"$NP_INC"'", @g' python.BUILD && \
48 | bazel build -c opt python/pip_package:build_pip_package && \
49 | pip install wheel && \
50 | ./bazel-bin/python/pip_package/build_pip_package /tmp/dmlab_pkg && \
51 | pip install /tmp/dmlab_pkg/DeepMind_Lab-1.0-py2-none-any.whl --force-reinstall
52 |
53 | # Install dataset (from https://github.com/deepmind/lab/tree/master/data/brady_konkle_oliva2008)
54 | RUN mkdir dataset && \
55 | cd dataset && \
56 | pip install Pillow && \
57 | curl -sS https://raw.githubusercontent.com/deepmind/lab/master/data/brady_konkle_oliva2008/README.md | \
58 | tr '\n' '\r' | \
59 | sed -e 's/.*```sh\(.*\)```.*/\1/' | \
60 | tr '\r' '\n' | \
61 | bash
62 |
63 | # Clone.
64 | RUN git clone https://github.com/deepmind/scalable_agent.git
65 | WORKDIR scalable_agent
66 |
67 | # Build dynamic batching module.
68 | RUN TF_INC="$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')" && \
69 | TF_LIB="$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')" && \
70 | g++-4.8 -std=c++11 -shared batcher.cc -o batcher.so -fPIC -I $TF_INC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 -L$TF_LIB -ltensorflow_framework
71 |
72 | # Run tests.
73 | RUN python py_process_test.py
74 | RUN python dynamic_batching_test.py
75 | RUN python vtrace_test.py
76 |
77 | # Run.
78 | CMD ["sh", "-c", "python experiment.py --total_environment_frames=10000 --dataset_path=../dataset && python experiment.py --mode=test --test_num_episodes=5"]
79 |
80 | # Docker commands:
81 | # docker rm scalable_agent -v
82 | # docker build -t scalable_agent .
83 | # docker run --name scalable_agent scalable_agent
84 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures
2 |
3 | This repository contains an implementation of "Importance Weighted Actor-Learner
4 | Architectures", along with a *dynamic batching* module. This is not an
5 | officially supported Google product.
6 |
7 | For a detailed description of the architecture please read [our paper][arxiv].
8 | Please cite the paper if you use the code from this repository in your work.
9 |
10 | ### Bibtex
11 |
12 | ```
13 | @inproceedings{impala2018,
14 | title={IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures},
15 | author={Espeholt, Lasse and Soyer, Hubert and Munos, Remi and Simonyan, Karen and Mnih, Volodymir and Ward, Tom and Doron, Yotam and Firoiu, Vlad and Harley, Tim and Dunning, Iain and others},
16 | booktitle={Proceedings of the International Conference on Machine Learning (ICML)},
17 | year={2018}
18 | }
19 | ```
20 |
21 | ## Running the Code
22 |
23 | ### Prerequisites
24 |
25 | [TensorFlow][tensorflow] >=1.9.0-dev20180530, the environment
26 | [DeepMind Lab][deepmind_lab] and the neural network library
27 | [DeepMind Sonnet][sonnet]. Although we use [DeepMind Lab][deepmind_lab] in this
28 | release, the agent has been successfully applied to other domains such as
29 | [Atari][arxiv], [Street View][learning_nav] and has been modified to
30 | [generate images][generate_images].
31 |
32 | We include a [Dockerfile][dockerfile] that serves as a reference for the
33 | prerequisites and commands needed to run the code.
34 |
35 | ### Single Machine Training on a Single Level
36 |
37 | Training on `explore_goal_locations_small`. Most runs should end up with average
38 | episode returns around 200 or around 250 after 1B frames.
39 |
40 | ```sh
41 | python experiment.py --num_actors=48 --batch_size=32
42 | ```
43 |
44 | Adjust the number of actors (i.e. number of environments) and batch size to
45 | match the size of the machine it runs on. A single actor, including DeepMind
46 | Lab, requires a few hundred MB of RAM.
47 |
48 | ### Distributed Training on DMLab-30
49 |
50 | Training on the full [DMLab-30][dmlab30]. Across 10 runs with different seeds
51 | but identical hyperparameters, we observed between 45 and 50 capped human
52 | normalized training score with different seeds (`--seed=[seed]`). Test scores
53 | are usually an absolute of ~2% lower.
54 |
55 | #### Learner
56 |
57 | ```sh
58 | python experiment.py --job_name=learner --task=0 --num_actors=150 \
59 | --level_name=dmlab30 --batch_size=32 --entropy_cost=0.0033391318945337044 \
60 | --learning_rate=0.00031866995608948655 \
61 | --total_environment_frames=10000000000 --reward_clipping=soft_asymmetric
62 | ```
63 |
64 | #### Actor(s)
65 |
66 | ```sh
67 | for i in $(seq 0 149); do
68 | python experiment.py --job_name=actor --task=$i \
69 | --num_actors=150 --level_name=dmlab30 --dataset_path=[...] &
70 | done;
71 | wait
72 | ```
73 |
74 | #### Test Score
75 |
76 | ```sh
77 | python experiment.py --mode=test --level_name=dmlab30 --dataset_path=[...] \
78 | --test_num_episodes=10
79 | ```
80 |
81 | [arxiv]: https://arxiv.org/abs/1802.01561
82 | [deepmind_lab]: https://github.com/deepmind/lab
83 | [sonnet]: https://github.com/deepmind/sonnet
84 | [learning_nav]: https://arxiv.org/abs/1804.00168
85 | [generate_images]: https://deepmind.com/blog/learning-to-generate-images/
86 | [tensorflow]: https://github.com/tensorflow/tensorflow
87 | [dockerfile]: Dockerfile
88 | [dmlab30]: https://github.com/deepmind/lab/tree/master/game_scripts/levels/contributed/dmlab30
89 |
--------------------------------------------------------------------------------
/batcher.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | // TensorFlow operations for dynamic batching.
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 |
22 | #include "tensorflow/core/framework/op.h"
23 | #include "tensorflow/core/framework/op_kernel.h"
24 | #include "tensorflow/core/framework/resource_op_kernel.h"
25 | #include "tensorflow/core/framework/shape_inference.h"
26 | #include "tensorflow/core/lib/gtl/flatmap.h"
27 | #include "tensorflow/core/lib/gtl/optional.h"
28 | #include "tensorflow/core/lib/strings/strcat.h"
29 | #include "tensorflow/core/util/batch_util.h"
30 | #include "tensorflow/core/util/work_sharder.h"
31 |
32 | namespace tensorflow {
33 | namespace {
34 |
35 | REGISTER_OP("Batcher")
36 | .Output("handle: resource")
37 | .Attr("minimum_batch_size: int")
38 | .Attr("maximum_batch_size: int")
39 | .Attr("timeout_ms: int")
40 | .Attr("container: string = ''")
41 | .Attr("shared_name: string = ''")
42 | .SetIsStateful()
43 | .SetShapeFn(shape_inference::ScalarShape)
44 | .Doc(R"doc(
45 | A Batcher which batches up computations into the same batch.
46 | )doc");
47 |
48 | REGISTER_OP("BatcherCompute")
49 | .Input("handle: resource")
50 | .Input("input_list: Tinput_list")
51 | .Attr("Tinput_list: list(type) >= 1")
52 | .Attr("Toutput_list: list(type) >= 1")
53 | .Output("output_list: Toutput_list")
54 | .SetShapeFn(shape_inference::UnknownShape)
55 | .Doc(R"doc(
56 | Puts the input into the computation queue, waits and returns the result.
57 | )doc");
58 |
59 | REGISTER_OP("BatcherGetInputs")
60 | .Input("handle: resource")
61 | .Attr("Toutput_list: list(type) >= 1")
62 | .Output("output_list: Toutput_list")
63 | .Output("computation_id: int64")
64 | .SetShapeFn([](shape_inference::InferenceContext* c) {
65 | for (int i = 0; i < c->num_outputs() - 1; ++i) {
66 | c->set_output(i, c->UnknownShape());
67 | }
68 | return c->set_output("computation_id", {c->Scalar()});
69 | })
70 | .Doc(R"doc(
71 | Gets a batch of inputs to compute the results of.
72 | )doc");
73 |
74 | REGISTER_OP("BatcherSetOutputs")
75 | .Input("handle: resource")
76 | .Input("input_list: Tinput_list")
77 | .Input("computation_id: int64")
78 | .Attr("Tinput_list: list(type) >= 1")
79 | .SetShapeFn(shape_inference::UnknownShape)
80 | .Doc(R"doc(
81 | Sets the outputs of a batch for the function.
82 | )doc");
83 |
84 | REGISTER_OP("BatcherClose")
85 | .Input("handle: resource")
86 | .SetShapeFn(shape_inference::NoOutputs)
87 | .Doc(R"doc(
88 | Closes the batcher and cancels all pending batcher operations.
89 | )doc");
90 |
91 | class Batcher : public ResourceBase {
92 | public:
93 | using DoneCallback = AsyncOpKernel::DoneCallback;
94 |
95 | Batcher(int32 minimum_batch_size, int32 maximum_batch_size,
96 | gtl::optional timeout)
97 | : ResourceBase(),
98 | curr_computation_id_(0),
99 | is_closed_(false),
100 | minimum_batch_size_(minimum_batch_size),
101 | maximum_batch_size_(maximum_batch_size),
102 | timeout_(std::move(timeout)) {}
103 |
104 | string DebugString() override {
105 | mutex_lock l(mu_);
106 | return strings::StrCat("Batcher with ", inputs_.size(), " waiting inputs.");
107 | }
108 |
109 | void Compute(OpKernelContext* context, const OpInputList& input_list,
110 | DoneCallback callback);
111 |
112 | void GetInputs(OpKernelContext* context, OpOutputList* output_list);
113 |
114 | void SetOutputs(OpKernelContext* context, const OpInputList& input_list,
115 | int64 computation_id);
116 |
117 | void Close(OpKernelContext* context);
118 |
119 | private:
120 | class Input {
121 | public:
122 | Input(OpKernelContext* context, const OpInputList& input_list,
123 | DoneCallback callback)
124 | : context_(context),
125 | input_list_(input_list),
126 | callback_(std::move(callback)) {}
127 |
128 | // Moveable but not copyable.
129 | Input(Input&& rhs)
130 | : Input(rhs.context_, rhs.input_list_, std::move(rhs.callback_)) {
131 | rhs.context_ = nullptr; // Mark invalid.
132 | }
133 |
134 | Input& operator=(Input&& rhs) {
135 | this->context_ = rhs.context_;
136 | this->input_list_ = rhs.input_list_;
137 | this->callback_ = std::move(rhs.callback_);
138 | rhs.context_ = nullptr; // Mark invalid.
139 | return *this;
140 | }
141 |
142 | OpKernelContext* context() const {
143 | CHECK(is_valid());
144 | return context_;
145 | }
146 |
147 | const OpInputList& input_list() const {
148 | CHECK(is_valid());
149 | return input_list_;
150 | }
151 |
152 | bool is_valid() const { return context_ != nullptr; }
153 |
154 | void Done() {
155 | CHECK(is_valid());
156 |
157 | // After callback is called, context_, input_list_ and callback_ becomes
158 | // invalid and shouldn't be used.
159 | context_ = nullptr;
160 | callback_();
161 | }
162 |
163 | private:
164 | // Not owned.
165 | OpKernelContext* context_;
166 | OpInputList input_list_;
167 | DoneCallback callback_;
168 | };
169 |
170 | void CancelInput(Input* input) EXCLUSIVE_LOCKS_REQUIRED(mu_);
171 |
172 | void GetInputsInternal(OpKernelContext* context, OpOutputList* output_list)
173 | EXCLUSIVE_LOCKS_REQUIRED(mu_);
174 |
175 | void SetOutputsInternal(OpKernelContext* context,
176 | const OpInputList& input_list, int64 computation_id)
177 | EXCLUSIVE_LOCKS_REQUIRED(mu_);
178 |
179 | // Cancels all pending Compute ops and marks the batcher closed.
180 | void CancelAndClose(OpKernelContext* context) EXCLUSIVE_LOCKS_REQUIRED(mu_);
181 |
182 | mutex mu_;
183 | condition_variable full_batch_or_cancelled_cond_var_;
184 |
185 | // A counter of all batched computations that have been started that is used
186 | // to create a unique id for each batched computation.
187 | int64 curr_computation_id_ GUARDED_BY(mu_);
188 |
189 | // Inputs waiting to be computed.
190 | std::deque inputs_ GUARDED_BY(mu_);
191 |
192 | // Batches that are currently being computed. Maps computation_id to a batch
193 | // of inputs.
194 | gtl::FlatMap> being_computed_ GUARDED_BY(mu_);
195 |
196 | // Whether the Batcher has been closed (happens when there is an error or
197 | // Close() has been called.)
198 | bool is_closed_ GUARDED_BY(mu_);
199 | const int32 minimum_batch_size_;
200 | const int32 maximum_batch_size_;
201 | const gtl::optional timeout_;
202 |
203 | TF_DISALLOW_COPY_AND_ASSIGN(Batcher);
204 | };
205 |
206 | void Batcher::Compute(OpKernelContext* context, const OpInputList& input_list,
207 | DoneCallback callback) {
208 | bool should_notify;
209 |
210 | {
211 | mutex_lock l(mu_);
212 |
213 | OP_REQUIRES_ASYNC(context, !is_closed_,
214 | errors::Cancelled("Batcher is closed"), callback);
215 |
216 | // Add the inputs to the list of inputs.
217 | inputs_.emplace_back(context, input_list, std::move(callback));
218 |
219 | should_notify = inputs_.size() >= minimum_batch_size_;
220 | }
221 |
222 | if (should_notify) {
223 | // If a GetInputs operation is blocked, wake it up.
224 | full_batch_or_cancelled_cond_var_.notify_one();
225 | }
226 | }
227 |
228 | void Batcher::GetInputs(OpKernelContext* context, OpOutputList* output_list) {
229 | CancellationManager* cm = context->cancellation_manager();
230 | CancellationToken token = cm->get_cancellation_token();
231 |
232 | bool is_cancelled_or_cancelling = !cm->RegisterCallback(
233 | token, [this]() { full_batch_or_cancelled_cond_var_.notify_all(); });
234 |
235 | mutex_lock l(mu_);
236 | std::cv_status status = std::cv_status::no_timeout;
237 |
238 | // Wait for data if the input list has fewer samples than `minimum_batch_size`
239 | // (or non-empty when a timeout has occurred), for cancellation of the
240 | // operation or for the batcher to be closed.
241 | while (((status == std::cv_status::timeout && inputs_.empty()) ||
242 | (status == std::cv_status::no_timeout &&
243 | inputs_.size() < minimum_batch_size_)) &&
244 | !is_cancelled_or_cancelling && !is_closed_) {
245 | // Using a timeout to make sure the operation always completes after a while
246 | // when there isn't enough samples and for the unlikely case where the
247 | // operation is being cancelled between checking if it has been cancelled
248 | // and calling wait_for().
249 | if (timeout_) {
250 | status = full_batch_or_cancelled_cond_var_.wait_for(l, *timeout_);
251 | } else {
252 | // Timeout is only used to check for cancellation as described in the
253 | // comment above.
254 | full_batch_or_cancelled_cond_var_.wait_for(
255 | l, std::chrono::milliseconds(100));
256 | }
257 | is_cancelled_or_cancelling = cm->IsCancelled();
258 | }
259 |
260 | if (is_closed_) {
261 | context->SetStatus(errors::Cancelled("Batcher is closed"));
262 | } else if (is_cancelled_or_cancelling) {
263 | context->SetStatus(errors::Cancelled("GetInputs operation was cancelled"));
264 | } else {
265 | GetInputsInternal(context, output_list);
266 | }
267 |
268 | if (!context->status().ok()) {
269 | CancelAndClose(context);
270 | }
271 | }
272 |
273 | void Batcher::GetInputsInternal(OpKernelContext* context,
274 | OpOutputList* output_list) {
275 | int64 batch_size = std::min(inputs_.size(), maximum_batch_size_);
276 | size_t num_tensors = inputs_.front().input_list().size();
277 |
278 | // Allocate output tensors.
279 | std::vector output_tensors(num_tensors);
280 | for (size_t i = 0; i < num_tensors; ++i) {
281 | TensorShape shape = inputs_.front().input_list()[i].shape();
282 | OP_REQUIRES(
283 | context, shape.dim_size(0) == 1,
284 | errors::InvalidArgument("Batcher requires batch size 1 but was ",
285 | shape.dim_size(0)));
286 | shape.set_dim(0, batch_size);
287 |
288 | OP_REQUIRES_OK(context,
289 | output_list->allocate(i, shape, &output_tensors[i]));
290 | }
291 |
292 | auto work = [this, &context, &output_tensors, num_tensors](
293 | int64 start, int64 end) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
294 | for (int64 j = start; j < end; ++j) {
295 | for (size_t i = 0; i < num_tensors; ++i) {
296 | OP_REQUIRES(context,
297 | inputs_[0].input_list()[i].shape() ==
298 | inputs_[j].input_list()[i].shape(),
299 | errors::InvalidArgument(
300 | "Shapes of inputs much be equal. Shapes observed: ",
301 | inputs_[0].input_list()[i].shape().DebugString(), ", ",
302 | inputs_[j].input_list()[i].shape().DebugString()));
303 |
304 | OP_REQUIRES_OK(context,
305 | tensorflow::batch_util::CopyElementToSlice(
306 | inputs_[j].input_list()[i], output_tensors[i], j));
307 | }
308 | }
309 | };
310 |
311 | auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
312 | Shard(worker_threads->num_threads, worker_threads->workers, batch_size, 10,
313 | work);
314 |
315 | // New unique computation id.
316 | int64 new_computation_id = curr_computation_id_++;
317 | Tensor* computation_id_t = nullptr;
318 | OP_REQUIRES_OK(context,
319 | context->allocate_output("computation_id", TensorShape({}),
320 | &computation_id_t));
321 | computation_id_t->scalar()() = new_computation_id;
322 |
323 | // Move the batch of inputs into a list for the new computation.
324 | auto iter = std::make_move_iterator(inputs_.begin());
325 | being_computed_.emplace(new_computation_id,
326 | std::vector{iter, iter + batch_size});
327 | inputs_.erase(inputs_.begin(), inputs_.begin() + batch_size);
328 | }
329 |
330 | void Batcher::SetOutputs(OpKernelContext* context,
331 | const OpInputList& input_list, int64 computation_id) {
332 | mutex_lock l(mu_);
333 | SetOutputsInternal(context, input_list, computation_id);
334 | if (!context->status().ok()) {
335 | CancelAndClose(context);
336 | }
337 | }
338 |
339 | void Batcher::SetOutputsInternal(OpKernelContext* context,
340 | const OpInputList& input_list,
341 | int64 computation_id) {
342 | OP_REQUIRES(context, !is_closed_, errors::Cancelled("Batcher is closed"));
343 |
344 | auto search = being_computed_.find(computation_id);
345 | OP_REQUIRES(
346 | context, search != being_computed_.end(),
347 | errors::InvalidArgument("Invalid computation id. Id: ", computation_id));
348 | auto& computation_input_list = search->second;
349 | int64 expected_batch_size = computation_input_list.size();
350 |
351 | for (const Tensor& tensor : input_list) {
352 | OP_REQUIRES(
353 | context, tensor.shape().dims() > 0,
354 | errors::InvalidArgument(
355 | "Output shape must have a batch dimension. Shape observed: ",
356 | tensor.shape().DebugString()));
357 | OP_REQUIRES(
358 | context, tensor.shape().dim_size(0) == expected_batch_size,
359 | errors::InvalidArgument("Output shape must have the same batch "
360 | "dimension as the input batch size. Expected: ",
361 | expected_batch_size,
362 | " Observed: ", tensor.shape().dim_size(0)));
363 | }
364 |
365 | auto work = [this, &input_list, &context, &computation_input_list](
366 | int64 start, int64 end) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
367 | for (int64 j = start; j < end; ++j) {
368 | Input& input = computation_input_list[j];
369 |
370 | for (size_t i = 0; i < input_list.size(); ++i) {
371 | TensorShape shape = input_list[i].shape();
372 | shape.set_dim(0, 1);
373 |
374 | Tensor* output_tensor;
375 | OP_REQUIRES_OK(context, input.context()->allocate_output(
376 | i, shape, &output_tensor));
377 |
378 | OP_REQUIRES_OK(context, tensorflow::batch_util::CopySliceToElement(
379 | input_list[i], output_tensor, j));
380 | }
381 |
382 | input.Done();
383 | }
384 | };
385 |
386 | auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
387 | Shard(worker_threads->num_threads, worker_threads->workers,
388 | expected_batch_size, 50000, work);
389 |
390 | being_computed_.erase(computation_id);
391 | }
392 |
393 | void Batcher::Close(OpKernelContext* context) {
394 | {
395 | mutex_lock l(mu_);
396 | CancelAndClose(context);
397 | }
398 |
399 | // Cancel all running GetInputs operations.
400 | full_batch_or_cancelled_cond_var_.notify_all();
401 | }
402 |
403 | void Batcher::CancelInput(Batcher::Input* input) {
404 | // Some may already have had their outputs set and the callback called so
405 | // they should be skipped.
406 | if (!input->is_valid()) {
407 | return;
408 | }
409 |
410 | input->context()->CtxFailure(errors::Cancelled("Compute was cancelled"));
411 | input->Done();
412 | }
413 |
414 | void Batcher::CancelAndClose(OpKernelContext* context) {
415 | // Something went wrong or the batcher was requested to close. All the waiting
416 | // Compute ops should be cancelled.
417 |
418 | if (is_closed_) {
419 | return;
420 | }
421 |
422 | for (auto& input : inputs_) {
423 | CancelInput(&input);
424 | }
425 | for (auto& p : being_computed_) {
426 | for (auto& input : p.second) {
427 | CancelInput(&input);
428 | }
429 | }
430 | is_closed_ = true; // Causes future Compute operations to be cancelled.
431 | }
432 |
433 | class BatcherHandleOp : public ResourceOpKernel {
434 | public:
435 | explicit BatcherHandleOp(OpKernelConstruction* context)
436 | : ResourceOpKernel(context) {
437 | OP_REQUIRES_OK(
438 | context, context->GetAttr("minimum_batch_size", &minimum_batch_size_));
439 | OP_REQUIRES_OK(
440 | context, context->GetAttr("maximum_batch_size", &maximum_batch_size_));
441 | OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_ms_));
442 | }
443 |
444 | private:
445 | Status CreateResource(Batcher** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
446 | gtl::optional timeout;
447 | if (timeout_ms_ != -1) {
448 | timeout = std::chrono::milliseconds(timeout_ms_);
449 | }
450 | *ret = new Batcher(minimum_batch_size_, maximum_batch_size_, timeout);
451 | return Status::OK();
452 | }
453 |
454 | int32 minimum_batch_size_;
455 | int32 maximum_batch_size_;
456 | int32 timeout_ms_;
457 |
458 | TF_DISALLOW_COPY_AND_ASSIGN(BatcherHandleOp);
459 | };
460 |
461 | class ComputeOp : public AsyncOpKernel {
462 | public:
463 | explicit ComputeOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
464 |
465 | void ComputeAsync(OpKernelContext* context, DoneCallback callback) override {
466 | Batcher* batcher;
467 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
468 | &batcher));
469 |
470 | OpInputList input_list;
471 | OP_REQUIRES_OK(context, context->input_list("input_list", &input_list));
472 |
473 | batcher->Compute(context, input_list, std::move(callback));
474 | }
475 |
476 | private:
477 | TF_DISALLOW_COPY_AND_ASSIGN(ComputeOp);
478 | };
479 |
480 | class GetInputsOp : public OpKernel {
481 | public:
482 | explicit GetInputsOp(OpKernelConstruction* context) : OpKernel(context) {}
483 |
484 | void Compute(OpKernelContext* context) override {
485 | Batcher* batcher;
486 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
487 | &batcher));
488 |
489 | OpOutputList output_list;
490 | OP_REQUIRES_OK(context, context->output_list("output_list", &output_list));
491 |
492 | batcher->GetInputs(context, &output_list);
493 | }
494 |
495 | private:
496 | TF_DISALLOW_COPY_AND_ASSIGN(GetInputsOp);
497 | };
498 |
499 | class SetOutputsOp : public OpKernel {
500 | public:
501 | explicit SetOutputsOp(OpKernelConstruction* context) : OpKernel(context) {}
502 |
503 | void Compute(OpKernelContext* context) override {
504 | Batcher* batcher;
505 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
506 | &batcher));
507 |
508 | OpInputList input_list;
509 | OP_REQUIRES_OK(context, context->input_list("input_list", &input_list));
510 |
511 | const Tensor* computation_id;
512 | OP_REQUIRES_OK(context, context->input("computation_id", &computation_id));
513 |
514 | batcher->SetOutputs(context, input_list, computation_id->scalar()());
515 | }
516 |
517 | private:
518 | TF_DISALLOW_COPY_AND_ASSIGN(SetOutputsOp);
519 | };
520 |
521 | class CloseOp : public OpKernel {
522 | public:
523 | explicit CloseOp(OpKernelConstruction* context) : OpKernel(context) {}
524 |
525 | void Compute(OpKernelContext* context) override {
526 | Batcher* batcher;
527 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
528 | &batcher));
529 |
530 | batcher->Close(context);
531 | }
532 |
533 | private:
534 | TF_DISALLOW_COPY_AND_ASSIGN(CloseOp);
535 | };
536 |
537 | REGISTER_KERNEL_BUILDER(Name("Batcher").Device(DEVICE_CPU), BatcherHandleOp);
538 |
539 | REGISTER_KERNEL_BUILDER(Name("BatcherCompute").Device(DEVICE_CPU), ComputeOp);
540 |
541 | REGISTER_KERNEL_BUILDER(Name("BatcherGetInputs").Device(DEVICE_CPU),
542 | GetInputsOp);
543 |
544 | REGISTER_KERNEL_BUILDER(Name("BatcherSetOutputs").Device(DEVICE_CPU),
545 | SetOutputsOp);
546 |
547 | REGISTER_KERNEL_BUILDER(Name("BatcherClose").Device(DEVICE_CPU), CloseOp);
548 |
549 | } // namespace
550 | } // namespace tensorflow
551 |
--------------------------------------------------------------------------------
/dmlab30.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for DMLab-30."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 |
23 | import numpy as np
24 | import tensorflow as tf
25 |
26 |
27 | LEVEL_MAPPING = collections.OrderedDict([
28 | ('rooms_collect_good_objects_train', 'rooms_collect_good_objects_test'),
29 | ('rooms_exploit_deferred_effects_train',
30 | 'rooms_exploit_deferred_effects_test'),
31 | ('rooms_select_nonmatching_object', 'rooms_select_nonmatching_object'),
32 | ('rooms_watermaze', 'rooms_watermaze'),
33 | ('rooms_keys_doors_puzzle', 'rooms_keys_doors_puzzle'),
34 | ('language_select_described_object', 'language_select_described_object'),
35 | ('language_select_located_object', 'language_select_located_object'),
36 | ('language_execute_random_task', 'language_execute_random_task'),
37 | ('language_answer_quantitative_question',
38 | 'language_answer_quantitative_question'),
39 | ('lasertag_one_opponent_small', 'lasertag_one_opponent_small'),
40 | ('lasertag_three_opponents_small', 'lasertag_three_opponents_small'),
41 | ('lasertag_one_opponent_large', 'lasertag_one_opponent_large'),
42 | ('lasertag_three_opponents_large', 'lasertag_three_opponents_large'),
43 | ('natlab_fixed_large_map', 'natlab_fixed_large_map'),
44 | ('natlab_varying_map_regrowth', 'natlab_varying_map_regrowth'),
45 | ('natlab_varying_map_randomized', 'natlab_varying_map_randomized'),
46 | ('skymaze_irreversible_path_hard', 'skymaze_irreversible_path_hard'),
47 | ('skymaze_irreversible_path_varied', 'skymaze_irreversible_path_varied'),
48 | ('psychlab_arbitrary_visuomotor_mapping',
49 | 'psychlab_arbitrary_visuomotor_mapping'),
50 | ('psychlab_continuous_recognition', 'psychlab_continuous_recognition'),
51 | ('psychlab_sequential_comparison', 'psychlab_sequential_comparison'),
52 | ('psychlab_visual_search', 'psychlab_visual_search'),
53 | ('explore_object_locations_small', 'explore_object_locations_small'),
54 | ('explore_object_locations_large', 'explore_object_locations_large'),
55 | ('explore_obstructed_goals_small', 'explore_obstructed_goals_small'),
56 | ('explore_obstructed_goals_large', 'explore_obstructed_goals_large'),
57 | ('explore_goal_locations_small', 'explore_goal_locations_small'),
58 | ('explore_goal_locations_large', 'explore_goal_locations_large'),
59 | ('explore_object_rewards_few', 'explore_object_rewards_few'),
60 | ('explore_object_rewards_many', 'explore_object_rewards_many'),
61 | ])
62 |
63 | HUMAN_SCORES = {
64 | 'rooms_collect_good_objects_test': 10,
65 | 'rooms_exploit_deferred_effects_test': 85.65,
66 | 'rooms_select_nonmatching_object': 65.9,
67 | 'rooms_watermaze': 54,
68 | 'rooms_keys_doors_puzzle': 53.8,
69 | 'language_select_described_object': 389.5,
70 | 'language_select_located_object': 280.7,
71 | 'language_execute_random_task': 254.05,
72 | 'language_answer_quantitative_question': 184.5,
73 | 'lasertag_one_opponent_small': 12.65,
74 | 'lasertag_three_opponents_small': 18.55,
75 | 'lasertag_one_opponent_large': 18.6,
76 | 'lasertag_three_opponents_large': 31.5,
77 | 'natlab_fixed_large_map': 36.9,
78 | 'natlab_varying_map_regrowth': 24.45,
79 | 'natlab_varying_map_randomized': 42.35,
80 | 'skymaze_irreversible_path_hard': 100,
81 | 'skymaze_irreversible_path_varied': 100,
82 | 'psychlab_arbitrary_visuomotor_mapping': 58.75,
83 | 'psychlab_continuous_recognition': 58.3,
84 | 'psychlab_sequential_comparison': 39.5,
85 | 'psychlab_visual_search': 78.5,
86 | 'explore_object_locations_small': 74.45,
87 | 'explore_object_locations_large': 65.65,
88 | 'explore_obstructed_goals_small': 206,
89 | 'explore_obstructed_goals_large': 119.5,
90 | 'explore_goal_locations_small': 267.5,
91 | 'explore_goal_locations_large': 194.5,
92 | 'explore_object_rewards_few': 77.7,
93 | 'explore_object_rewards_many': 106.7,
94 | }
95 |
96 | RANDOM_SCORES = {
97 | 'rooms_collect_good_objects_test': 0.073,
98 | 'rooms_exploit_deferred_effects_test': 8.501,
99 | 'rooms_select_nonmatching_object': 0.312,
100 | 'rooms_watermaze': 4.065,
101 | 'rooms_keys_doors_puzzle': 4.135,
102 | 'language_select_described_object': -0.07,
103 | 'language_select_located_object': 1.929,
104 | 'language_execute_random_task': -5.913,
105 | 'language_answer_quantitative_question': -0.33,
106 | 'lasertag_one_opponent_small': -0.224,
107 | 'lasertag_three_opponents_small': -0.214,
108 | 'lasertag_one_opponent_large': -0.083,
109 | 'lasertag_three_opponents_large': -0.102,
110 | 'natlab_fixed_large_map': 2.173,
111 | 'natlab_varying_map_regrowth': 2.989,
112 | 'natlab_varying_map_randomized': 7.346,
113 | 'skymaze_irreversible_path_hard': 0.1,
114 | 'skymaze_irreversible_path_varied': 14.4,
115 | 'psychlab_arbitrary_visuomotor_mapping': 0.163,
116 | 'psychlab_continuous_recognition': 0.224,
117 | 'psychlab_sequential_comparison': 0.129,
118 | 'psychlab_visual_search': 0.085,
119 | 'explore_object_locations_small': 3.575,
120 | 'explore_object_locations_large': 4.673,
121 | 'explore_obstructed_goals_small': 6.76,
122 | 'explore_obstructed_goals_large': 2.61,
123 | 'explore_goal_locations_small': 7.66,
124 | 'explore_goal_locations_large': 3.14,
125 | 'explore_object_rewards_few': 2.073,
126 | 'explore_object_rewards_many': 2.438,
127 | }
128 |
129 | ALL_LEVELS = frozenset([
130 | 'rooms_collect_good_objects_train',
131 | 'rooms_collect_good_objects_test',
132 | 'rooms_exploit_deferred_effects_train',
133 | 'rooms_exploit_deferred_effects_test',
134 | 'rooms_select_nonmatching_object',
135 | 'rooms_watermaze',
136 | 'rooms_keys_doors_puzzle',
137 | 'language_select_described_object',
138 | 'language_select_located_object',
139 | 'language_execute_random_task',
140 | 'language_answer_quantitative_question',
141 | 'lasertag_one_opponent_small',
142 | 'lasertag_three_opponents_small',
143 | 'lasertag_one_opponent_large',
144 | 'lasertag_three_opponents_large',
145 | 'natlab_fixed_large_map',
146 | 'natlab_varying_map_regrowth',
147 | 'natlab_varying_map_randomized',
148 | 'skymaze_irreversible_path_hard',
149 | 'skymaze_irreversible_path_varied',
150 | 'psychlab_arbitrary_visuomotor_mapping',
151 | 'psychlab_continuous_recognition',
152 | 'psychlab_sequential_comparison',
153 | 'psychlab_visual_search',
154 | 'explore_object_locations_small',
155 | 'explore_object_locations_large',
156 | 'explore_obstructed_goals_small',
157 | 'explore_obstructed_goals_large',
158 | 'explore_goal_locations_small',
159 | 'explore_goal_locations_large',
160 | 'explore_object_rewards_few',
161 | 'explore_object_rewards_many',
162 | ])
163 |
164 |
165 | def _transform_level_returns(level_returns):
166 | """Converts training level names to test level names."""
167 | new_level_returns = {}
168 | for level_name, returns in level_returns.iteritems():
169 | new_level_returns[LEVEL_MAPPING.get(level_name, level_name)] = returns
170 |
171 | test_set = set(LEVEL_MAPPING.values())
172 | diff = test_set - set(new_level_returns.keys())
173 | if diff:
174 | raise ValueError('Missing levels: %s' % list(diff))
175 |
176 | for level_name, returns in new_level_returns.iteritems():
177 | if level_name in test_set:
178 | if not returns:
179 | raise ValueError('Missing returns for level: \'%s\': ' % level_name)
180 | else:
181 | tf.logging.info('Skipping level %s for calculation.', level_name)
182 |
183 | return new_level_returns
184 |
185 |
186 | def compute_human_normalized_score(level_returns, per_level_cap):
187 | """Computes human normalized score.
188 |
189 | Levels that have different training and test versions, will use the returns
190 | for the training level to calculate the score. E.g.
191 | 'rooms_collect_good_objects_train' will be used for
192 | 'rooms_collect_good_objects_test'. All returns for levels not in DmLab-30
193 | will be ignored.
194 |
195 | Args:
196 | level_returns: A dictionary from level to list of episode returns.
197 | per_level_cap: A percentage cap (e.g. 100.) on the per level human
198 | normalized score. If None, no cap is applied.
199 |
200 | Returns:
201 | A float with the human normalized score in percentage.
202 |
203 | Raises:
204 | ValueError: If a level is missing from `level_returns` or has no returns.
205 | """
206 | new_level_returns = _transform_level_returns(level_returns)
207 |
208 | def human_normalized_score(level_name, returns):
209 | score = np.mean(returns)
210 | human = HUMAN_SCORES[level_name]
211 | random = RANDOM_SCORES[level_name]
212 | human_normalized_score = (score - random) / (human - random) * 100
213 | if per_level_cap is not None:
214 | human_normalized_score = min(human_normalized_score, per_level_cap)
215 | return human_normalized_score
216 |
217 | return np.mean(
218 | [human_normalized_score(k, v) for k, v in new_level_returns.items()])
219 |
--------------------------------------------------------------------------------
/dynamic_batching.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Dynamic batching."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 |
23 | import tensorflow as tf
24 |
25 | batcher_ops = tf.load_op_library('./batcher.so')
26 |
27 | nest = tf.contrib.framework.nest
28 |
29 |
30 | class _Batcher(object):
31 | """A thin layer around the Batcher TensorFlow operations.
32 |
33 | It shares some of the interface with queues (close(), name) to be able to use
34 | it correctly as the input to a QueueRunner.
35 | """
36 |
37 | def __init__(self, minimum_batch_size, maximum_batch_size, timeout_ms):
38 | self._handle = batcher_ops.batcher(minimum_batch_size, maximum_batch_size,
39 | timeout_ms or -1)
40 |
41 | @property
42 | def name(self):
43 | return 'batcher'
44 |
45 | def get_inputs(self, input_dtypes):
46 | return batcher_ops.batcher_get_inputs(self._handle, input_dtypes)
47 |
48 | def set_outputs(self, flat_result, computation_id):
49 | return batcher_ops.batcher_set_outputs(self._handle, flat_result,
50 | computation_id)
51 |
52 | def compute(self, flat_args, output_dtypes):
53 | return batcher_ops.batcher_compute(self._handle, flat_args, output_dtypes)
54 |
55 | def close(self, cancel_pending_enqueues=False, name=None):
56 | del cancel_pending_enqueues
57 | return batcher_ops.batcher_close(self._handle, name=name)
58 |
59 |
60 | def batch_fn(f):
61 | """See `batch_fn_with_options` for details."""
62 | return batch_fn_with_options()(f)
63 |
64 |
65 | def batch_fn_with_options(minimum_batch_size=1, maximum_batch_size=1024,
66 | timeout_ms=100):
67 | """Python decorator that automatically batches computations.
68 |
69 | When the decorated function is called, it creates an operation that adds the
70 | inputs to a queue, waits until the computation is done, and returns the
71 | tensors. The inputs must be nests (see `tf.contrib.framework.nest`) and the
72 | first dimension of each tensor in the nest must have size 1.
73 |
74 | It adds a QueueRunner that asynchronously keeps fetching batches of data,
75 | computes the results and pushes the results back to the caller.
76 |
77 | Example usage:
78 |
79 | @dynamic_batching.batch_fn_with_options(
80 | minimum_batch_size=10, timeout_ms=100)
81 | def fn(a, b):
82 | return a + b
83 |
84 | output0 = fn(tf.constant([1]), tf.constant([2])) # Will be batched with the
85 | # next call.
86 | output1 = fn(tf.constant([3]), tf.constant([4]))
87 |
88 | Note, gradients are currently not supported.
89 | Note, if minimum_batch_size == maximum_batch_size and timeout_ms=None, then
90 | the batch size of input arguments will be set statically. Otherwise, it will
91 | be None.
92 |
93 | Args:
94 | minimum_batch_size: The minimum batch size before processing starts.
95 | maximum_batch_size: The maximum batch size.
96 | timeout_ms: Milliseconds after a batch of samples is requested before it is
97 | processed, even if the batch size is smaller than `minimum_batch_size`. If
98 | None, there is no timeout.
99 |
100 | Returns:
101 | The decorator.
102 | """
103 |
104 | def decorator(f):
105 | """Decorator."""
106 | batcher = [None]
107 | batched_output = [None]
108 |
109 | @functools.wraps(f)
110 | def wrapper(*args):
111 | """Wrapper."""
112 |
113 | flat_args = [tf.convert_to_tensor(arg) for arg in nest.flatten(args)]
114 |
115 | if batcher[0] is None:
116 | # Remove control dependencies which is necessary when created in loops,
117 | # etc.
118 | with tf.control_dependencies(None):
119 | input_dtypes = [t.dtype for t in flat_args]
120 | batcher[0] = _Batcher(minimum_batch_size, maximum_batch_size,
121 | timeout_ms)
122 |
123 | # Compute in batches using a queue runner.
124 |
125 | if minimum_batch_size == maximum_batch_size and timeout_ms is None:
126 | batch_size = minimum_batch_size
127 | else:
128 | batch_size = None
129 |
130 | # Dequeue batched input.
131 | inputs, computation_id = batcher[0].get_inputs(input_dtypes)
132 | nest.map_structure(
133 | lambda i, a: i.set_shape([batch_size] + a.shape.as_list()[1:]),
134 | inputs, flat_args)
135 |
136 | # Compute result.
137 | result = f(*nest.pack_sequence_as(args, inputs))
138 | batched_output[0] = result
139 | flat_result = nest.flatten(result)
140 |
141 | # Insert results back into batcher.
142 | set_op = batcher[0].set_outputs(flat_result, computation_id)
143 |
144 | tf.train.add_queue_runner(tf.train.QueueRunner(batcher[0], [set_op]))
145 |
146 | # Insert inputs into input queue.
147 | flat_result = batcher[0].compute(
148 | flat_args,
149 | [t.dtype for t in nest.flatten(batched_output[0])])
150 |
151 | # Restore structure and shapes.
152 | result = nest.pack_sequence_as(batched_output[0], flat_result)
153 | static_batch_size = nest.flatten(args)[0].shape[0]
154 |
155 | nest.map_structure(
156 | lambda t, b: t.set_shape([static_batch_size] + b.shape[1:].as_list()),
157 | result, batched_output[0])
158 | return result
159 |
160 | return wrapper
161 |
162 | return decorator
163 |
--------------------------------------------------------------------------------
/dynamic_batching_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests dynamic_batching.py."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import datetime
22 | from multiprocessing import pool
23 | import time
24 |
25 | import dynamic_batching
26 |
27 | import tensorflow as tf
28 |
29 | from six.moves import range
30 |
31 |
32 | _SLEEP_TIME = 1.0
33 |
34 |
35 | class DynamicBatchingTest(tf.test.TestCase):
36 |
37 | def test_one(self):
38 | with self.test_session() as session:
39 | @dynamic_batching.batch_fn
40 | def f(a, b):
41 | batch_size = tf.shape(a)[0]
42 | return a + b, tf.tile([batch_size], [batch_size])
43 |
44 | output = f(tf.constant([[1, 3]]), tf.constant([2]))
45 |
46 | tf.train.start_queue_runners()
47 |
48 | result, batch_size = session.run(output)
49 |
50 | self.assertAllEqual([[3, 5]], result)
51 | self.assertAllEqual([1], batch_size)
52 |
53 | def test_two(self):
54 | with self.test_session() as session:
55 | @dynamic_batching.batch_fn
56 | def f(a, b):
57 | batch_size = tf.shape(a)[0]
58 | return a + b, tf.tile([batch_size], [batch_size])
59 |
60 | output0 = f(tf.constant([1]), tf.constant([2]))
61 | output1 = f(tf.constant([2]), tf.constant([3]))
62 |
63 | tp = pool.ThreadPool(2)
64 | f0 = tp.apply_async(session.run, [output0])
65 | f1 = tp.apply_async(session.run, [output1])
66 |
67 | # Make sure both inputs are in the batcher before starting it.
68 | time.sleep(_SLEEP_TIME)
69 |
70 | tf.train.start_queue_runners()
71 |
72 | result0, batch_size0 = f0.get()
73 | result1, batch_size1 = f1.get()
74 |
75 | self.assertAllEqual([3], result0)
76 | self.assertAllEqual([2], batch_size0)
77 | self.assertAllEqual([5], result1)
78 | self.assertAllEqual([2], batch_size1)
79 |
80 | def test_many_small(self):
81 | with self.test_session() as session:
82 | @dynamic_batching.batch_fn
83 | def f(a, b):
84 | return a + b
85 |
86 | outputs = []
87 | for i in range(200):
88 | outputs.append(f(tf.fill([1, 5], i), tf.fill([1, 5], i)))
89 |
90 | tf.train.start_queue_runners()
91 |
92 | tp = pool.ThreadPool(10)
93 | futures = []
94 | for output in outputs:
95 | futures.append(tp.apply_async(session.run, [output]))
96 |
97 | for i, future in enumerate(futures):
98 | result = future.get()
99 | self.assertAllEqual([[i * 2] * 5], result)
100 |
101 | def test_input_batch_size_should_be_one(self):
102 | with self.test_session() as session:
103 | @dynamic_batching.batch_fn
104 | def f(a):
105 | return a
106 |
107 | output = f(tf.constant([1, 2]))
108 |
109 | coord = tf.train.Coordinator()
110 | tf.train.start_queue_runners(coord=coord)
111 |
112 | with self.assertRaises(tf.errors.CancelledError):
113 | session.run(output)
114 |
115 | with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
116 | 'requires batch size 1'):
117 | coord.join()
118 |
119 | def test_run_after_error_should_be_cancelled(self):
120 | with self.test_session() as session:
121 |
122 | @dynamic_batching.batch_fn
123 | def f(a):
124 | return a
125 |
126 | output = f(tf.constant([1, 2]))
127 |
128 | coord = tf.train.Coordinator()
129 | tf.train.start_queue_runners(coord=coord)
130 |
131 | with self.assertRaises(tf.errors.CancelledError):
132 | session.run(output)
133 |
134 | with self.assertRaises(tf.errors.CancelledError):
135 | session.run(output)
136 |
137 | def test_input_shapes_should_be_equal(self):
138 | with self.test_session() as session:
139 |
140 | @dynamic_batching.batch_fn
141 | def f(a, b):
142 | return a + b
143 |
144 | output0 = f(tf.constant([1]), tf.constant([2]))
145 | output1 = f(tf.constant([[2]]), tf.constant([3]))
146 |
147 | tp = pool.ThreadPool(2)
148 | f0 = tp.apply_async(session.run, [output0])
149 | f1 = tp.apply_async(session.run, [output1])
150 |
151 | time.sleep(_SLEEP_TIME)
152 |
153 | coord = tf.train.Coordinator()
154 | tf.train.start_queue_runners(coord=coord)
155 |
156 | with self.assertRaises(tf.errors.CancelledError):
157 | f0.get()
158 | f1.get()
159 |
160 | with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
161 | 'Shapes of inputs much be equal'):
162 | coord.join()
163 |
164 | def test_output_must_have_batch_dimension(self):
165 | with self.test_session() as session:
166 | @dynamic_batching.batch_fn
167 | def f(_):
168 | return tf.constant(1)
169 |
170 | output = f(tf.constant([1]))
171 |
172 | coord = tf.train.Coordinator()
173 | tf.train.start_queue_runners(coord=coord)
174 |
175 | with self.assertRaises(tf.errors.CancelledError):
176 | session.run(output)
177 |
178 | with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
179 | 'Output shape must have a batch dimension'):
180 | coord.join()
181 |
182 | def test_output_must_have_same_batch_dimension_size_as_input(self):
183 | with self.test_session() as session:
184 | @dynamic_batching.batch_fn
185 | def f(_):
186 | return tf.constant([1, 2, 3, 4])
187 |
188 | output = f(tf.constant([1]))
189 |
190 | coord = tf.train.Coordinator()
191 | tf.train.start_queue_runners(coord=coord)
192 |
193 | with self.assertRaises(tf.errors.CancelledError):
194 | session.run(output)
195 |
196 | with self.assertRaisesRegexp(
197 | tf.errors.InvalidArgumentError,
198 | 'Output shape must have the same batch dimension as the input batch '
199 | 'size. Expected: 1 Observed: 4'):
200 | coord.join()
201 |
202 | def test_get_inputs_cancelled(self):
203 | with tf.Graph().as_default():
204 |
205 | @dynamic_batching.batch_fn
206 | def f(a):
207 | return a
208 |
209 | f(tf.constant([1]))
210 |
211 | # Intentionally using tf.Session() instead of self.test_session() to have
212 | # control over closing the session. test_session() is a cached session.
213 | with tf.Session():
214 | coord = tf.train.Coordinator()
215 | tf.train.start_queue_runners(coord=coord)
216 | # Sleep to make sure the queue runner has started the first run call.
217 | time.sleep(_SLEEP_TIME)
218 |
219 | # Session closed.
220 | with self.assertRaisesRegexp(tf.errors.CancelledError,
221 | 'GetInputs operation was cancelled'):
222 | coord.join()
223 |
224 | def test_batcher_closed(self):
225 | with tf.Graph().as_default():
226 | @dynamic_batching.batch_fn
227 | def f(a):
228 | return a
229 |
230 | f(tf.constant([1]))
231 |
232 | # Intentionally using tf.Session() instead of self.test_session() to have
233 | # control over closing the session. test_session() is a cached session.
234 | with tf.Session():
235 | coord = tf.train.Coordinator()
236 | tf.train.start_queue_runners(coord=coord)
237 | time.sleep(_SLEEP_TIME)
238 | coord.request_stop() # Calls close operation.
239 | coord.join()
240 | # Session closed.
241 |
242 | def test_minimum_batch_size(self):
243 | with self.test_session() as session:
244 | @dynamic_batching.batch_fn_with_options(
245 | minimum_batch_size=2, timeout_ms=1000)
246 | def f(a, b):
247 | batch_size = tf.shape(a)[0]
248 | return a + b, tf.tile([batch_size], [batch_size])
249 |
250 | output = f(tf.constant([[1, 3]]), tf.constant([2]))
251 |
252 | tf.train.start_queue_runners()
253 |
254 | start = datetime.datetime.now()
255 | session.run(output)
256 | duration = datetime.datetime.now() - start
257 |
258 | # There should have been a timeout here because only one sample was added
259 | # and the minimum batch size is 2.
260 | self.assertLessEqual(.9, duration.total_seconds())
261 | self.assertGreaterEqual(1.5, duration.total_seconds())
262 |
263 | outputs = [
264 | f(tf.constant([[1, 3]]), tf.constant([2])),
265 | f(tf.constant([[1, 3]]), tf.constant([2]))
266 | ]
267 |
268 | start = datetime.datetime.now()
269 | (_, batch_size), _ = session.run(outputs)
270 | duration = datetime.datetime.now() - start
271 |
272 | # The outputs should be executed immediately because two samples are
273 | # added.
274 | self.assertGreaterEqual(.5, duration.total_seconds())
275 | self.assertEqual(2, batch_size)
276 |
277 | def test_maximum_batch_size(self):
278 | with self.test_session() as session:
279 | @dynamic_batching.batch_fn_with_options(maximum_batch_size=2)
280 | def f(a, b):
281 | batch_size = tf.shape(a)[0]
282 | return a + b, tf.tile([batch_size], [batch_size])
283 |
284 | outputs = [
285 | f(tf.constant([1]), tf.constant([2])),
286 | f(tf.constant([1]), tf.constant([2])),
287 | f(tf.constant([1]), tf.constant([2])),
288 | f(tf.constant([1]), tf.constant([2])),
289 | f(tf.constant([1]), tf.constant([2])),
290 | ]
291 |
292 | tf.train.start_queue_runners()
293 |
294 | results = session.run(outputs)
295 |
296 | for value, batch_size in results:
297 | self.assertEqual(3, value)
298 | self.assertGreaterEqual(2, batch_size)
299 |
300 | def test_static_shape(self):
301 | assertions_triggered = [0]
302 |
303 | @dynamic_batching.batch_fn_with_options(minimum_batch_size=1,
304 | maximum_batch_size=2)
305 | def f0(a):
306 | self.assertEqual(None, a.shape[0].value)
307 | assertions_triggered[0] += 1
308 | return a
309 |
310 | @dynamic_batching.batch_fn_with_options(minimum_batch_size=2,
311 | maximum_batch_size=2)
312 | def f1(a):
313 | # Even though minimum_batch_size and maximum_batch_size are equal, the
314 | # timeout can cause a batch with less than mininum_batch_size.
315 | self.assertEqual(None, a.shape[0].value)
316 | assertions_triggered[0] += 1
317 | return a
318 |
319 | @dynamic_batching.batch_fn_with_options(minimum_batch_size=2,
320 | maximum_batch_size=2,
321 | timeout_ms=None)
322 | def f2(a):
323 | # When timeout is disabled and minimum/maximum batch size are equal, the
324 | # shape is statically known.
325 | self.assertEqual(2, a.shape[0].value)
326 | assertions_triggered[0] += 1
327 | return a
328 |
329 | f0(tf.constant([1]))
330 | f1(tf.constant([1]))
331 | f2(tf.constant([1]))
332 | self.assertEqual(3, assertions_triggered[0])
333 |
334 | def test_out_of_order_execution1(self):
335 | with self.test_session() as session:
336 | batcher = dynamic_batching._Batcher(minimum_batch_size=1,
337 | maximum_batch_size=1,
338 | timeout_ms=None)
339 |
340 | tp = pool.ThreadPool(10)
341 | r0 = tp.apply_async(session.run, batcher.compute([[1]], [tf.int32]))
342 | (input0,), computation_id0 = session.run(batcher.get_inputs([tf.int32]))
343 | r1 = tp.apply_async(session.run, batcher.compute([[2]], [tf.int32]))
344 | (input1,), computation_id1 = session.run(batcher.get_inputs([tf.int32]))
345 |
346 | self.assertAllEqual([1], input0)
347 | self.assertAllEqual([2], input1)
348 |
349 | session.run(batcher.set_outputs([input0 + 42], computation_id0))
350 | session.run(batcher.set_outputs([input1 + 42], computation_id1))
351 |
352 | self.assertAllEqual([43], r0.get())
353 | self.assertAllEqual([44], r1.get())
354 |
355 | def test_out_of_order_execution2(self):
356 | with self.test_session() as session:
357 | batcher = dynamic_batching._Batcher(minimum_batch_size=1,
358 | maximum_batch_size=1,
359 | timeout_ms=None)
360 |
361 | tp = pool.ThreadPool(10)
362 | r0 = tp.apply_async(session.run, batcher.compute([[1]], [tf.int32]))
363 | (input0,), computation_id0 = session.run(batcher.get_inputs([tf.int32]))
364 | r1 = tp.apply_async(session.run, batcher.compute([[2]], [tf.int32]))
365 | (input1,), computation_id1 = session.run(batcher.get_inputs([tf.int32]))
366 |
367 | self.assertAllEqual([1], input0)
368 | self.assertAllEqual([2], input1)
369 |
370 | # These two runs are switched from testOutOfOrderExecution1.
371 | session.run(batcher.set_outputs([input1 + 42], computation_id1))
372 | session.run(batcher.set_outputs([input0 + 42], computation_id0))
373 |
374 | self.assertAllEqual([43], r0.get())
375 | self.assertAllEqual([44], r1.get())
376 |
377 | def test_invalid_computation_id(self):
378 | with self.test_session() as session:
379 | batcher = dynamic_batching._Batcher(minimum_batch_size=1,
380 | maximum_batch_size=1,
381 | timeout_ms=None)
382 |
383 | tp = pool.ThreadPool(10)
384 | tp.apply_async(session.run, batcher.compute([[1]], [tf.int32]))
385 | (input0,), _ = session.run(batcher.get_inputs([tf.int32]))
386 |
387 | self.assertAllEqual([1], input0)
388 |
389 | with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
390 | 'Invalid computation id'):
391 | session.run(batcher.set_outputs([input0], 42))
392 |
393 | def test_op_shape(self):
394 | with self.test_session():
395 | batcher = dynamic_batching._Batcher(minimum_batch_size=1,
396 | maximum_batch_size=1,
397 | timeout_ms=None)
398 |
399 | _, computation_id = batcher.get_inputs([tf.int32])
400 |
401 | self.assertEqual([], computation_id.shape)
402 |
403 |
404 | class DynamicBatchingBenchmarks(tf.test.Benchmark):
405 |
406 | def benchmark_batching_small(self):
407 | with tf.Session() as session:
408 | @dynamic_batching.batch_fn
409 | def f(a, b):
410 | return a + b
411 |
412 | outputs = []
413 | for _ in range(1000):
414 | outputs.append(f(tf.ones([1, 10]), tf.ones([1, 10])))
415 | op_to_benchmark = tf.group(*outputs)
416 |
417 | tf.train.start_queue_runners()
418 |
419 | self.run_op_benchmark(
420 | name='batching_many_small',
421 | sess=session,
422 | op_or_tensor=op_to_benchmark,
423 | burn_iters=10,
424 | min_iters=50)
425 |
426 | def benchmark_batching_large(self):
427 | with tf.Session() as session:
428 | @dynamic_batching.batch_fn
429 | def f(a, b):
430 | return a + b
431 |
432 | outputs = []
433 | for _ in range(1000):
434 | outputs.append(f(tf.ones([1, 100000]), tf.ones([1, 100000])))
435 | op_to_benchmark = tf.group(*outputs)
436 |
437 | tf.train.start_queue_runners()
438 |
439 | self.run_op_benchmark(
440 | name='batching_many_large',
441 | sess=session,
442 | op_or_tensor=op_to_benchmark,
443 | burn_iters=10,
444 | min_iters=50)
445 |
446 |
447 | if __name__ == '__main__':
448 | tf.test.main()
449 |
--------------------------------------------------------------------------------
/environments.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Environments and environment helper classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os.path
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 |
27 | import deepmind_lab
28 |
29 |
30 | nest = tf.contrib.framework.nest
31 |
32 |
33 | class LocalLevelCache(object):
34 | """Local level cache."""
35 |
36 | def __init__(self, cache_dir='/tmp/level_cache'):
37 | self._cache_dir = cache_dir
38 | tf.gfile.MakeDirs(cache_dir)
39 |
40 | def fetch(self, key, pk3_path):
41 | path = os.path.join(self._cache_dir, key)
42 | if tf.gfile.Exists(path):
43 | tf.gfile.Copy(path, pk3_path, overwrite=True)
44 | return True
45 | return False
46 |
47 | def write(self, key, pk3_path):
48 | path = os.path.join(self._cache_dir, key)
49 | if not tf.gfile.Exists(path):
50 | tf.gfile.Copy(pk3_path, path)
51 |
52 |
53 | DEFAULT_ACTION_SET = (
54 | (0, 0, 0, 1, 0, 0, 0), # Forward
55 | (0, 0, 0, -1, 0, 0, 0), # Backward
56 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left
57 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right
58 | (-20, 0, 0, 0, 0, 0, 0), # Look Left
59 | (20, 0, 0, 0, 0, 0, 0), # Look Right
60 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward
61 | (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward
62 | (0, 0, 0, 0, 1, 0, 0), # Fire.
63 | )
64 |
65 |
66 | class PyProcessDmLab(object):
67 | """DeepMind Lab wrapper for PyProcess."""
68 |
69 | def __init__(self, level, config, num_action_repeats, seed,
70 | runfiles_path=None, level_cache=None):
71 | self._num_action_repeats = num_action_repeats
72 | self._random_state = np.random.RandomState(seed=seed)
73 | if runfiles_path:
74 | deepmind_lab.set_runfiles_path(runfiles_path)
75 | config = {k: str(v) for k, v in config.iteritems()}
76 | self._observation_spec = ['RGB_INTERLEAVED', 'INSTR']
77 | self._env = deepmind_lab.Lab(
78 | level=level,
79 | observations=self._observation_spec,
80 | config=config,
81 | level_cache=level_cache,
82 | )
83 |
84 | def _reset(self):
85 | self._env.reset(seed=self._random_state.randint(0, 2 ** 31 - 1))
86 |
87 | def _observation(self):
88 | d = self._env.observations()
89 | return [d[k] for k in self._observation_spec]
90 |
91 | def initial(self):
92 | self._reset()
93 | return self._observation()
94 |
95 | def step(self, action):
96 | reward = self._env.step(action, num_steps=self._num_action_repeats)
97 | done = np.array(not self._env.is_running())
98 | if done:
99 | self._reset()
100 | observation = self._observation()
101 | reward = np.array(reward, dtype=np.float32)
102 | return reward, done, observation
103 |
104 | def close(self):
105 | self._env.close()
106 |
107 | @staticmethod
108 | def _tensor_specs(method_name, unused_kwargs, constructor_kwargs):
109 | """Returns a nest of `TensorSpec` with the method's output specification."""
110 | width = constructor_kwargs['config'].get('width', 320)
111 | height = constructor_kwargs['config'].get('height', 240)
112 |
113 | observation_spec = [
114 | tf.contrib.framework.TensorSpec([height, width, 3], tf.uint8),
115 | tf.contrib.framework.TensorSpec([], tf.string),
116 | ]
117 |
118 | if method_name == 'initial':
119 | return observation_spec
120 | elif method_name == 'step':
121 | return (
122 | tf.contrib.framework.TensorSpec([], tf.float32),
123 | tf.contrib.framework.TensorSpec([], tf.bool),
124 | observation_spec,
125 | )
126 |
127 |
128 | StepOutputInfo = collections.namedtuple('StepOutputInfo',
129 | 'episode_return episode_step')
130 | StepOutput = collections.namedtuple('StepOutput',
131 | 'reward info done observation')
132 |
133 |
134 | class FlowEnvironment(object):
135 | """An environment that returns a new state for every modifying method.
136 |
137 | The environment returns a new environment state for every modifying action and
138 | forces previous actions to be completed first. Similar to `flow` for
139 | `TensorArray`.
140 | """
141 |
142 | def __init__(self, env):
143 | """Initializes the environment.
144 |
145 | Args:
146 | env: An environment with `initial()` and `step(action)` methods where
147 | `initial` returns the initial observations and `step` takes an action
148 | and returns a tuple of (reward, done, observation). `observation`
149 | should be the observation after the step is taken. If `done` is
150 | True, the observation should be the first observation in the next
151 | episode.
152 | """
153 | self._env = env
154 |
155 | def initial(self):
156 | """Returns the initial output and initial state.
157 |
158 | Returns:
159 | A tuple of (`StepOutput`, environment state). The environment state should
160 | be passed in to the next invocation of `step` and should not be used in
161 | any other way. The reward and transition type in the `StepOutput` is the
162 | reward/transition type that lead to the observation in `StepOutput`.
163 | """
164 | with tf.name_scope('flow_environment_initial'):
165 | initial_reward = tf.constant(0.)
166 | initial_info = StepOutputInfo(tf.constant(0.), tf.constant(0))
167 | initial_done = tf.constant(True)
168 | initial_observation = self._env.initial()
169 |
170 | initial_output = StepOutput(
171 | initial_reward,
172 | initial_info,
173 | initial_done,
174 | initial_observation)
175 |
176 | # Control dependency to make sure the next step can't be taken before the
177 | # initial output has been read from the environment.
178 | with tf.control_dependencies(nest.flatten(initial_output)):
179 | initial_flow = tf.constant(0, dtype=tf.int64)
180 | initial_state = (initial_flow, initial_info)
181 | return initial_output, initial_state
182 |
183 | def step(self, action, state):
184 | """Takes a step in the environment.
185 |
186 | Args:
187 | action: An action tensor suitable for the underlying environment.
188 | state: The environment state from the last step or initial state.
189 |
190 | Returns:
191 | A tuple of (`StepOutput`, environment state). The environment state should
192 | be passed in to the next invocation of `step` and should not be used in
193 | any other way. On episode end (i.e. `done` is True), the returned reward
194 | should be included in the sum of rewards for the ending episode and not
195 | part of the next episode.
196 | """
197 | with tf.name_scope('flow_environment_step'):
198 | flow, info = nest.map_structure(tf.convert_to_tensor, state)
199 |
200 | # Make sure the previous step has been executed before running the next
201 | # step.
202 | with tf.control_dependencies([flow]):
203 | reward, done, observation = self._env.step(action)
204 |
205 | with tf.control_dependencies(nest.flatten(observation)):
206 | new_flow = tf.add(flow, 1)
207 |
208 | # When done, include the reward in the output info but not in the
209 | # state for the next step.
210 | new_info = StepOutputInfo(info.episode_return + reward,
211 | info.episode_step + 1)
212 | new_state = new_flow, nest.map_structure(
213 | lambda a, b: tf.where(done, a, b),
214 | StepOutputInfo(tf.constant(0.), tf.constant(0)),
215 | new_info)
216 |
217 | output = StepOutput(reward, new_info, done, observation)
218 | return output, new_state
219 |
--------------------------------------------------------------------------------
/experiment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Importance Weighted Actor-Learner Architectures."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import contextlib
23 | import functools
24 | import os
25 | import sys
26 |
27 | import dmlab30
28 | import environments
29 | import numpy as np
30 | import py_process
31 | import sonnet as snt
32 | import tensorflow as tf
33 | import vtrace
34 |
35 | try:
36 | import dynamic_batching
37 | except tf.errors.NotFoundError:
38 | tf.logging.warning('Running without dynamic batching.')
39 |
40 | from six.moves import range
41 |
42 |
43 | nest = tf.contrib.framework.nest
44 |
45 | flags = tf.app.flags
46 | FLAGS = tf.app.flags.FLAGS
47 |
48 | flags.DEFINE_string('logdir', '/tmp/agent', 'TensorFlow log directory.')
49 | flags.DEFINE_enum('mode', 'train', ['train', 'test'], 'Training or test mode.')
50 |
51 | # Flags used for testing.
52 | flags.DEFINE_integer('test_num_episodes', 10, 'Number of episodes per level.')
53 |
54 | # Flags used for distributed training.
55 | flags.DEFINE_integer('task', -1, 'Task id. Use -1 for local training.')
56 | flags.DEFINE_enum('job_name', 'learner', ['learner', 'actor'],
57 | 'Job name. Ignored when task is set to -1.')
58 |
59 | # Training.
60 | flags.DEFINE_integer('total_environment_frames', int(1e9),
61 | 'Total environment frames to train for.')
62 | flags.DEFINE_integer('num_actors', 4, 'Number of actors.')
63 | flags.DEFINE_integer('batch_size', 2, 'Batch size for training.')
64 | flags.DEFINE_integer('unroll_length', 100, 'Unroll length in agent steps.')
65 | flags.DEFINE_integer('num_action_repeats', 4, 'Number of action repeats.')
66 | flags.DEFINE_integer('seed', 1, 'Random seed.')
67 |
68 | # Loss settings.
69 | flags.DEFINE_float('entropy_cost', 0.00025, 'Entropy cost/multiplier.')
70 | flags.DEFINE_float('baseline_cost', .5, 'Baseline cost/multiplier.')
71 | flags.DEFINE_float('discounting', .99, 'Discounting factor.')
72 | flags.DEFINE_enum('reward_clipping', 'abs_one', ['abs_one', 'soft_asymmetric'],
73 | 'Reward clipping.')
74 |
75 | # Environment settings.
76 | flags.DEFINE_string(
77 | 'dataset_path', '',
78 | 'Path to dataset needed for psychlab_*, see '
79 | 'https://github.com/deepmind/lab/tree/master/data/brady_konkle_oliva2008')
80 | flags.DEFINE_string('level_name', 'explore_goal_locations_small',
81 | '''Level name or \'dmlab30\' for the full DmLab-30 suite '''
82 | '''with levels assigned round robin to the actors.''')
83 | flags.DEFINE_integer('width', 96, 'Width of observation.')
84 | flags.DEFINE_integer('height', 72, 'Height of observation.')
85 |
86 | # Optimizer settings.
87 | flags.DEFINE_float('learning_rate', 0.00048, 'Learning rate.')
88 | flags.DEFINE_float('decay', .99, 'RMSProp optimizer decay.')
89 | flags.DEFINE_float('momentum', 0., 'RMSProp momentum.')
90 | flags.DEFINE_float('epsilon', .1, 'RMSProp epsilon.')
91 |
92 |
93 | # Structure to be sent from actors to learner.
94 | ActorOutput = collections.namedtuple(
95 | 'ActorOutput', 'level_name agent_state env_outputs agent_outputs')
96 | AgentOutput = collections.namedtuple('AgentOutput',
97 | 'action policy_logits baseline')
98 |
99 |
100 | def is_single_machine():
101 | return FLAGS.task == -1
102 |
103 |
104 | class Agent(snt.RNNCore):
105 | """Agent with ResNet."""
106 |
107 | def __init__(self, num_actions):
108 | super(Agent, self).__init__(name='agent')
109 |
110 | self._num_actions = num_actions
111 |
112 | with self._enter_variable_scope():
113 | self._core = tf.contrib.rnn.LSTMBlockCell(256)
114 |
115 | def initial_state(self, batch_size):
116 | return self._core.zero_state(batch_size, tf.float32)
117 |
118 | def _instruction(self, instruction):
119 | # Split string.
120 | splitted = tf.string_split(instruction)
121 | dense = tf.sparse_tensor_to_dense(splitted, default_value='')
122 | length = tf.reduce_sum(tf.to_int32(tf.not_equal(dense, '')), axis=1)
123 |
124 | # To int64 hash buckets. Small risk of having collisions. Alternatively, a
125 | # vocabulary can be used.
126 | num_hash_buckets = 1000
127 | buckets = tf.string_to_hash_bucket_fast(dense, num_hash_buckets)
128 |
129 | # Embed the instruction. Embedding size 20 seems to be enough.
130 | embedding_size = 20
131 | embedding = snt.Embed(num_hash_buckets, embedding_size)(buckets)
132 |
133 | # Pad to make sure there is at least one output.
134 | padding = tf.to_int32(tf.equal(tf.shape(embedding)[1], 0))
135 | embedding = tf.pad(embedding, [[0, 0], [0, padding], [0, 0]])
136 |
137 | core = tf.contrib.rnn.LSTMBlockCell(64, name='language_lstm')
138 | output, _ = tf.nn.dynamic_rnn(core, embedding, length, dtype=tf.float32)
139 |
140 | # Return last output.
141 | return tf.reverse_sequence(output, length, seq_axis=1)[:, 0]
142 |
143 | def _torso(self, input_):
144 | last_action, env_output = input_
145 | reward, _, _, (frame, instruction) = env_output
146 |
147 | # Convert to floats.
148 | frame = tf.to_float(frame)
149 |
150 | frame /= 255
151 | with tf.variable_scope('convnet'):
152 | conv_out = frame
153 | for i, (num_ch, num_blocks) in enumerate([(16, 2), (32, 2), (32, 2)]):
154 | # Downscale.
155 | conv_out = snt.Conv2D(num_ch, 3, stride=1, padding='SAME')(conv_out)
156 | conv_out = tf.nn.pool(
157 | conv_out,
158 | window_shape=[3, 3],
159 | pooling_type='MAX',
160 | padding='SAME',
161 | strides=[2, 2])
162 |
163 | # Residual block(s).
164 | for j in range(num_blocks):
165 | with tf.variable_scope('residual_%d_%d' % (i, j)):
166 | block_input = conv_out
167 | conv_out = tf.nn.relu(conv_out)
168 | conv_out = snt.Conv2D(num_ch, 3, stride=1, padding='SAME')(conv_out)
169 | conv_out = tf.nn.relu(conv_out)
170 | conv_out = snt.Conv2D(num_ch, 3, stride=1, padding='SAME')(conv_out)
171 | conv_out += block_input
172 |
173 | conv_out = tf.nn.relu(conv_out)
174 | conv_out = snt.BatchFlatten()(conv_out)
175 |
176 | conv_out = snt.Linear(256)(conv_out)
177 | conv_out = tf.nn.relu(conv_out)
178 |
179 | instruction_out = self._instruction(instruction)
180 |
181 | # Append clipped last reward and one hot last action.
182 | clipped_reward = tf.expand_dims(tf.clip_by_value(reward, -1, 1), -1)
183 | one_hot_last_action = tf.one_hot(last_action, self._num_actions)
184 | return tf.concat(
185 | [conv_out, clipped_reward, one_hot_last_action, instruction_out],
186 | axis=1)
187 |
188 | def _head(self, core_output):
189 | policy_logits = snt.Linear(self._num_actions, name='policy_logits')(
190 | core_output)
191 | baseline = tf.squeeze(snt.Linear(1, name='baseline')(core_output), axis=-1)
192 |
193 | # Sample an action from the policy.
194 | new_action = tf.multinomial(policy_logits, num_samples=1,
195 | output_dtype=tf.int32)
196 | new_action = tf.squeeze(new_action, 1, name='new_action')
197 |
198 | return AgentOutput(new_action, policy_logits, baseline)
199 |
200 | def _build(self, input_, core_state):
201 | action, env_output = input_
202 | actions, env_outputs = nest.map_structure(lambda t: tf.expand_dims(t, 0),
203 | (action, env_output))
204 | outputs, core_state = self.unroll(actions, env_outputs, core_state)
205 | return nest.map_structure(lambda t: tf.squeeze(t, 0), outputs), core_state
206 |
207 | @snt.reuse_variables
208 | def unroll(self, actions, env_outputs, core_state):
209 | _, _, done, _ = env_outputs
210 |
211 | torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs))
212 |
213 | # Note, in this implementation we can't use CuDNN RNN to speed things up due
214 | # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be
215 | # changed to implement snt.LSTMCell).
216 | initial_core_state = self._core.zero_state(tf.shape(actions)[1], tf.float32)
217 | core_output_list = []
218 | for input_, d in zip(tf.unstack(torso_outputs), tf.unstack(done)):
219 | # If the episode ended, the core state should be reset before the next.
220 | core_state = nest.map_structure(functools.partial(tf.where, d),
221 | initial_core_state, core_state)
222 | core_output, core_state = self._core(input_, core_state)
223 | core_output_list.append(core_output)
224 |
225 | return snt.BatchApply(self._head)(tf.stack(core_output_list)), core_state
226 |
227 |
228 | def build_actor(agent, env, level_name, action_set):
229 | """Builds the actor loop."""
230 | # Initial values.
231 | initial_env_output, initial_env_state = env.initial()
232 | initial_agent_state = agent.initial_state(1)
233 | initial_action = tf.zeros([1], dtype=tf.int32)
234 | dummy_agent_output, _ = agent(
235 | (initial_action,
236 | nest.map_structure(lambda t: tf.expand_dims(t, 0), initial_env_output)),
237 | initial_agent_state)
238 | initial_agent_output = nest.map_structure(
239 | lambda t: tf.zeros(t.shape, t.dtype), dummy_agent_output)
240 |
241 | # All state that needs to persist across training iterations. This includes
242 | # the last environment output, agent state and last agent output. These
243 | # variables should never go on the parameter servers.
244 | def create_state(t):
245 | # Creates a unique variable scope to ensure the variable name is unique.
246 | with tf.variable_scope(None, default_name='state'):
247 | return tf.get_local_variable(t.op.name, initializer=t, use_resource=True)
248 |
249 | persistent_state = nest.map_structure(
250 | create_state, (initial_env_state, initial_env_output, initial_agent_state,
251 | initial_agent_output))
252 |
253 | def step(input_, unused_i):
254 | """Steps through the agent and the environment."""
255 | env_state, env_output, agent_state, agent_output = input_
256 |
257 | # Run agent.
258 | action = agent_output[0]
259 | batched_env_output = nest.map_structure(lambda t: tf.expand_dims(t, 0),
260 | env_output)
261 | agent_output, agent_state = agent((action, batched_env_output), agent_state)
262 |
263 | # Convert action index to the native action.
264 | action = agent_output[0][0]
265 | raw_action = tf.gather(action_set, action)
266 |
267 | env_output, env_state = env.step(raw_action, env_state)
268 |
269 | return env_state, env_output, agent_state, agent_output
270 |
271 | # Run the unroll. `read_value()` is needed to make sure later usage will
272 | # return the first values and not a new snapshot of the variables.
273 | first_values = nest.map_structure(lambda v: v.read_value(), persistent_state)
274 | _, first_env_output, first_agent_state, first_agent_output = first_values
275 |
276 | # Use scan to apply `step` multiple times, therefore unrolling the agent
277 | # and environment interaction for `FLAGS.unroll_length`. `tf.scan` forwards
278 | # the output of each call of `step` as input of the subsequent call of `step`.
279 | # The unroll sequence is initialized with the agent and environment states
280 | # and outputs as stored at the end of the previous unroll.
281 | # `output` stores lists of all states and outputs stacked along the entire
282 | # unroll. Note that the initial states and outputs (fed through `initializer`)
283 | # are not in `output` and will need to be added manually later.
284 | output = tf.scan(step, tf.range(FLAGS.unroll_length), first_values)
285 | _, env_outputs, _, agent_outputs = output
286 |
287 | # Update persistent state with the last output from the loop.
288 | assign_ops = nest.map_structure(lambda v, t: v.assign(t[-1]),
289 | persistent_state, output)
290 |
291 | # The control dependency ensures that the final agent and environment states
292 | # and outputs are stored in `persistent_state` (to initialize next unroll).
293 | with tf.control_dependencies(nest.flatten(assign_ops)):
294 | # Remove the batch dimension from the agent state/output.
295 | first_agent_state = nest.map_structure(lambda t: t[0], first_agent_state)
296 | first_agent_output = nest.map_structure(lambda t: t[0], first_agent_output)
297 | agent_outputs = nest.map_structure(lambda t: t[:, 0], agent_outputs)
298 |
299 | # Concatenate first output and the unroll along the time dimension.
300 | full_agent_outputs, full_env_outputs = nest.map_structure(
301 | lambda first, rest: tf.concat([[first], rest], 0),
302 | (first_agent_output, first_env_output), (agent_outputs, env_outputs))
303 |
304 | output = ActorOutput(
305 | level_name=level_name, agent_state=first_agent_state,
306 | env_outputs=full_env_outputs, agent_outputs=full_agent_outputs)
307 |
308 | # No backpropagation should be done here.
309 | return nest.map_structure(tf.stop_gradient, output)
310 |
311 |
312 | def compute_baseline_loss(advantages):
313 | # Loss for the baseline, summed over the time dimension.
314 | # Multiply by 0.5 to match the standard update rule:
315 | # d(loss) / d(baseline) = advantage
316 | return .5 * tf.reduce_sum(tf.square(advantages))
317 |
318 |
319 | def compute_entropy_loss(logits):
320 | policy = tf.nn.softmax(logits)
321 | log_policy = tf.nn.log_softmax(logits)
322 | entropy_per_timestep = tf.reduce_sum(-policy * log_policy, axis=-1)
323 | return -tf.reduce_sum(entropy_per_timestep)
324 |
325 |
326 | def compute_policy_gradient_loss(logits, actions, advantages):
327 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
328 | labels=actions, logits=logits)
329 | advantages = tf.stop_gradient(advantages)
330 | policy_gradient_loss_per_timestep = cross_entropy * advantages
331 | return tf.reduce_sum(policy_gradient_loss_per_timestep)
332 |
333 |
334 | def build_learner(agent, agent_state, env_outputs, agent_outputs):
335 | """Builds the learner loop.
336 |
337 | Args:
338 | agent: A snt.RNNCore module outputting `AgentOutput` named tuples, with an
339 | `unroll` call for computing the outputs for a whole trajectory.
340 | agent_state: The initial agent state for each sequence in the batch.
341 | env_outputs: A `StepOutput` namedtuple where each field is of shape
342 | [T+1, ...].
343 | agent_outputs: An `AgentOutput` namedtuple where each field is of shape
344 | [T+1, ...].
345 |
346 | Returns:
347 | A tuple of (done, infos, and environment frames) where
348 | the environment frames tensor causes an update.
349 | """
350 | learner_outputs, _ = agent.unroll(agent_outputs.action, env_outputs,
351 | agent_state)
352 |
353 | # Use last baseline value (from the value function) to bootstrap.
354 | bootstrap_value = learner_outputs.baseline[-1]
355 |
356 | # At this point, the environment outputs at time step `t` are the inputs that
357 | # lead to the learner_outputs at time step `t`. After the following shifting,
358 | # the actions in agent_outputs and learner_outputs at time step `t` is what
359 | # leads to the environment outputs at time step `t`.
360 | agent_outputs = nest.map_structure(lambda t: t[1:], agent_outputs)
361 | rewards, infos, done, _ = nest.map_structure(
362 | lambda t: t[1:], env_outputs)
363 | learner_outputs = nest.map_structure(lambda t: t[:-1], learner_outputs)
364 |
365 | if FLAGS.reward_clipping == 'abs_one':
366 | clipped_rewards = tf.clip_by_value(rewards, -1, 1)
367 | elif FLAGS.reward_clipping == 'soft_asymmetric':
368 | squeezed = tf.tanh(rewards / 5.0)
369 | # Negative rewards are given less weight than positive rewards.
370 | clipped_rewards = tf.where(rewards < 0, .3 * squeezed, squeezed) * 5.
371 |
372 | discounts = tf.to_float(~done) * FLAGS.discounting
373 |
374 | # Compute V-trace returns and weights.
375 | # Note, this is put on the CPU because it's faster than on GPU. It can be
376 | # improved further with XLA-compilation or with a custom TensorFlow operation.
377 | with tf.device('/cpu'):
378 | vtrace_returns = vtrace.from_logits(
379 | behaviour_policy_logits=agent_outputs.policy_logits,
380 | target_policy_logits=learner_outputs.policy_logits,
381 | actions=agent_outputs.action,
382 | discounts=discounts,
383 | rewards=clipped_rewards,
384 | values=learner_outputs.baseline,
385 | bootstrap_value=bootstrap_value)
386 |
387 | # Compute loss as a weighted sum of the baseline loss, the policy gradient
388 | # loss and an entropy regularization term.
389 | total_loss = compute_policy_gradient_loss(
390 | learner_outputs.policy_logits, agent_outputs.action,
391 | vtrace_returns.pg_advantages)
392 | total_loss += FLAGS.baseline_cost * compute_baseline_loss(
393 | vtrace_returns.vs - learner_outputs.baseline)
394 | total_loss += FLAGS.entropy_cost * compute_entropy_loss(
395 | learner_outputs.policy_logits)
396 |
397 | # Optimization
398 | num_env_frames = tf.train.get_global_step()
399 | learning_rate = tf.train.polynomial_decay(FLAGS.learning_rate, num_env_frames,
400 | FLAGS.total_environment_frames, 0)
401 | optimizer = tf.train.RMSPropOptimizer(learning_rate, FLAGS.decay,
402 | FLAGS.momentum, FLAGS.epsilon)
403 | train_op = optimizer.minimize(total_loss)
404 |
405 | # Merge updating the network and environment frames into a single tensor.
406 | with tf.control_dependencies([train_op]):
407 | num_env_frames_and_train = num_env_frames.assign_add(
408 | FLAGS.batch_size * FLAGS.unroll_length * FLAGS.num_action_repeats)
409 |
410 | # Adding a few summaries.
411 | tf.summary.scalar('learning_rate', learning_rate)
412 | tf.summary.scalar('total_loss', total_loss)
413 | tf.summary.histogram('action', agent_outputs.action)
414 |
415 | return done, infos, num_env_frames_and_train
416 |
417 |
418 | def create_environment(level_name, seed, is_test=False):
419 | """Creates an environment wrapped in a `FlowEnvironment`."""
420 | if level_name in dmlab30.ALL_LEVELS:
421 | level_name = 'contributed/dmlab30/' + level_name
422 |
423 | # Note, you may want to use a level cache to speed of compilation of
424 | # environment maps. See the documentation for the Python interface of DeepMind
425 | # Lab.
426 | config = {
427 | 'width': FLAGS.width,
428 | 'height': FLAGS.height,
429 | 'datasetPath': FLAGS.dataset_path,
430 | 'logLevel': 'WARN',
431 | }
432 | if is_test:
433 | config['allowHoldOutLevels'] = 'true'
434 | # Mixer seed for evalution, see
435 | # https://github.com/deepmind/lab/blob/master/docs/users/python_api.md
436 | config['mixerSeed'] = 0x600D5EED
437 | p = py_process.PyProcess(environments.PyProcessDmLab, level_name, config,
438 | FLAGS.num_action_repeats, seed)
439 | return environments.FlowEnvironment(p.proxy)
440 |
441 |
442 | @contextlib.contextmanager
443 | def pin_global_variables(device):
444 | """Pins global variables to the specified device."""
445 | def getter(getter, *args, **kwargs):
446 | var_collections = kwargs.get('collections', None)
447 | if var_collections is None:
448 | var_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
449 | if tf.GraphKeys.GLOBAL_VARIABLES in var_collections:
450 | with tf.device(device):
451 | return getter(*args, **kwargs)
452 | else:
453 | return getter(*args, **kwargs)
454 |
455 | with tf.variable_scope('', custom_getter=getter) as vs:
456 | yield vs
457 |
458 |
459 | def train(action_set, level_names):
460 | """Train."""
461 |
462 | if is_single_machine():
463 | local_job_device = ''
464 | shared_job_device = ''
465 | is_actor_fn = lambda i: True
466 | is_learner = True
467 | global_variable_device = '/gpu'
468 | server = tf.train.Server.create_local_server()
469 | filters = []
470 | else:
471 | local_job_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task)
472 | shared_job_device = '/job:learner/task:0'
473 | is_actor_fn = lambda i: FLAGS.job_name == 'actor' and i == FLAGS.task
474 | is_learner = FLAGS.job_name == 'learner'
475 |
476 | # Placing the variable on CPU, makes it cheaper to send it to all the
477 | # actors. Continual copying the variables from the GPU is slow.
478 | global_variable_device = shared_job_device + '/cpu'
479 | cluster = tf.train.ClusterSpec({
480 | 'actor': ['localhost:%d' % (8001 + i) for i in range(FLAGS.num_actors)],
481 | 'learner': ['localhost:8000']
482 | })
483 | server = tf.train.Server(cluster, job_name=FLAGS.job_name,
484 | task_index=FLAGS.task)
485 | filters = [shared_job_device, local_job_device]
486 |
487 | # Only used to find the actor output structure.
488 | with tf.Graph().as_default():
489 | agent = Agent(len(action_set))
490 | env = create_environment(level_names[0], seed=1)
491 | structure = build_actor(agent, env, level_names[0], action_set)
492 | flattened_structure = nest.flatten(structure)
493 | dtypes = [t.dtype for t in flattened_structure]
494 | shapes = [t.shape.as_list() for t in flattened_structure]
495 |
496 | with tf.Graph().as_default(), \
497 | tf.device(local_job_device + '/cpu'), \
498 | pin_global_variables(global_variable_device):
499 | tf.set_random_seed(FLAGS.seed) # Makes initialization deterministic.
500 |
501 | # Create Queue and Agent on the learner.
502 | with tf.device(shared_job_device):
503 | queue = tf.FIFOQueue(1, dtypes, shapes, shared_name='buffer')
504 | agent = Agent(len(action_set))
505 |
506 | if is_single_machine() and 'dynamic_batching' in sys.modules:
507 | # For single machine training, we use dynamic batching for improved GPU
508 | # utilization. The semantics of single machine training are slightly
509 | # different from the distributed setting because within a single unroll
510 | # of an environment, the actions may be computed using different weights
511 | # if an update happens within the unroll.
512 | old_build = agent._build
513 | @dynamic_batching.batch_fn
514 | def build(*args):
515 | with tf.device('/gpu'):
516 | return old_build(*args)
517 | tf.logging.info('Using dynamic batching.')
518 | agent._build = build
519 |
520 | # Build actors and ops to enqueue their output.
521 | enqueue_ops = []
522 | for i in range(FLAGS.num_actors):
523 | if is_actor_fn(i):
524 | level_name = level_names[i % len(level_names)]
525 | tf.logging.info('Creating actor %d with level %s', i, level_name)
526 | env = create_environment(level_name, seed=i + 1)
527 | actor_output = build_actor(agent, env, level_name, action_set)
528 | with tf.device(shared_job_device):
529 | enqueue_ops.append(queue.enqueue(nest.flatten(actor_output)))
530 |
531 | # If running in a single machine setup, run actors with QueueRunners
532 | # (separate threads).
533 | if is_learner and enqueue_ops:
534 | tf.train.add_queue_runner(tf.train.QueueRunner(queue, enqueue_ops))
535 |
536 | # Build learner.
537 | if is_learner:
538 | # Create global step, which is the number of environment frames processed.
539 | tf.get_variable(
540 | 'num_environment_frames',
541 | initializer=tf.zeros_initializer(),
542 | shape=[],
543 | dtype=tf.int64,
544 | trainable=False,
545 | collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
546 |
547 | # Create batch (time major) and recreate structure.
548 | dequeued = queue.dequeue_many(FLAGS.batch_size)
549 | dequeued = nest.pack_sequence_as(structure, dequeued)
550 |
551 | def make_time_major(s):
552 | return nest.map_structure(
553 | lambda t: tf.transpose(t, [1, 0] + list(range(t.shape.ndims))[2:]), s)
554 |
555 | dequeued = dequeued._replace(
556 | env_outputs=make_time_major(dequeued.env_outputs),
557 | agent_outputs=make_time_major(dequeued.agent_outputs))
558 |
559 | with tf.device('/gpu'):
560 | # Using StagingArea allows us to prepare the next batch and send it to
561 | # the GPU while we're performing a training step. This adds up to 1 step
562 | # policy lag.
563 | flattened_output = nest.flatten(dequeued)
564 | area = tf.contrib.staging.StagingArea(
565 | [t.dtype for t in flattened_output],
566 | [t.shape for t in flattened_output])
567 | stage_op = area.put(flattened_output)
568 |
569 | data_from_actors = nest.pack_sequence_as(structure, area.get())
570 |
571 | # Unroll agent on sequence, create losses and update ops.
572 | output = build_learner(agent, data_from_actors.agent_state,
573 | data_from_actors.env_outputs,
574 | data_from_actors.agent_outputs)
575 |
576 | # Create MonitoredSession (to run the graph, checkpoint and log).
577 | tf.logging.info('Creating MonitoredSession, is_chief %s', is_learner)
578 | config = tf.ConfigProto(allow_soft_placement=True, device_filters=filters)
579 | with tf.train.MonitoredTrainingSession(
580 | server.target,
581 | is_chief=is_learner,
582 | checkpoint_dir=FLAGS.logdir,
583 | save_checkpoint_secs=600,
584 | save_summaries_secs=30,
585 | log_step_count_steps=50000,
586 | config=config,
587 | hooks=[py_process.PyProcessHook()]) as session:
588 |
589 | if is_learner:
590 | # Logging.
591 | level_returns = {level_name: [] for level_name in level_names}
592 | summary_writer = tf.summary.FileWriterCache.get(FLAGS.logdir)
593 |
594 | # Prepare data for first run.
595 | session.run_step_fn(
596 | lambda step_context: step_context.session.run(stage_op))
597 |
598 | # Execute learning and track performance.
599 | num_env_frames_v = 0
600 | while num_env_frames_v < FLAGS.total_environment_frames:
601 | level_names_v, done_v, infos_v, num_env_frames_v, _ = session.run(
602 | (data_from_actors.level_name,) + output + (stage_op,))
603 | level_names_v = np.repeat([level_names_v], done_v.shape[0], 0)
604 |
605 | for level_name, episode_return, episode_step in zip(
606 | level_names_v[done_v],
607 | infos_v.episode_return[done_v],
608 | infos_v.episode_step[done_v]):
609 | episode_frames = episode_step * FLAGS.num_action_repeats
610 |
611 | tf.logging.info('Level: %s Episode return: %f',
612 | level_name, episode_return)
613 |
614 | summary = tf.summary.Summary()
615 | summary.value.add(tag=level_name + '/episode_return',
616 | simple_value=episode_return)
617 | summary.value.add(tag=level_name + '/episode_frames',
618 | simple_value=episode_frames)
619 | summary_writer.add_summary(summary, num_env_frames_v)
620 |
621 | if FLAGS.level_name == 'dmlab30':
622 | level_returns[level_name].append(episode_return)
623 |
624 | if (FLAGS.level_name == 'dmlab30' and
625 | min(map(len, level_returns.values())) >= 1):
626 | no_cap = dmlab30.compute_human_normalized_score(level_returns,
627 | per_level_cap=None)
628 | cap_100 = dmlab30.compute_human_normalized_score(level_returns,
629 | per_level_cap=100)
630 | summary = tf.summary.Summary()
631 | summary.value.add(
632 | tag='dmlab30/training_no_cap', simple_value=no_cap)
633 | summary.value.add(
634 | tag='dmlab30/training_cap_100', simple_value=cap_100)
635 | summary_writer.add_summary(summary, num_env_frames_v)
636 |
637 | # Clear level scores.
638 | level_returns = {level_name: [] for level_name in level_names}
639 |
640 | else:
641 | # Execute actors (they just need to enqueue their output).
642 | while True:
643 | session.run(enqueue_ops)
644 |
645 |
646 | def test(action_set, level_names):
647 | """Test."""
648 |
649 | level_returns = {level_name: [] for level_name in level_names}
650 | with tf.Graph().as_default():
651 | agent = Agent(len(action_set))
652 | outputs = {}
653 | for level_name in level_names:
654 | env = create_environment(level_name, seed=1, is_test=True)
655 | outputs[level_name] = build_actor(agent, env, level_name, action_set)
656 |
657 | with tf.train.SingularMonitoredSession(
658 | checkpoint_dir=FLAGS.logdir,
659 | hooks=[py_process.PyProcessHook()]) as session:
660 | for level_name in level_names:
661 | tf.logging.info('Testing level: %s', level_name)
662 | while True:
663 | done_v, infos_v = session.run((
664 | outputs[level_name].env_outputs.done,
665 | outputs[level_name].env_outputs.info
666 | ))
667 | returns = level_returns[level_name]
668 | returns.extend(infos_v.episode_return[1:][done_v[1:]])
669 |
670 | if len(returns) >= FLAGS.test_num_episodes:
671 | tf.logging.info('Mean episode return: %f', np.mean(returns))
672 | break
673 |
674 | if FLAGS.level_name == 'dmlab30':
675 | no_cap = dmlab30.compute_human_normalized_score(level_returns,
676 | per_level_cap=None)
677 | cap_100 = dmlab30.compute_human_normalized_score(level_returns,
678 | per_level_cap=100)
679 | tf.logging.info('No cap.: %f Cap 100: %f', no_cap, cap_100)
680 |
681 |
682 | def main(_):
683 | tf.logging.set_verbosity(tf.logging.INFO)
684 |
685 | action_set = environments.DEFAULT_ACTION_SET
686 | if FLAGS.level_name == 'dmlab30' and FLAGS.mode == 'train':
687 | level_names = dmlab30.LEVEL_MAPPING.keys()
688 | elif FLAGS.level_name == 'dmlab30' and FLAGS.mode == 'test':
689 | level_names = dmlab30.LEVEL_MAPPING.values()
690 | else:
691 | level_names = [FLAGS.level_name]
692 |
693 | if FLAGS.mode == 'train':
694 | train(action_set, level_names)
695 | else:
696 | test(action_set, level_names)
697 |
698 |
699 | if __name__ == '__main__':
700 | tf.app.run()
701 |
--------------------------------------------------------------------------------
/py_process.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """PyProcess.
16 |
17 | This file includes utilities for running code in separate Python processes as
18 | part of a TensorFlow graph. It is similar to tf.py_func, but the code is run in
19 | separate processes to avoid the GIL.
20 |
21 | Example:
22 |
23 | class Zeros(object):
24 |
25 | def __init__(self, dim0):
26 | self._dim0 = dim0
27 |
28 | def compute(self, dim1):
29 | return np.zeros([self._dim0, dim1], dtype=np.int32)
30 |
31 | @staticmethod
32 | def _tensor_specs(method_name, kwargs, constructor_kwargs):
33 | dim0 = constructor_kwargs['dim0']
34 | dim1 = kwargs['dim1']
35 | if method_name == 'compute':
36 | return tf.contrib.framework.TensorSpec([dim0, dim1], tf.int32)
37 |
38 | with tf.Graph().as_default():
39 | p = py_process.PyProcess(Zeros, 1)
40 | result = p.proxy.compute(2)
41 |
42 | with tf.train.SingularMonitoredSession(
43 | hooks=[py_process.PyProcessHook()]) as session:
44 | print(session.run(result)) # Prints [[0, 0]].
45 | """
46 |
47 | from __future__ import absolute_import
48 | from __future__ import division
49 | from __future__ import print_function
50 |
51 | import multiprocessing
52 |
53 | import tensorflow as tf
54 |
55 | from tensorflow.python.util import function_utils
56 |
57 |
58 | nest = tf.contrib.framework.nest
59 |
60 |
61 | class _TFProxy(object):
62 | """A proxy that creates TensorFlow operations for each method call to a
63 | separate process."""
64 |
65 | def __init__(self, type_, constructor_kwargs):
66 | self._type = type_
67 | self._constructor_kwargs = constructor_kwargs
68 |
69 | def __getattr__(self, name):
70 | def call(*args):
71 | kwargs = dict(
72 | zip(function_utils.fn_args(getattr(self._type, name))[1:], args))
73 | specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs)
74 |
75 | if specs is None:
76 | raise ValueError(
77 | 'No tensor specifications were provided for: %s' % name)
78 |
79 | flat_dtypes = nest.flatten(nest.map_structure(lambda s: s.dtype, specs))
80 | flat_shapes = nest.flatten(nest.map_structure(lambda s: s.shape, specs))
81 |
82 | def py_call(*args):
83 | try:
84 | self._out.send(args)
85 | result = self._out.recv()
86 | if isinstance(result, Exception):
87 | raise result
88 | if result is not None:
89 | return result
90 | except Exception as e:
91 | if isinstance(e, IOError):
92 | raise StopIteration() # Clean exit.
93 | else:
94 | raise
95 |
96 | result = tf.py_func(py_call, (name,) + tuple(args), flat_dtypes,
97 | name=name)
98 |
99 | if isinstance(result, tf.Operation):
100 | return result
101 |
102 | for t, shape in zip(result, flat_shapes):
103 | t.set_shape(shape)
104 | return nest.pack_sequence_as(specs, result)
105 | return call
106 |
107 | def _start(self):
108 | self._out, in_ = multiprocessing.Pipe()
109 | self._process = multiprocessing.Process(
110 | target=self._worker_fn,
111 | args=(self._type, self._constructor_kwargs, in_))
112 | self._process.start()
113 | result = self._out.recv()
114 |
115 | if isinstance(result, Exception):
116 | raise result
117 |
118 | def _close(self, session):
119 | try:
120 | self._out.send(None)
121 | self._out.close()
122 | except IOError:
123 | pass
124 | self._process.join()
125 |
126 | def _worker_fn(self, type_, constructor_kwargs, in_):
127 | try:
128 | o = type_(**constructor_kwargs)
129 |
130 | in_.send(None) # Ready.
131 |
132 | while True:
133 | # Receive request.
134 | serialized = in_.recv()
135 |
136 | if serialized is None:
137 | if hasattr(o, 'close'):
138 | o.close()
139 | in_.close()
140 | return
141 |
142 | method_name = str(serialized[0])
143 | inputs = serialized[1:]
144 |
145 | # Compute result.
146 | results = getattr(o, method_name)(*inputs)
147 | if results is not None:
148 | results = nest.flatten(results)
149 |
150 | # Respond.
151 | in_.send(results)
152 | except Exception as e:
153 | if 'o' in locals() and hasattr(o, 'close'):
154 | try:
155 | o.close()
156 | except:
157 | pass
158 | in_.send(e)
159 |
160 |
161 | class PyProcess(object):
162 | COLLECTION = 'py_process_processes'
163 |
164 | def __init__(self, type_, *constructor_args, **constructor_kwargs):
165 | self._type = type_
166 | self._constructor_kwargs = dict(
167 | zip(function_utils.fn_args(type_.__init__)[1:], constructor_args))
168 | self._constructor_kwargs.update(constructor_kwargs)
169 |
170 | tf.add_to_collection(PyProcess.COLLECTION, self)
171 |
172 | self._proxy = _TFProxy(type_, self._constructor_kwargs)
173 |
174 | @property
175 | def proxy(self):
176 | """A proxy that creates TensorFlow operations for each method call."""
177 | return self._proxy
178 |
179 | def close(self, session):
180 | self._proxy._close(session)
181 |
182 | def start(self):
183 | self._proxy._start()
184 |
185 |
186 | class PyProcessHook(tf.train.SessionRunHook):
187 | """A MonitoredSession hook that starts and stops PyProcess instances."""
188 |
189 | def begin(self):
190 | tf.logging.info('Starting all processes.')
191 | tp = multiprocessing.pool.ThreadPool()
192 | tp.map(lambda p: p.start(), tf.get_collection(PyProcess.COLLECTION))
193 | tp.close()
194 | tp.join()
195 | tf.logging.info('All processes started.')
196 |
197 | def end(self, session):
198 | tf.logging.info('Closing all processes.')
199 | tp = multiprocessing.pool.ThreadPool()
200 | tp.map(lambda p: p.close(session), tf.get_collection(PyProcess.COLLECTION))
201 | tp.close()
202 | tp.join()
203 | tf.logging.info('All processes closed.')
204 |
--------------------------------------------------------------------------------
/py_process_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests py_process.py."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tempfile
22 | import time
23 |
24 | import numpy as np
25 | import py_process
26 | import tensorflow as tf
27 |
28 | from six.moves import range
29 |
30 |
31 | class PyProcessTest(tf.test.TestCase):
32 |
33 | def test_small(self):
34 |
35 | class Example(object):
36 |
37 | def __init__(self, a):
38 | self._a = a
39 |
40 | def inc(self):
41 | self._a += 1
42 |
43 | def compute(self, b):
44 | return np.array(self._a + b, dtype=np.int32)
45 |
46 | @staticmethod
47 | def _tensor_specs(method_name, unused_args, unused_constructor_kwargs):
48 | if method_name == 'compute':
49 | return tf.contrib.framework.TensorSpec([], tf.int32)
50 | elif method_name == 'inc':
51 | return ()
52 |
53 | with tf.Graph().as_default():
54 | p = py_process.PyProcess(Example, 1)
55 | inc = p.proxy.inc()
56 | compute = p.proxy.compute(2)
57 |
58 | with tf.train.SingularMonitoredSession(
59 | hooks=[py_process.PyProcessHook()]) as session:
60 | self.assertTrue(isinstance(inc, tf.Operation))
61 | session.run(inc)
62 |
63 | self.assertEqual([], compute.shape)
64 | self.assertEqual(4, session.run(compute))
65 |
66 | def test_threading(self):
67 |
68 | class Example(object):
69 |
70 | def __init__(self):
71 | pass
72 |
73 | def wait(self):
74 | time.sleep(.2)
75 | return None
76 |
77 | @staticmethod
78 | def _tensor_specs(method_name, unused_args, unused_constructor_kwargs):
79 | if method_name == 'wait':
80 | return tf.contrib.framework.TensorSpec([], tf.int32)
81 |
82 | with tf.Graph().as_default():
83 | p = py_process.PyProcess(Example)
84 | wait = p.proxy.wait()
85 |
86 | hook = py_process.PyProcessHook()
87 | with tf.train.SingularMonitoredSession(hooks=[hook]) as session:
88 |
89 | def run():
90 | with self.assertRaises(tf.errors.OutOfRangeError):
91 | session.run(wait)
92 |
93 | t = self.checkedThread(target=run)
94 | t.start()
95 | time.sleep(.1)
96 | t.join()
97 |
98 | def test_args(self):
99 |
100 | class Example(object):
101 |
102 | def __init__(self, dim0):
103 | self._dim0 = dim0
104 |
105 | def compute(self, dim1):
106 | return np.zeros([self._dim0, dim1], dtype=np.int32)
107 |
108 | @staticmethod
109 | def _tensor_specs(method_name, kwargs, constructor_kwargs):
110 | dim0 = constructor_kwargs['dim0']
111 | dim1 = kwargs['dim1']
112 | if method_name == 'compute':
113 | return tf.contrib.framework.TensorSpec([dim0, dim1], tf.int32)
114 |
115 | with tf.Graph().as_default():
116 | p = py_process.PyProcess(Example, 1)
117 | result = p.proxy.compute(2)
118 |
119 | with tf.train.SingularMonitoredSession(
120 | hooks=[py_process.PyProcessHook()]) as session:
121 | self.assertEqual([1, 2], result.shape)
122 | self.assertAllEqual([[0, 0]], session.run(result))
123 |
124 | def test_error_handling_constructor(self):
125 |
126 | class Example(object):
127 |
128 | def __init__(self):
129 | raise ValueError('foo')
130 |
131 | def something(self):
132 | pass
133 |
134 | @staticmethod
135 | def _tensor_specs(method_name, unused_kwargs, unused_constructor_kwargs):
136 | if method_name == 'something':
137 | return ()
138 |
139 | with tf.Graph().as_default():
140 | py_process.PyProcess(Example, 1)
141 |
142 | with self.assertRaisesRegexp(Exception, 'foo'):
143 | with tf.train.SingularMonitoredSession(
144 | hooks=[py_process.PyProcessHook()]):
145 | pass
146 |
147 | def test_error_handling_method(self):
148 |
149 | class Example(object):
150 |
151 | def __init__(self):
152 | pass
153 |
154 | def something(self):
155 | raise ValueError('foo')
156 |
157 | @staticmethod
158 | def _tensor_specs(method_name, unused_kwargs, unused_constructor_kwargs):
159 | if method_name == 'something':
160 | return ()
161 |
162 | with tf.Graph().as_default():
163 | p = py_process.PyProcess(Example, 1)
164 | result = p.proxy.something()
165 |
166 | with tf.train.SingularMonitoredSession(
167 | hooks=[py_process.PyProcessHook()]) as session:
168 | with self.assertRaisesRegexp(Exception, 'foo'):
169 | session.run(result)
170 |
171 | def test_close(self):
172 | with tempfile.NamedTemporaryFile() as tmp:
173 | class Example(object):
174 |
175 | def __init__(self, filename):
176 | self._filename = filename
177 |
178 | def close(self):
179 | with tf.gfile.Open(self._filename, 'w') as f:
180 | f.write('was_closed')
181 |
182 | with tf.Graph().as_default():
183 | py_process.PyProcess(Example, tmp.name)
184 |
185 | with tf.train.SingularMonitoredSession(
186 | hooks=[py_process.PyProcessHook()]):
187 | pass
188 |
189 | self.assertEqual('was_closed', tmp.read())
190 |
191 | def test_close_on_error(self):
192 | with tempfile.NamedTemporaryFile() as tmp:
193 |
194 | class Example(object):
195 |
196 | def __init__(self, filename):
197 | self._filename = filename
198 |
199 | def something(self):
200 | raise ValueError('foo')
201 |
202 | def close(self):
203 | with tf.gfile.Open(self._filename, 'w') as f:
204 | f.write('was_closed')
205 |
206 | @staticmethod
207 | def _tensor_specs(method_name, unused_kwargs,
208 | unused_constructor_kwargs):
209 | if method_name == 'something':
210 | return ()
211 |
212 | with tf.Graph().as_default():
213 | p = py_process.PyProcess(Example, tmp.name)
214 | result = p.proxy.something()
215 |
216 | with tf.train.SingularMonitoredSession(
217 | hooks=[py_process.PyProcessHook()]) as session:
218 | with self.assertRaisesRegexp(Exception, 'foo'):
219 | session.run(result)
220 |
221 | self.assertEqual('was_closed', tmp.read())
222 |
223 |
224 | class PyProcessBenchmarks(tf.test.Benchmark):
225 |
226 | class Example(object):
227 |
228 | def __init__(self):
229 | self._result = np.random.randint(0, 256, (72, 96, 3), np.uint8)
230 |
231 | def compute(self, unused_a):
232 | return self._result
233 |
234 | @staticmethod
235 | def _tensor_specs(method_name, unused_args, unused_constructor_kwargs):
236 | if method_name == 'compute':
237 | return tf.contrib.framework.TensorSpec([72, 96, 3], tf.uint8)
238 |
239 | def benchmark_one(self):
240 | with tf.Graph().as_default():
241 | p = py_process.PyProcess(PyProcessBenchmarks.Example)
242 | compute = p.proxy.compute(2)
243 |
244 | with tf.train.SingularMonitoredSession(
245 | hooks=[py_process.PyProcessHook()]) as session:
246 |
247 | self.run_op_benchmark(
248 | name='process_one',
249 | sess=session,
250 | op_or_tensor=compute,
251 | burn_iters=10,
252 | min_iters=5000)
253 |
254 | def benchmark_many(self):
255 | with tf.Graph().as_default():
256 | ps = [
257 | py_process.PyProcess(PyProcessBenchmarks.Example) for _ in range(200)
258 | ]
259 | compute_ops = [p.proxy.compute(2) for p in ps]
260 | compute = tf.group(*compute_ops)
261 |
262 | with tf.train.SingularMonitoredSession(
263 | hooks=[py_process.PyProcessHook()]) as session:
264 |
265 | self.run_op_benchmark(
266 | name='process_many',
267 | sess=session,
268 | op_or_tensor=compute,
269 | burn_iters=10,
270 | min_iters=500)
271 |
272 |
273 | if __name__ == '__main__':
274 | tf.test.main()
275 |
--------------------------------------------------------------------------------
/vtrace.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions to compute V-trace off-policy actor critic targets.
16 |
17 | For details and theory see:
18 |
19 | "IMPALA: Scalable Distributed Deep-RL with
20 | Importance Weighted Actor-Learner Architectures"
21 | by Espeholt, Soyer, Munos et al.
22 |
23 | See https://arxiv.org/abs/1802.01561 for the full paper.
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | import collections
31 |
32 | import tensorflow as tf
33 |
34 | nest = tf.contrib.framework.nest
35 |
36 |
37 | VTraceFromLogitsReturns = collections.namedtuple(
38 | 'VTraceFromLogitsReturns',
39 | ['vs', 'pg_advantages', 'log_rhos',
40 | 'behaviour_action_log_probs', 'target_action_log_probs'])
41 |
42 | VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages')
43 |
44 |
45 | def log_probs_from_logits_and_actions(policy_logits, actions):
46 | """Computes action log-probs from policy logits and actions.
47 |
48 | In the notation used throughout documentation and comments, T refers to the
49 | time dimension ranging from 0 to T-1. B refers to the batch size and
50 | NUM_ACTIONS refers to the number of actions.
51 |
52 | Args:
53 | policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
54 | un-normalized log-probabilities parameterizing a softmax policy.
55 | actions: An int32 tensor of shape [T, B] with actions.
56 |
57 | Returns:
58 | A float32 tensor of shape [T, B] corresponding to the sampling log
59 | probability of the chosen action w.r.t. the policy.
60 | """
61 | policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32)
62 | actions = tf.convert_to_tensor(actions, dtype=tf.int32)
63 |
64 | policy_logits.shape.assert_has_rank(3)
65 | actions.shape.assert_has_rank(2)
66 |
67 | return -tf.nn.sparse_softmax_cross_entropy_with_logits(
68 | logits=policy_logits, labels=actions)
69 |
70 |
71 | def from_logits(
72 | behaviour_policy_logits, target_policy_logits, actions,
73 | discounts, rewards, values, bootstrap_value,
74 | clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0,
75 | name='vtrace_from_logits'):
76 | r"""V-trace for softmax policies.
77 |
78 | Calculates V-trace actor critic targets for softmax polices as described in
79 |
80 | "IMPALA: Scalable Distributed Deep-RL with
81 | Importance Weighted Actor-Learner Architectures"
82 | by Espeholt, Soyer, Munos et al.
83 |
84 | Target policy refers to the policy we are interested in improving and
85 | behaviour policy refers to the policy that generated the given
86 | rewards and actions.
87 |
88 | In the notation used throughout documentation and comments, T refers to the
89 | time dimension ranging from 0 to T-1. B refers to the batch size and
90 | NUM_ACTIONS refers to the number of actions.
91 |
92 | Args:
93 | behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
94 | un-normalized log-probabilities parametrizing the softmax behaviour
95 | policy.
96 | target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
97 | un-normalized log-probabilities parametrizing the softmax target policy.
98 | actions: An int32 tensor of shape [T, B] of actions sampled from the
99 | behaviour policy.
100 | discounts: A float32 tensor of shape [T, B] with the discount encountered
101 | when following the behaviour policy.
102 | rewards: A float32 tensor of shape [T, B] with the rewards generated by
103 | following the behaviour policy.
104 | values: A float32 tensor of shape [T, B] with the value function estimates
105 | wrt. the target policy.
106 | bootstrap_value: A float32 of shape [B] with the value function estimate at
107 | time T.
108 | clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
109 | importance weights (rho) when calculating the baseline targets (vs).
110 | rho^bar in the paper.
111 | clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
112 | on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
113 | name: The name scope that all V-trace operations will be created in.
114 |
115 | Returns:
116 | A `VTraceFromLogitsReturns` namedtuple with the following fields:
117 | vs: A float32 tensor of shape [T, B]. Can be used as target to train a
118 | baseline (V(x_t) - vs_t)^2.
119 | pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an
120 | estimate of the advantage in the calculation of policy gradients.
121 | log_rhos: A float32 tensor of shape [T, B] containing the log importance
122 | sampling weights (log rhos).
123 | behaviour_action_log_probs: A float32 tensor of shape [T, B] containing
124 | behaviour policy action log probabilities (log \mu(a_t)).
125 | target_action_log_probs: A float32 tensor of shape [T, B] containing
126 | target policy action probabilities (log \pi(a_t)).
127 | """
128 | behaviour_policy_logits = tf.convert_to_tensor(
129 | behaviour_policy_logits, dtype=tf.float32)
130 | target_policy_logits = tf.convert_to_tensor(
131 | target_policy_logits, dtype=tf.float32)
132 | actions = tf.convert_to_tensor(actions, dtype=tf.int32)
133 |
134 | # Make sure tensor ranks are as expected.
135 | # The rest will be checked by from_action_log_probs.
136 | behaviour_policy_logits.shape.assert_has_rank(3)
137 | target_policy_logits.shape.assert_has_rank(3)
138 | actions.shape.assert_has_rank(2)
139 |
140 | with tf.name_scope(name, values=[
141 | behaviour_policy_logits, target_policy_logits, actions,
142 | discounts, rewards, values, bootstrap_value]):
143 | target_action_log_probs = log_probs_from_logits_and_actions(
144 | target_policy_logits, actions)
145 | behaviour_action_log_probs = log_probs_from_logits_and_actions(
146 | behaviour_policy_logits, actions)
147 | log_rhos = target_action_log_probs - behaviour_action_log_probs
148 | vtrace_returns = from_importance_weights(
149 | log_rhos=log_rhos,
150 | discounts=discounts,
151 | rewards=rewards,
152 | values=values,
153 | bootstrap_value=bootstrap_value,
154 | clip_rho_threshold=clip_rho_threshold,
155 | clip_pg_rho_threshold=clip_pg_rho_threshold)
156 | return VTraceFromLogitsReturns(
157 | log_rhos=log_rhos,
158 | behaviour_action_log_probs=behaviour_action_log_probs,
159 | target_action_log_probs=target_action_log_probs,
160 | **vtrace_returns._asdict()
161 | )
162 |
163 |
164 | def from_importance_weights(
165 | log_rhos, discounts, rewards, values, bootstrap_value,
166 | clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0,
167 | name='vtrace_from_importance_weights'):
168 | r"""V-trace from log importance weights.
169 |
170 | Calculates V-trace actor critic targets as described in
171 |
172 | "IMPALA: Scalable Distributed Deep-RL with
173 | Importance Weighted Actor-Learner Architectures"
174 | by Espeholt, Soyer, Munos et al.
175 |
176 | In the notation used throughout documentation and comments, T refers to the
177 | time dimension ranging from 0 to T-1. B refers to the batch size and
178 | NUM_ACTIONS refers to the number of actions. This code also supports the
179 | case where all tensors have the same number of additional dimensions, e.g.,
180 | `rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C].
181 |
182 | Args:
183 | log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the log
184 | importance sampling weights, i.e.
185 | log(target_policy(a) / behaviour_policy(a)). V-trace performs operations
186 | on rhos in log-space for numerical stability.
187 | discounts: A float32 tensor of shape [T, B] with discounts encountered when
188 | following the behaviour policy.
189 | rewards: A float32 tensor of shape [T, B] containing rewards generated by
190 | following the behaviour policy.
191 | values: A float32 tensor of shape [T, B] with the value function estimates
192 | wrt. the target policy.
193 | bootstrap_value: A float32 of shape [B] with the value function estimate at
194 | time T.
195 | clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
196 | importance weights (rho) when calculating the baseline targets (vs).
197 | rho^bar in the paper. If None, no clipping is applied.
198 | clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
199 | on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If
200 | None, no clipping is applied.
201 | name: The name scope that all V-trace operations will be created in.
202 |
203 | Returns:
204 | A VTraceReturns namedtuple (vs, pg_advantages) where:
205 | vs: A float32 tensor of shape [T, B]. Can be used as target to
206 | train a baseline (V(x_t) - vs_t)^2.
207 | pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
208 | advantage in the calculation of policy gradients.
209 | """
210 | log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32)
211 | discounts = tf.convert_to_tensor(discounts, dtype=tf.float32)
212 | rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)
213 | values = tf.convert_to_tensor(values, dtype=tf.float32)
214 | bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32)
215 | if clip_rho_threshold is not None:
216 | clip_rho_threshold = tf.convert_to_tensor(clip_rho_threshold,
217 | dtype=tf.float32)
218 | if clip_pg_rho_threshold is not None:
219 | clip_pg_rho_threshold = tf.convert_to_tensor(clip_pg_rho_threshold,
220 | dtype=tf.float32)
221 |
222 | # Make sure tensor ranks are consistent.
223 | rho_rank = log_rhos.shape.ndims # Usually 2.
224 | values.shape.assert_has_rank(rho_rank)
225 | bootstrap_value.shape.assert_has_rank(rho_rank - 1)
226 | discounts.shape.assert_has_rank(rho_rank)
227 | rewards.shape.assert_has_rank(rho_rank)
228 | if clip_rho_threshold is not None:
229 | clip_rho_threshold.shape.assert_has_rank(0)
230 | if clip_pg_rho_threshold is not None:
231 | clip_pg_rho_threshold.shape.assert_has_rank(0)
232 |
233 | with tf.name_scope(name, values=[
234 | log_rhos, discounts, rewards, values, bootstrap_value]):
235 | rhos = tf.exp(log_rhos)
236 | if clip_rho_threshold is not None:
237 | clipped_rhos = tf.minimum(clip_rho_threshold, rhos, name='clipped_rhos')
238 | else:
239 | clipped_rhos = rhos
240 |
241 | cs = tf.minimum(1.0, rhos, name='cs')
242 | # Append bootstrapped value to get [v1, ..., v_t+1]
243 | values_t_plus_1 = tf.concat(
244 | [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
245 | deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
246 |
247 | sequences = (discounts, cs, deltas)
248 | # V-trace vs are calculated through a scan from the back to the beginning
249 | # of the given trajectory.
250 | def scanfunc(acc, sequence_item):
251 | discount_t, c_t, delta_t = sequence_item
252 | return delta_t + discount_t * c_t * acc
253 |
254 | initial_values = tf.zeros_like(bootstrap_value)
255 | vs_minus_v_xs = tf.scan(
256 | fn=scanfunc,
257 | elems=sequences,
258 | initializer=initial_values,
259 | parallel_iterations=1,
260 | back_prop=False,
261 | reverse=True, # Computation starts from the back.
262 | name='scan')
263 |
264 | # Add V(x_s) to get v_s.
265 | vs = tf.add(vs_minus_v_xs, values, name='vs')
266 |
267 | # Advantage for policy gradient.
268 | vs_t_plus_1 = tf.concat([
269 | vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
270 | if clip_pg_rho_threshold is not None:
271 | clipped_pg_rhos = tf.minimum(clip_pg_rho_threshold, rhos,
272 | name='clipped_pg_rhos')
273 | else:
274 | clipped_pg_rhos = rhos
275 | pg_advantages = (
276 | clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values))
277 |
278 | # Make sure no gradients backpropagated through the returned values.
279 | return VTraceReturns(vs=tf.stop_gradient(vs),
280 | pg_advantages=tf.stop_gradient(pg_advantages))
281 |
--------------------------------------------------------------------------------
/vtrace_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for V-trace.
16 |
17 | For details and theory see:
18 |
19 | "IMPALA: Scalable Distributed Deep-RL with
20 | Importance Weighted Actor-Learner Architectures"
21 | by Espeholt, Soyer, Munos et al.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 | from absl.testing import parameterized
29 | import numpy as np
30 | import tensorflow as tf
31 | import vtrace
32 |
33 |
34 | def _shaped_arange(*shape):
35 | """Runs np.arange, converts to float and reshapes."""
36 | return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape)
37 |
38 |
39 | def _softmax(logits):
40 | """Applies softmax non-linearity on inputs."""
41 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
42 |
43 |
44 | def _ground_truth_calculation(discounts, log_rhos, rewards, values,
45 | bootstrap_value, clip_rho_threshold,
46 | clip_pg_rho_threshold):
47 | """Calculates the ground truth for V-trace in Python/Numpy."""
48 | vs = []
49 | seq_len = len(discounts)
50 | rhos = np.exp(log_rhos)
51 | cs = np.minimum(rhos, 1.0)
52 | clipped_rhos = rhos
53 | if clip_rho_threshold:
54 | clipped_rhos = np.minimum(rhos, clip_rho_threshold)
55 | clipped_pg_rhos = rhos
56 | if clip_pg_rho_threshold:
57 | clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold)
58 |
59 | # This is a very inefficient way to calculate the V-trace ground truth.
60 | # We calculate it this way because it is close to the mathematical notation of
61 | # V-trace.
62 | # v_s = V(x_s)
63 | # + \sum^{T-1}_{t=s} \gamma^{t-s}
64 | # * \prod_{i=s}^{t-1} c_i
65 | # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t))
66 | # Note that when we take the product over c_i, we write `s:t` as the notation
67 | # of the paper is inclusive of the `t-1`, but Python is exclusive.
68 | # Also note that np.prod([]) == 1.
69 | values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0)
70 | for s in range(seq_len):
71 | v_s = np.copy(values[s]) # Very important copy.
72 | for t in range(s, seq_len):
73 | v_s += (
74 | np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t],
75 | axis=0) * clipped_rhos[t] *
76 | (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - values[t]))
77 | vs.append(v_s)
78 | vs = np.stack(vs, axis=0)
79 | pg_advantages = (
80 | clipped_pg_rhos * (rewards + discounts * np.concatenate(
81 | [vs[1:], bootstrap_value[None, :]], axis=0) - values))
82 |
83 | return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages)
84 |
85 |
86 | class LogProbsFromLogitsAndActionsTest(tf.test.TestCase,
87 | parameterized.TestCase):
88 |
89 | @parameterized.named_parameters(('Batch1', 1), ('Batch2', 2))
90 | def test_log_probs_from_logits_and_actions(self, batch_size):
91 | """Tests log_probs_from_logits_and_actions."""
92 | seq_len = 7
93 | num_actions = 3
94 |
95 | policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10
96 | actions = np.random.randint(
97 | 0, num_actions, size=(seq_len, batch_size), dtype=np.int32)
98 |
99 | action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions(
100 | policy_logits, actions)
101 |
102 | # Ground Truth
103 | # Using broadcasting to create a mask that indexes action logits
104 | action_index_mask = actions[..., None] == np.arange(num_actions)
105 |
106 | def index_with_mask(array, mask):
107 | return array[mask].reshape(*array.shape[:-1])
108 |
109 | # Note: Normally log(softmax) is not a good idea because it's not
110 | # numerically stable. However, in this test we have well-behaved values.
111 | ground_truth_v = index_with_mask(
112 | np.log(_softmax(policy_logits)), action_index_mask)
113 |
114 | with self.test_session() as session:
115 | self.assertAllClose(ground_truth_v, session.run(action_log_probs_tensor))
116 |
117 |
118 | class VtraceTest(tf.test.TestCase, parameterized.TestCase):
119 |
120 | @parameterized.named_parameters(('Batch1', 1), ('Batch5', 5))
121 | def test_vtrace(self, batch_size):
122 | """Tests V-trace against ground truth data calculated in python."""
123 | seq_len = 5
124 |
125 | # Create log_rhos such that rho will span from near-zero to above the
126 | # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5),
127 | # so that rho is in approx [0.08, 12.2).
128 | log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len)
129 | log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5).
130 | values = {
131 | 'log_rhos': log_rhos,
132 | # T, B where B_i: [0.9 / (i+1)] * T
133 | 'discounts':
134 | np.array([[0.9 / (b + 1)
135 | for b in range(batch_size)]
136 | for _ in range(seq_len)]),
137 | 'rewards':
138 | _shaped_arange(seq_len, batch_size),
139 | 'values':
140 | _shaped_arange(seq_len, batch_size) / batch_size,
141 | 'bootstrap_value':
142 | _shaped_arange(batch_size) + 1.0,
143 | 'clip_rho_threshold':
144 | 3.7,
145 | 'clip_pg_rho_threshold':
146 | 2.2,
147 | }
148 |
149 | output = vtrace.from_importance_weights(**values)
150 |
151 | with self.test_session() as session:
152 | output_v = session.run(output)
153 |
154 | ground_truth_v = _ground_truth_calculation(**values)
155 | for a, b in zip(ground_truth_v, output_v):
156 | self.assertAllClose(a, b)
157 |
158 | @parameterized.named_parameters(('Batch1', 1), ('Batch2', 2))
159 | def test_vtrace_from_logits(self, batch_size):
160 | """Tests V-trace calculated from logits."""
161 | seq_len = 5
162 | num_actions = 3
163 | clip_rho_threshold = None # No clipping.
164 | clip_pg_rho_threshold = None # No clipping.
165 |
166 | # Intentionally leaving shapes unspecified to test if V-trace can
167 | # deal with that.
168 | placeholders = {
169 | # T, B, NUM_ACTIONS
170 | 'behaviour_policy_logits':
171 | tf.placeholder(dtype=tf.float32, shape=[None, None, None]),
172 | # T, B, NUM_ACTIONS
173 | 'target_policy_logits':
174 | tf.placeholder(dtype=tf.float32, shape=[None, None, None]),
175 | 'actions':
176 | tf.placeholder(dtype=tf.int32, shape=[None, None]),
177 | 'discounts':
178 | tf.placeholder(dtype=tf.float32, shape=[None, None]),
179 | 'rewards':
180 | tf.placeholder(dtype=tf.float32, shape=[None, None]),
181 | 'values':
182 | tf.placeholder(dtype=tf.float32, shape=[None, None]),
183 | 'bootstrap_value':
184 | tf.placeholder(dtype=tf.float32, shape=[None]),
185 | }
186 |
187 | from_logits_output = vtrace.from_logits(
188 | clip_rho_threshold=clip_rho_threshold,
189 | clip_pg_rho_threshold=clip_pg_rho_threshold,
190 | **placeholders)
191 |
192 | target_log_probs = vtrace.log_probs_from_logits_and_actions(
193 | placeholders['target_policy_logits'], placeholders['actions'])
194 | behaviour_log_probs = vtrace.log_probs_from_logits_and_actions(
195 | placeholders['behaviour_policy_logits'], placeholders['actions'])
196 | log_rhos = target_log_probs - behaviour_log_probs
197 | ground_truth = (log_rhos, behaviour_log_probs, target_log_probs)
198 |
199 | values = {
200 | 'behaviour_policy_logits':
201 | _shaped_arange(seq_len, batch_size, num_actions),
202 | 'target_policy_logits':
203 | _shaped_arange(seq_len, batch_size, num_actions),
204 | 'actions':
205 | np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)),
206 | 'discounts':
207 | np.array( # T, B where B_i: [0.9 / (i+1)] * T
208 | [[0.9 / (b + 1)
209 | for b in range(batch_size)]
210 | for _ in range(seq_len)]),
211 | 'rewards':
212 | _shaped_arange(seq_len, batch_size),
213 | 'values':
214 | _shaped_arange(seq_len, batch_size) / batch_size,
215 | 'bootstrap_value':
216 | _shaped_arange(batch_size) + 1.0, # B
217 | }
218 |
219 | feed_dict = {placeholders[k]: v for k, v in values.items()}
220 | with self.test_session() as session:
221 | from_logits_output_v = session.run(
222 | from_logits_output, feed_dict=feed_dict)
223 | (ground_truth_log_rhos, ground_truth_behaviour_action_log_probs,
224 | ground_truth_target_action_log_probs) = session.run(
225 | ground_truth, feed_dict=feed_dict)
226 |
227 | # Calculate V-trace using the ground truth logits.
228 | from_iw = vtrace.from_importance_weights(
229 | log_rhos=ground_truth_log_rhos,
230 | discounts=values['discounts'],
231 | rewards=values['rewards'],
232 | values=values['values'],
233 | bootstrap_value=values['bootstrap_value'],
234 | clip_rho_threshold=clip_rho_threshold,
235 | clip_pg_rho_threshold=clip_pg_rho_threshold)
236 |
237 | with self.test_session() as session:
238 | from_iw_v = session.run(from_iw)
239 |
240 | self.assertAllClose(from_iw_v.vs, from_logits_output_v.vs)
241 | self.assertAllClose(from_iw_v.pg_advantages,
242 | from_logits_output_v.pg_advantages)
243 | self.assertAllClose(ground_truth_behaviour_action_log_probs,
244 | from_logits_output_v.behaviour_action_log_probs)
245 | self.assertAllClose(ground_truth_target_action_log_probs,
246 | from_logits_output_v.target_action_log_probs)
247 | self.assertAllClose(ground_truth_log_rhos, from_logits_output_v.log_rhos)
248 |
249 | def test_higher_rank_inputs_for_importance_weights(self):
250 | """Checks support for additional dimensions in inputs."""
251 | placeholders = {
252 | 'log_rhos': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]),
253 | 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]),
254 | 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]),
255 | 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]),
256 | 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None, 42])
257 | }
258 | output = vtrace.from_importance_weights(**placeholders)
259 | self.assertEqual(output.vs.shape.as_list()[-1], 42)
260 |
261 | def test_inconsistent_rank_inputs_for_importance_weights(self):
262 | """Test one of many possible errors in shape of inputs."""
263 | placeholders = {
264 | 'log_rhos': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]),
265 | 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]),
266 | 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]),
267 | 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]),
268 | # Should be [None, 42].
269 | 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None])
270 | }
271 | with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
272 | vtrace.from_importance_weights(**placeholders)
273 |
274 |
275 | if __name__ == '__main__':
276 | tf.test.main()
277 |
--------------------------------------------------------------------------------