├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── batcher.cc ├── dmlab30.py ├── dynamic_batching.py ├── dynamic_batching_test.py ├── environments.py ├── experiment.py ├── py_process.py ├── py_process_test.py ├── vtrace.py └── vtrace_test.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | # Install dependencies. 4 | # g++ (v. 5.4) does not work: https://github.com/tensorflow/tensorflow/issues/13308 5 | RUN apt-get update && apt-get install -y \ 6 | curl \ 7 | zip \ 8 | unzip \ 9 | software-properties-common \ 10 | pkg-config \ 11 | g++-4.8 \ 12 | zlib1g-dev \ 13 | python \ 14 | lua5.1 \ 15 | liblua5.1-0-dev \ 16 | libffi-dev \ 17 | gettext \ 18 | freeglut3 \ 19 | libsdl2-dev \ 20 | libosmesa6-dev \ 21 | libglu1-mesa \ 22 | libglu1-mesa-dev \ 23 | python-dev \ 24 | build-essential \ 25 | git \ 26 | python-setuptools \ 27 | python-pip \ 28 | libjpeg-dev 29 | 30 | # Install bazel 31 | RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | \ 32 | tee /etc/apt/sources.list.d/bazel.list && \ 33 | curl https://bazel.build/bazel-release.pub.gpg | \ 34 | apt-key add - && \ 35 | apt-get update && apt-get install -y bazel 36 | 37 | # Install TensorFlow and other dependencies 38 | RUN pip install tensorflow==1.9.0 dm-sonnet==1.23 39 | 40 | # Build and install DeepMind Lab pip package. 41 | # We explicitly set the Numpy path as shown here: 42 | # https://github.com/deepmind/lab/blob/master/docs/users/build.md 43 | RUN NP_INC="$(python -c 'import numpy as np; print(np.get_include())[5:]')" && \ 44 | git clone https://github.com/deepmind/lab.git && \ 45 | cd lab && \ 46 | sed -i 's@hdrs = glob(\[@hdrs = glob(["'"$NP_INC"'/\*\*/*.h", @g' python.BUILD && \ 47 | sed -i 's@includes = \[@includes = ["'"$NP_INC"'", @g' python.BUILD && \ 48 | bazel build -c opt python/pip_package:build_pip_package && \ 49 | pip install wheel && \ 50 | ./bazel-bin/python/pip_package/build_pip_package /tmp/dmlab_pkg && \ 51 | pip install /tmp/dmlab_pkg/DeepMind_Lab-1.0-py2-none-any.whl --force-reinstall 52 | 53 | # Install dataset (from https://github.com/deepmind/lab/tree/master/data/brady_konkle_oliva2008) 54 | RUN mkdir dataset && \ 55 | cd dataset && \ 56 | pip install Pillow && \ 57 | curl -sS https://raw.githubusercontent.com/deepmind/lab/master/data/brady_konkle_oliva2008/README.md | \ 58 | tr '\n' '\r' | \ 59 | sed -e 's/.*```sh\(.*\)```.*/\1/' | \ 60 | tr '\r' '\n' | \ 61 | bash 62 | 63 | # Clone. 64 | RUN git clone https://github.com/deepmind/scalable_agent.git 65 | WORKDIR scalable_agent 66 | 67 | # Build dynamic batching module. 68 | RUN TF_INC="$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')" && \ 69 | TF_LIB="$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')" && \ 70 | g++-4.8 -std=c++11 -shared batcher.cc -o batcher.so -fPIC -I $TF_INC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 -L$TF_LIB -ltensorflow_framework 71 | 72 | # Run tests. 73 | RUN python py_process_test.py 74 | RUN python dynamic_batching_test.py 75 | RUN python vtrace_test.py 76 | 77 | # Run. 78 | CMD ["sh", "-c", "python experiment.py --total_environment_frames=10000 --dataset_path=../dataset && python experiment.py --mode=test --test_num_episodes=5"] 79 | 80 | # Docker commands: 81 | # docker rm scalable_agent -v 82 | # docker build -t scalable_agent . 83 | # docker run --name scalable_agent scalable_agent 84 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures 2 | 3 | This repository contains an implementation of "Importance Weighted Actor-Learner 4 | Architectures", along with a *dynamic batching* module. This is not an 5 | officially supported Google product. 6 | 7 | For a detailed description of the architecture please read [our paper][arxiv]. 8 | Please cite the paper if you use the code from this repository in your work. 9 | 10 | ### Bibtex 11 | 12 | ``` 13 | @inproceedings{impala2018, 14 | title={IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures}, 15 | author={Espeholt, Lasse and Soyer, Hubert and Munos, Remi and Simonyan, Karen and Mnih, Volodymir and Ward, Tom and Doron, Yotam and Firoiu, Vlad and Harley, Tim and Dunning, Iain and others}, 16 | booktitle={Proceedings of the International Conference on Machine Learning (ICML)}, 17 | year={2018} 18 | } 19 | ``` 20 | 21 | ## Running the Code 22 | 23 | ### Prerequisites 24 | 25 | [TensorFlow][tensorflow] >=1.9.0-dev20180530, the environment 26 | [DeepMind Lab][deepmind_lab] and the neural network library 27 | [DeepMind Sonnet][sonnet]. Although we use [DeepMind Lab][deepmind_lab] in this 28 | release, the agent has been successfully applied to other domains such as 29 | [Atari][arxiv], [Street View][learning_nav] and has been modified to 30 | [generate images][generate_images]. 31 | 32 | We include a [Dockerfile][dockerfile] that serves as a reference for the 33 | prerequisites and commands needed to run the code. 34 | 35 | ### Single Machine Training on a Single Level 36 | 37 | Training on `explore_goal_locations_small`. Most runs should end up with average 38 | episode returns around 200 or around 250 after 1B frames. 39 | 40 | ```sh 41 | python experiment.py --num_actors=48 --batch_size=32 42 | ``` 43 | 44 | Adjust the number of actors (i.e. number of environments) and batch size to 45 | match the size of the machine it runs on. A single actor, including DeepMind 46 | Lab, requires a few hundred MB of RAM. 47 | 48 | ### Distributed Training on DMLab-30 49 | 50 | Training on the full [DMLab-30][dmlab30]. Across 10 runs with different seeds 51 | but identical hyperparameters, we observed between 45 and 50 capped human 52 | normalized training score with different seeds (`--seed=[seed]`). Test scores 53 | are usually an absolute of ~2% lower. 54 | 55 | #### Learner 56 | 57 | ```sh 58 | python experiment.py --job_name=learner --task=0 --num_actors=150 \ 59 | --level_name=dmlab30 --batch_size=32 --entropy_cost=0.0033391318945337044 \ 60 | --learning_rate=0.00031866995608948655 \ 61 | --total_environment_frames=10000000000 --reward_clipping=soft_asymmetric 62 | ``` 63 | 64 | #### Actor(s) 65 | 66 | ```sh 67 | for i in $(seq 0 149); do 68 | python experiment.py --job_name=actor --task=$i \ 69 | --num_actors=150 --level_name=dmlab30 --dataset_path=[...] & 70 | done; 71 | wait 72 | ``` 73 | 74 | #### Test Score 75 | 76 | ```sh 77 | python experiment.py --mode=test --level_name=dmlab30 --dataset_path=[...] \ 78 | --test_num_episodes=10 79 | ``` 80 | 81 | [arxiv]: https://arxiv.org/abs/1802.01561 82 | [deepmind_lab]: https://github.com/deepmind/lab 83 | [sonnet]: https://github.com/deepmind/sonnet 84 | [learning_nav]: https://arxiv.org/abs/1804.00168 85 | [generate_images]: https://deepmind.com/blog/learning-to-generate-images/ 86 | [tensorflow]: https://github.com/tensorflow/tensorflow 87 | [dockerfile]: Dockerfile 88 | [dmlab30]: https://github.com/deepmind/lab/tree/master/game_scripts/levels/contributed/dmlab30 89 | -------------------------------------------------------------------------------- /batcher.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // TensorFlow operations for dynamic batching. 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "tensorflow/core/framework/op.h" 23 | #include "tensorflow/core/framework/op_kernel.h" 24 | #include "tensorflow/core/framework/resource_op_kernel.h" 25 | #include "tensorflow/core/framework/shape_inference.h" 26 | #include "tensorflow/core/lib/gtl/flatmap.h" 27 | #include "tensorflow/core/lib/gtl/optional.h" 28 | #include "tensorflow/core/lib/strings/strcat.h" 29 | #include "tensorflow/core/util/batch_util.h" 30 | #include "tensorflow/core/util/work_sharder.h" 31 | 32 | namespace tensorflow { 33 | namespace { 34 | 35 | REGISTER_OP("Batcher") 36 | .Output("handle: resource") 37 | .Attr("minimum_batch_size: int") 38 | .Attr("maximum_batch_size: int") 39 | .Attr("timeout_ms: int") 40 | .Attr("container: string = ''") 41 | .Attr("shared_name: string = ''") 42 | .SetIsStateful() 43 | .SetShapeFn(shape_inference::ScalarShape) 44 | .Doc(R"doc( 45 | A Batcher which batches up computations into the same batch. 46 | )doc"); 47 | 48 | REGISTER_OP("BatcherCompute") 49 | .Input("handle: resource") 50 | .Input("input_list: Tinput_list") 51 | .Attr("Tinput_list: list(type) >= 1") 52 | .Attr("Toutput_list: list(type) >= 1") 53 | .Output("output_list: Toutput_list") 54 | .SetShapeFn(shape_inference::UnknownShape) 55 | .Doc(R"doc( 56 | Puts the input into the computation queue, waits and returns the result. 57 | )doc"); 58 | 59 | REGISTER_OP("BatcherGetInputs") 60 | .Input("handle: resource") 61 | .Attr("Toutput_list: list(type) >= 1") 62 | .Output("output_list: Toutput_list") 63 | .Output("computation_id: int64") 64 | .SetShapeFn([](shape_inference::InferenceContext* c) { 65 | for (int i = 0; i < c->num_outputs() - 1; ++i) { 66 | c->set_output(i, c->UnknownShape()); 67 | } 68 | return c->set_output("computation_id", {c->Scalar()}); 69 | }) 70 | .Doc(R"doc( 71 | Gets a batch of inputs to compute the results of. 72 | )doc"); 73 | 74 | REGISTER_OP("BatcherSetOutputs") 75 | .Input("handle: resource") 76 | .Input("input_list: Tinput_list") 77 | .Input("computation_id: int64") 78 | .Attr("Tinput_list: list(type) >= 1") 79 | .SetShapeFn(shape_inference::UnknownShape) 80 | .Doc(R"doc( 81 | Sets the outputs of a batch for the function. 82 | )doc"); 83 | 84 | REGISTER_OP("BatcherClose") 85 | .Input("handle: resource") 86 | .SetShapeFn(shape_inference::NoOutputs) 87 | .Doc(R"doc( 88 | Closes the batcher and cancels all pending batcher operations. 89 | )doc"); 90 | 91 | class Batcher : public ResourceBase { 92 | public: 93 | using DoneCallback = AsyncOpKernel::DoneCallback; 94 | 95 | Batcher(int32 minimum_batch_size, int32 maximum_batch_size, 96 | gtl::optional timeout) 97 | : ResourceBase(), 98 | curr_computation_id_(0), 99 | is_closed_(false), 100 | minimum_batch_size_(minimum_batch_size), 101 | maximum_batch_size_(maximum_batch_size), 102 | timeout_(std::move(timeout)) {} 103 | 104 | string DebugString() override { 105 | mutex_lock l(mu_); 106 | return strings::StrCat("Batcher with ", inputs_.size(), " waiting inputs."); 107 | } 108 | 109 | void Compute(OpKernelContext* context, const OpInputList& input_list, 110 | DoneCallback callback); 111 | 112 | void GetInputs(OpKernelContext* context, OpOutputList* output_list); 113 | 114 | void SetOutputs(OpKernelContext* context, const OpInputList& input_list, 115 | int64 computation_id); 116 | 117 | void Close(OpKernelContext* context); 118 | 119 | private: 120 | class Input { 121 | public: 122 | Input(OpKernelContext* context, const OpInputList& input_list, 123 | DoneCallback callback) 124 | : context_(context), 125 | input_list_(input_list), 126 | callback_(std::move(callback)) {} 127 | 128 | // Moveable but not copyable. 129 | Input(Input&& rhs) 130 | : Input(rhs.context_, rhs.input_list_, std::move(rhs.callback_)) { 131 | rhs.context_ = nullptr; // Mark invalid. 132 | } 133 | 134 | Input& operator=(Input&& rhs) { 135 | this->context_ = rhs.context_; 136 | this->input_list_ = rhs.input_list_; 137 | this->callback_ = std::move(rhs.callback_); 138 | rhs.context_ = nullptr; // Mark invalid. 139 | return *this; 140 | } 141 | 142 | OpKernelContext* context() const { 143 | CHECK(is_valid()); 144 | return context_; 145 | } 146 | 147 | const OpInputList& input_list() const { 148 | CHECK(is_valid()); 149 | return input_list_; 150 | } 151 | 152 | bool is_valid() const { return context_ != nullptr; } 153 | 154 | void Done() { 155 | CHECK(is_valid()); 156 | 157 | // After callback is called, context_, input_list_ and callback_ becomes 158 | // invalid and shouldn't be used. 159 | context_ = nullptr; 160 | callback_(); 161 | } 162 | 163 | private: 164 | // Not owned. 165 | OpKernelContext* context_; 166 | OpInputList input_list_; 167 | DoneCallback callback_; 168 | }; 169 | 170 | void CancelInput(Input* input) EXCLUSIVE_LOCKS_REQUIRED(mu_); 171 | 172 | void GetInputsInternal(OpKernelContext* context, OpOutputList* output_list) 173 | EXCLUSIVE_LOCKS_REQUIRED(mu_); 174 | 175 | void SetOutputsInternal(OpKernelContext* context, 176 | const OpInputList& input_list, int64 computation_id) 177 | EXCLUSIVE_LOCKS_REQUIRED(mu_); 178 | 179 | // Cancels all pending Compute ops and marks the batcher closed. 180 | void CancelAndClose(OpKernelContext* context) EXCLUSIVE_LOCKS_REQUIRED(mu_); 181 | 182 | mutex mu_; 183 | condition_variable full_batch_or_cancelled_cond_var_; 184 | 185 | // A counter of all batched computations that have been started that is used 186 | // to create a unique id for each batched computation. 187 | int64 curr_computation_id_ GUARDED_BY(mu_); 188 | 189 | // Inputs waiting to be computed. 190 | std::deque inputs_ GUARDED_BY(mu_); 191 | 192 | // Batches that are currently being computed. Maps computation_id to a batch 193 | // of inputs. 194 | gtl::FlatMap> being_computed_ GUARDED_BY(mu_); 195 | 196 | // Whether the Batcher has been closed (happens when there is an error or 197 | // Close() has been called.) 198 | bool is_closed_ GUARDED_BY(mu_); 199 | const int32 minimum_batch_size_; 200 | const int32 maximum_batch_size_; 201 | const gtl::optional timeout_; 202 | 203 | TF_DISALLOW_COPY_AND_ASSIGN(Batcher); 204 | }; 205 | 206 | void Batcher::Compute(OpKernelContext* context, const OpInputList& input_list, 207 | DoneCallback callback) { 208 | bool should_notify; 209 | 210 | { 211 | mutex_lock l(mu_); 212 | 213 | OP_REQUIRES_ASYNC(context, !is_closed_, 214 | errors::Cancelled("Batcher is closed"), callback); 215 | 216 | // Add the inputs to the list of inputs. 217 | inputs_.emplace_back(context, input_list, std::move(callback)); 218 | 219 | should_notify = inputs_.size() >= minimum_batch_size_; 220 | } 221 | 222 | if (should_notify) { 223 | // If a GetInputs operation is blocked, wake it up. 224 | full_batch_or_cancelled_cond_var_.notify_one(); 225 | } 226 | } 227 | 228 | void Batcher::GetInputs(OpKernelContext* context, OpOutputList* output_list) { 229 | CancellationManager* cm = context->cancellation_manager(); 230 | CancellationToken token = cm->get_cancellation_token(); 231 | 232 | bool is_cancelled_or_cancelling = !cm->RegisterCallback( 233 | token, [this]() { full_batch_or_cancelled_cond_var_.notify_all(); }); 234 | 235 | mutex_lock l(mu_); 236 | std::cv_status status = std::cv_status::no_timeout; 237 | 238 | // Wait for data if the input list has fewer samples than `minimum_batch_size` 239 | // (or non-empty when a timeout has occurred), for cancellation of the 240 | // operation or for the batcher to be closed. 241 | while (((status == std::cv_status::timeout && inputs_.empty()) || 242 | (status == std::cv_status::no_timeout && 243 | inputs_.size() < minimum_batch_size_)) && 244 | !is_cancelled_or_cancelling && !is_closed_) { 245 | // Using a timeout to make sure the operation always completes after a while 246 | // when there isn't enough samples and for the unlikely case where the 247 | // operation is being cancelled between checking if it has been cancelled 248 | // and calling wait_for(). 249 | if (timeout_) { 250 | status = full_batch_or_cancelled_cond_var_.wait_for(l, *timeout_); 251 | } else { 252 | // Timeout is only used to check for cancellation as described in the 253 | // comment above. 254 | full_batch_or_cancelled_cond_var_.wait_for( 255 | l, std::chrono::milliseconds(100)); 256 | } 257 | is_cancelled_or_cancelling = cm->IsCancelled(); 258 | } 259 | 260 | if (is_closed_) { 261 | context->SetStatus(errors::Cancelled("Batcher is closed")); 262 | } else if (is_cancelled_or_cancelling) { 263 | context->SetStatus(errors::Cancelled("GetInputs operation was cancelled")); 264 | } else { 265 | GetInputsInternal(context, output_list); 266 | } 267 | 268 | if (!context->status().ok()) { 269 | CancelAndClose(context); 270 | } 271 | } 272 | 273 | void Batcher::GetInputsInternal(OpKernelContext* context, 274 | OpOutputList* output_list) { 275 | int64 batch_size = std::min(inputs_.size(), maximum_batch_size_); 276 | size_t num_tensors = inputs_.front().input_list().size(); 277 | 278 | // Allocate output tensors. 279 | std::vector output_tensors(num_tensors); 280 | for (size_t i = 0; i < num_tensors; ++i) { 281 | TensorShape shape = inputs_.front().input_list()[i].shape(); 282 | OP_REQUIRES( 283 | context, shape.dim_size(0) == 1, 284 | errors::InvalidArgument("Batcher requires batch size 1 but was ", 285 | shape.dim_size(0))); 286 | shape.set_dim(0, batch_size); 287 | 288 | OP_REQUIRES_OK(context, 289 | output_list->allocate(i, shape, &output_tensors[i])); 290 | } 291 | 292 | auto work = [this, &context, &output_tensors, num_tensors]( 293 | int64 start, int64 end) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 294 | for (int64 j = start; j < end; ++j) { 295 | for (size_t i = 0; i < num_tensors; ++i) { 296 | OP_REQUIRES(context, 297 | inputs_[0].input_list()[i].shape() == 298 | inputs_[j].input_list()[i].shape(), 299 | errors::InvalidArgument( 300 | "Shapes of inputs much be equal. Shapes observed: ", 301 | inputs_[0].input_list()[i].shape().DebugString(), ", ", 302 | inputs_[j].input_list()[i].shape().DebugString())); 303 | 304 | OP_REQUIRES_OK(context, 305 | tensorflow::batch_util::CopyElementToSlice( 306 | inputs_[j].input_list()[i], output_tensors[i], j)); 307 | } 308 | } 309 | }; 310 | 311 | auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); 312 | Shard(worker_threads->num_threads, worker_threads->workers, batch_size, 10, 313 | work); 314 | 315 | // New unique computation id. 316 | int64 new_computation_id = curr_computation_id_++; 317 | Tensor* computation_id_t = nullptr; 318 | OP_REQUIRES_OK(context, 319 | context->allocate_output("computation_id", TensorShape({}), 320 | &computation_id_t)); 321 | computation_id_t->scalar()() = new_computation_id; 322 | 323 | // Move the batch of inputs into a list for the new computation. 324 | auto iter = std::make_move_iterator(inputs_.begin()); 325 | being_computed_.emplace(new_computation_id, 326 | std::vector{iter, iter + batch_size}); 327 | inputs_.erase(inputs_.begin(), inputs_.begin() + batch_size); 328 | } 329 | 330 | void Batcher::SetOutputs(OpKernelContext* context, 331 | const OpInputList& input_list, int64 computation_id) { 332 | mutex_lock l(mu_); 333 | SetOutputsInternal(context, input_list, computation_id); 334 | if (!context->status().ok()) { 335 | CancelAndClose(context); 336 | } 337 | } 338 | 339 | void Batcher::SetOutputsInternal(OpKernelContext* context, 340 | const OpInputList& input_list, 341 | int64 computation_id) { 342 | OP_REQUIRES(context, !is_closed_, errors::Cancelled("Batcher is closed")); 343 | 344 | auto search = being_computed_.find(computation_id); 345 | OP_REQUIRES( 346 | context, search != being_computed_.end(), 347 | errors::InvalidArgument("Invalid computation id. Id: ", computation_id)); 348 | auto& computation_input_list = search->second; 349 | int64 expected_batch_size = computation_input_list.size(); 350 | 351 | for (const Tensor& tensor : input_list) { 352 | OP_REQUIRES( 353 | context, tensor.shape().dims() > 0, 354 | errors::InvalidArgument( 355 | "Output shape must have a batch dimension. Shape observed: ", 356 | tensor.shape().DebugString())); 357 | OP_REQUIRES( 358 | context, tensor.shape().dim_size(0) == expected_batch_size, 359 | errors::InvalidArgument("Output shape must have the same batch " 360 | "dimension as the input batch size. Expected: ", 361 | expected_batch_size, 362 | " Observed: ", tensor.shape().dim_size(0))); 363 | } 364 | 365 | auto work = [this, &input_list, &context, &computation_input_list]( 366 | int64 start, int64 end) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 367 | for (int64 j = start; j < end; ++j) { 368 | Input& input = computation_input_list[j]; 369 | 370 | for (size_t i = 0; i < input_list.size(); ++i) { 371 | TensorShape shape = input_list[i].shape(); 372 | shape.set_dim(0, 1); 373 | 374 | Tensor* output_tensor; 375 | OP_REQUIRES_OK(context, input.context()->allocate_output( 376 | i, shape, &output_tensor)); 377 | 378 | OP_REQUIRES_OK(context, tensorflow::batch_util::CopySliceToElement( 379 | input_list[i], output_tensor, j)); 380 | } 381 | 382 | input.Done(); 383 | } 384 | }; 385 | 386 | auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); 387 | Shard(worker_threads->num_threads, worker_threads->workers, 388 | expected_batch_size, 50000, work); 389 | 390 | being_computed_.erase(computation_id); 391 | } 392 | 393 | void Batcher::Close(OpKernelContext* context) { 394 | { 395 | mutex_lock l(mu_); 396 | CancelAndClose(context); 397 | } 398 | 399 | // Cancel all running GetInputs operations. 400 | full_batch_or_cancelled_cond_var_.notify_all(); 401 | } 402 | 403 | void Batcher::CancelInput(Batcher::Input* input) { 404 | // Some may already have had their outputs set and the callback called so 405 | // they should be skipped. 406 | if (!input->is_valid()) { 407 | return; 408 | } 409 | 410 | input->context()->CtxFailure(errors::Cancelled("Compute was cancelled")); 411 | input->Done(); 412 | } 413 | 414 | void Batcher::CancelAndClose(OpKernelContext* context) { 415 | // Something went wrong or the batcher was requested to close. All the waiting 416 | // Compute ops should be cancelled. 417 | 418 | if (is_closed_) { 419 | return; 420 | } 421 | 422 | for (auto& input : inputs_) { 423 | CancelInput(&input); 424 | } 425 | for (auto& p : being_computed_) { 426 | for (auto& input : p.second) { 427 | CancelInput(&input); 428 | } 429 | } 430 | is_closed_ = true; // Causes future Compute operations to be cancelled. 431 | } 432 | 433 | class BatcherHandleOp : public ResourceOpKernel { 434 | public: 435 | explicit BatcherHandleOp(OpKernelConstruction* context) 436 | : ResourceOpKernel(context) { 437 | OP_REQUIRES_OK( 438 | context, context->GetAttr("minimum_batch_size", &minimum_batch_size_)); 439 | OP_REQUIRES_OK( 440 | context, context->GetAttr("maximum_batch_size", &maximum_batch_size_)); 441 | OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_ms_)); 442 | } 443 | 444 | private: 445 | Status CreateResource(Batcher** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) { 446 | gtl::optional timeout; 447 | if (timeout_ms_ != -1) { 448 | timeout = std::chrono::milliseconds(timeout_ms_); 449 | } 450 | *ret = new Batcher(minimum_batch_size_, maximum_batch_size_, timeout); 451 | return Status::OK(); 452 | } 453 | 454 | int32 minimum_batch_size_; 455 | int32 maximum_batch_size_; 456 | int32 timeout_ms_; 457 | 458 | TF_DISALLOW_COPY_AND_ASSIGN(BatcherHandleOp); 459 | }; 460 | 461 | class ComputeOp : public AsyncOpKernel { 462 | public: 463 | explicit ComputeOp(OpKernelConstruction* context) : AsyncOpKernel(context) {} 464 | 465 | void ComputeAsync(OpKernelContext* context, DoneCallback callback) override { 466 | Batcher* batcher; 467 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 468 | &batcher)); 469 | 470 | OpInputList input_list; 471 | OP_REQUIRES_OK(context, context->input_list("input_list", &input_list)); 472 | 473 | batcher->Compute(context, input_list, std::move(callback)); 474 | } 475 | 476 | private: 477 | TF_DISALLOW_COPY_AND_ASSIGN(ComputeOp); 478 | }; 479 | 480 | class GetInputsOp : public OpKernel { 481 | public: 482 | explicit GetInputsOp(OpKernelConstruction* context) : OpKernel(context) {} 483 | 484 | void Compute(OpKernelContext* context) override { 485 | Batcher* batcher; 486 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 487 | &batcher)); 488 | 489 | OpOutputList output_list; 490 | OP_REQUIRES_OK(context, context->output_list("output_list", &output_list)); 491 | 492 | batcher->GetInputs(context, &output_list); 493 | } 494 | 495 | private: 496 | TF_DISALLOW_COPY_AND_ASSIGN(GetInputsOp); 497 | }; 498 | 499 | class SetOutputsOp : public OpKernel { 500 | public: 501 | explicit SetOutputsOp(OpKernelConstruction* context) : OpKernel(context) {} 502 | 503 | void Compute(OpKernelContext* context) override { 504 | Batcher* batcher; 505 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 506 | &batcher)); 507 | 508 | OpInputList input_list; 509 | OP_REQUIRES_OK(context, context->input_list("input_list", &input_list)); 510 | 511 | const Tensor* computation_id; 512 | OP_REQUIRES_OK(context, context->input("computation_id", &computation_id)); 513 | 514 | batcher->SetOutputs(context, input_list, computation_id->scalar()()); 515 | } 516 | 517 | private: 518 | TF_DISALLOW_COPY_AND_ASSIGN(SetOutputsOp); 519 | }; 520 | 521 | class CloseOp : public OpKernel { 522 | public: 523 | explicit CloseOp(OpKernelConstruction* context) : OpKernel(context) {} 524 | 525 | void Compute(OpKernelContext* context) override { 526 | Batcher* batcher; 527 | OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 528 | &batcher)); 529 | 530 | batcher->Close(context); 531 | } 532 | 533 | private: 534 | TF_DISALLOW_COPY_AND_ASSIGN(CloseOp); 535 | }; 536 | 537 | REGISTER_KERNEL_BUILDER(Name("Batcher").Device(DEVICE_CPU), BatcherHandleOp); 538 | 539 | REGISTER_KERNEL_BUILDER(Name("BatcherCompute").Device(DEVICE_CPU), ComputeOp); 540 | 541 | REGISTER_KERNEL_BUILDER(Name("BatcherGetInputs").Device(DEVICE_CPU), 542 | GetInputsOp); 543 | 544 | REGISTER_KERNEL_BUILDER(Name("BatcherSetOutputs").Device(DEVICE_CPU), 545 | SetOutputsOp); 546 | 547 | REGISTER_KERNEL_BUILDER(Name("BatcherClose").Device(DEVICE_CPU), CloseOp); 548 | 549 | } // namespace 550 | } // namespace tensorflow 551 | -------------------------------------------------------------------------------- /dmlab30.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for DMLab-30.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | 27 | LEVEL_MAPPING = collections.OrderedDict([ 28 | ('rooms_collect_good_objects_train', 'rooms_collect_good_objects_test'), 29 | ('rooms_exploit_deferred_effects_train', 30 | 'rooms_exploit_deferred_effects_test'), 31 | ('rooms_select_nonmatching_object', 'rooms_select_nonmatching_object'), 32 | ('rooms_watermaze', 'rooms_watermaze'), 33 | ('rooms_keys_doors_puzzle', 'rooms_keys_doors_puzzle'), 34 | ('language_select_described_object', 'language_select_described_object'), 35 | ('language_select_located_object', 'language_select_located_object'), 36 | ('language_execute_random_task', 'language_execute_random_task'), 37 | ('language_answer_quantitative_question', 38 | 'language_answer_quantitative_question'), 39 | ('lasertag_one_opponent_small', 'lasertag_one_opponent_small'), 40 | ('lasertag_three_opponents_small', 'lasertag_three_opponents_small'), 41 | ('lasertag_one_opponent_large', 'lasertag_one_opponent_large'), 42 | ('lasertag_three_opponents_large', 'lasertag_three_opponents_large'), 43 | ('natlab_fixed_large_map', 'natlab_fixed_large_map'), 44 | ('natlab_varying_map_regrowth', 'natlab_varying_map_regrowth'), 45 | ('natlab_varying_map_randomized', 'natlab_varying_map_randomized'), 46 | ('skymaze_irreversible_path_hard', 'skymaze_irreversible_path_hard'), 47 | ('skymaze_irreversible_path_varied', 'skymaze_irreversible_path_varied'), 48 | ('psychlab_arbitrary_visuomotor_mapping', 49 | 'psychlab_arbitrary_visuomotor_mapping'), 50 | ('psychlab_continuous_recognition', 'psychlab_continuous_recognition'), 51 | ('psychlab_sequential_comparison', 'psychlab_sequential_comparison'), 52 | ('psychlab_visual_search', 'psychlab_visual_search'), 53 | ('explore_object_locations_small', 'explore_object_locations_small'), 54 | ('explore_object_locations_large', 'explore_object_locations_large'), 55 | ('explore_obstructed_goals_small', 'explore_obstructed_goals_small'), 56 | ('explore_obstructed_goals_large', 'explore_obstructed_goals_large'), 57 | ('explore_goal_locations_small', 'explore_goal_locations_small'), 58 | ('explore_goal_locations_large', 'explore_goal_locations_large'), 59 | ('explore_object_rewards_few', 'explore_object_rewards_few'), 60 | ('explore_object_rewards_many', 'explore_object_rewards_many'), 61 | ]) 62 | 63 | HUMAN_SCORES = { 64 | 'rooms_collect_good_objects_test': 10, 65 | 'rooms_exploit_deferred_effects_test': 85.65, 66 | 'rooms_select_nonmatching_object': 65.9, 67 | 'rooms_watermaze': 54, 68 | 'rooms_keys_doors_puzzle': 53.8, 69 | 'language_select_described_object': 389.5, 70 | 'language_select_located_object': 280.7, 71 | 'language_execute_random_task': 254.05, 72 | 'language_answer_quantitative_question': 184.5, 73 | 'lasertag_one_opponent_small': 12.65, 74 | 'lasertag_three_opponents_small': 18.55, 75 | 'lasertag_one_opponent_large': 18.6, 76 | 'lasertag_three_opponents_large': 31.5, 77 | 'natlab_fixed_large_map': 36.9, 78 | 'natlab_varying_map_regrowth': 24.45, 79 | 'natlab_varying_map_randomized': 42.35, 80 | 'skymaze_irreversible_path_hard': 100, 81 | 'skymaze_irreversible_path_varied': 100, 82 | 'psychlab_arbitrary_visuomotor_mapping': 58.75, 83 | 'psychlab_continuous_recognition': 58.3, 84 | 'psychlab_sequential_comparison': 39.5, 85 | 'psychlab_visual_search': 78.5, 86 | 'explore_object_locations_small': 74.45, 87 | 'explore_object_locations_large': 65.65, 88 | 'explore_obstructed_goals_small': 206, 89 | 'explore_obstructed_goals_large': 119.5, 90 | 'explore_goal_locations_small': 267.5, 91 | 'explore_goal_locations_large': 194.5, 92 | 'explore_object_rewards_few': 77.7, 93 | 'explore_object_rewards_many': 106.7, 94 | } 95 | 96 | RANDOM_SCORES = { 97 | 'rooms_collect_good_objects_test': 0.073, 98 | 'rooms_exploit_deferred_effects_test': 8.501, 99 | 'rooms_select_nonmatching_object': 0.312, 100 | 'rooms_watermaze': 4.065, 101 | 'rooms_keys_doors_puzzle': 4.135, 102 | 'language_select_described_object': -0.07, 103 | 'language_select_located_object': 1.929, 104 | 'language_execute_random_task': -5.913, 105 | 'language_answer_quantitative_question': -0.33, 106 | 'lasertag_one_opponent_small': -0.224, 107 | 'lasertag_three_opponents_small': -0.214, 108 | 'lasertag_one_opponent_large': -0.083, 109 | 'lasertag_three_opponents_large': -0.102, 110 | 'natlab_fixed_large_map': 2.173, 111 | 'natlab_varying_map_regrowth': 2.989, 112 | 'natlab_varying_map_randomized': 7.346, 113 | 'skymaze_irreversible_path_hard': 0.1, 114 | 'skymaze_irreversible_path_varied': 14.4, 115 | 'psychlab_arbitrary_visuomotor_mapping': 0.163, 116 | 'psychlab_continuous_recognition': 0.224, 117 | 'psychlab_sequential_comparison': 0.129, 118 | 'psychlab_visual_search': 0.085, 119 | 'explore_object_locations_small': 3.575, 120 | 'explore_object_locations_large': 4.673, 121 | 'explore_obstructed_goals_small': 6.76, 122 | 'explore_obstructed_goals_large': 2.61, 123 | 'explore_goal_locations_small': 7.66, 124 | 'explore_goal_locations_large': 3.14, 125 | 'explore_object_rewards_few': 2.073, 126 | 'explore_object_rewards_many': 2.438, 127 | } 128 | 129 | ALL_LEVELS = frozenset([ 130 | 'rooms_collect_good_objects_train', 131 | 'rooms_collect_good_objects_test', 132 | 'rooms_exploit_deferred_effects_train', 133 | 'rooms_exploit_deferred_effects_test', 134 | 'rooms_select_nonmatching_object', 135 | 'rooms_watermaze', 136 | 'rooms_keys_doors_puzzle', 137 | 'language_select_described_object', 138 | 'language_select_located_object', 139 | 'language_execute_random_task', 140 | 'language_answer_quantitative_question', 141 | 'lasertag_one_opponent_small', 142 | 'lasertag_three_opponents_small', 143 | 'lasertag_one_opponent_large', 144 | 'lasertag_three_opponents_large', 145 | 'natlab_fixed_large_map', 146 | 'natlab_varying_map_regrowth', 147 | 'natlab_varying_map_randomized', 148 | 'skymaze_irreversible_path_hard', 149 | 'skymaze_irreversible_path_varied', 150 | 'psychlab_arbitrary_visuomotor_mapping', 151 | 'psychlab_continuous_recognition', 152 | 'psychlab_sequential_comparison', 153 | 'psychlab_visual_search', 154 | 'explore_object_locations_small', 155 | 'explore_object_locations_large', 156 | 'explore_obstructed_goals_small', 157 | 'explore_obstructed_goals_large', 158 | 'explore_goal_locations_small', 159 | 'explore_goal_locations_large', 160 | 'explore_object_rewards_few', 161 | 'explore_object_rewards_many', 162 | ]) 163 | 164 | 165 | def _transform_level_returns(level_returns): 166 | """Converts training level names to test level names.""" 167 | new_level_returns = {} 168 | for level_name, returns in level_returns.iteritems(): 169 | new_level_returns[LEVEL_MAPPING.get(level_name, level_name)] = returns 170 | 171 | test_set = set(LEVEL_MAPPING.values()) 172 | diff = test_set - set(new_level_returns.keys()) 173 | if diff: 174 | raise ValueError('Missing levels: %s' % list(diff)) 175 | 176 | for level_name, returns in new_level_returns.iteritems(): 177 | if level_name in test_set: 178 | if not returns: 179 | raise ValueError('Missing returns for level: \'%s\': ' % level_name) 180 | else: 181 | tf.logging.info('Skipping level %s for calculation.', level_name) 182 | 183 | return new_level_returns 184 | 185 | 186 | def compute_human_normalized_score(level_returns, per_level_cap): 187 | """Computes human normalized score. 188 | 189 | Levels that have different training and test versions, will use the returns 190 | for the training level to calculate the score. E.g. 191 | 'rooms_collect_good_objects_train' will be used for 192 | 'rooms_collect_good_objects_test'. All returns for levels not in DmLab-30 193 | will be ignored. 194 | 195 | Args: 196 | level_returns: A dictionary from level to list of episode returns. 197 | per_level_cap: A percentage cap (e.g. 100.) on the per level human 198 | normalized score. If None, no cap is applied. 199 | 200 | Returns: 201 | A float with the human normalized score in percentage. 202 | 203 | Raises: 204 | ValueError: If a level is missing from `level_returns` or has no returns. 205 | """ 206 | new_level_returns = _transform_level_returns(level_returns) 207 | 208 | def human_normalized_score(level_name, returns): 209 | score = np.mean(returns) 210 | human = HUMAN_SCORES[level_name] 211 | random = RANDOM_SCORES[level_name] 212 | human_normalized_score = (score - random) / (human - random) * 100 213 | if per_level_cap is not None: 214 | human_normalized_score = min(human_normalized_score, per_level_cap) 215 | return human_normalized_score 216 | 217 | return np.mean( 218 | [human_normalized_score(k, v) for k, v in new_level_returns.items()]) 219 | -------------------------------------------------------------------------------- /dynamic_batching.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dynamic batching.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | 23 | import tensorflow as tf 24 | 25 | batcher_ops = tf.load_op_library('./batcher.so') 26 | 27 | nest = tf.contrib.framework.nest 28 | 29 | 30 | class _Batcher(object): 31 | """A thin layer around the Batcher TensorFlow operations. 32 | 33 | It shares some of the interface with queues (close(), name) to be able to use 34 | it correctly as the input to a QueueRunner. 35 | """ 36 | 37 | def __init__(self, minimum_batch_size, maximum_batch_size, timeout_ms): 38 | self._handle = batcher_ops.batcher(minimum_batch_size, maximum_batch_size, 39 | timeout_ms or -1) 40 | 41 | @property 42 | def name(self): 43 | return 'batcher' 44 | 45 | def get_inputs(self, input_dtypes): 46 | return batcher_ops.batcher_get_inputs(self._handle, input_dtypes) 47 | 48 | def set_outputs(self, flat_result, computation_id): 49 | return batcher_ops.batcher_set_outputs(self._handle, flat_result, 50 | computation_id) 51 | 52 | def compute(self, flat_args, output_dtypes): 53 | return batcher_ops.batcher_compute(self._handle, flat_args, output_dtypes) 54 | 55 | def close(self, cancel_pending_enqueues=False, name=None): 56 | del cancel_pending_enqueues 57 | return batcher_ops.batcher_close(self._handle, name=name) 58 | 59 | 60 | def batch_fn(f): 61 | """See `batch_fn_with_options` for details.""" 62 | return batch_fn_with_options()(f) 63 | 64 | 65 | def batch_fn_with_options(minimum_batch_size=1, maximum_batch_size=1024, 66 | timeout_ms=100): 67 | """Python decorator that automatically batches computations. 68 | 69 | When the decorated function is called, it creates an operation that adds the 70 | inputs to a queue, waits until the computation is done, and returns the 71 | tensors. The inputs must be nests (see `tf.contrib.framework.nest`) and the 72 | first dimension of each tensor in the nest must have size 1. 73 | 74 | It adds a QueueRunner that asynchronously keeps fetching batches of data, 75 | computes the results and pushes the results back to the caller. 76 | 77 | Example usage: 78 | 79 | @dynamic_batching.batch_fn_with_options( 80 | minimum_batch_size=10, timeout_ms=100) 81 | def fn(a, b): 82 | return a + b 83 | 84 | output0 = fn(tf.constant([1]), tf.constant([2])) # Will be batched with the 85 | # next call. 86 | output1 = fn(tf.constant([3]), tf.constant([4])) 87 | 88 | Note, gradients are currently not supported. 89 | Note, if minimum_batch_size == maximum_batch_size and timeout_ms=None, then 90 | the batch size of input arguments will be set statically. Otherwise, it will 91 | be None. 92 | 93 | Args: 94 | minimum_batch_size: The minimum batch size before processing starts. 95 | maximum_batch_size: The maximum batch size. 96 | timeout_ms: Milliseconds after a batch of samples is requested before it is 97 | processed, even if the batch size is smaller than `minimum_batch_size`. If 98 | None, there is no timeout. 99 | 100 | Returns: 101 | The decorator. 102 | """ 103 | 104 | def decorator(f): 105 | """Decorator.""" 106 | batcher = [None] 107 | batched_output = [None] 108 | 109 | @functools.wraps(f) 110 | def wrapper(*args): 111 | """Wrapper.""" 112 | 113 | flat_args = [tf.convert_to_tensor(arg) for arg in nest.flatten(args)] 114 | 115 | if batcher[0] is None: 116 | # Remove control dependencies which is necessary when created in loops, 117 | # etc. 118 | with tf.control_dependencies(None): 119 | input_dtypes = [t.dtype for t in flat_args] 120 | batcher[0] = _Batcher(minimum_batch_size, maximum_batch_size, 121 | timeout_ms) 122 | 123 | # Compute in batches using a queue runner. 124 | 125 | if minimum_batch_size == maximum_batch_size and timeout_ms is None: 126 | batch_size = minimum_batch_size 127 | else: 128 | batch_size = None 129 | 130 | # Dequeue batched input. 131 | inputs, computation_id = batcher[0].get_inputs(input_dtypes) 132 | nest.map_structure( 133 | lambda i, a: i.set_shape([batch_size] + a.shape.as_list()[1:]), 134 | inputs, flat_args) 135 | 136 | # Compute result. 137 | result = f(*nest.pack_sequence_as(args, inputs)) 138 | batched_output[0] = result 139 | flat_result = nest.flatten(result) 140 | 141 | # Insert results back into batcher. 142 | set_op = batcher[0].set_outputs(flat_result, computation_id) 143 | 144 | tf.train.add_queue_runner(tf.train.QueueRunner(batcher[0], [set_op])) 145 | 146 | # Insert inputs into input queue. 147 | flat_result = batcher[0].compute( 148 | flat_args, 149 | [t.dtype for t in nest.flatten(batched_output[0])]) 150 | 151 | # Restore structure and shapes. 152 | result = nest.pack_sequence_as(batched_output[0], flat_result) 153 | static_batch_size = nest.flatten(args)[0].shape[0] 154 | 155 | nest.map_structure( 156 | lambda t, b: t.set_shape([static_batch_size] + b.shape[1:].as_list()), 157 | result, batched_output[0]) 158 | return result 159 | 160 | return wrapper 161 | 162 | return decorator 163 | -------------------------------------------------------------------------------- /dynamic_batching_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests dynamic_batching.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import datetime 22 | from multiprocessing import pool 23 | import time 24 | 25 | import dynamic_batching 26 | 27 | import tensorflow as tf 28 | 29 | from six.moves import range 30 | 31 | 32 | _SLEEP_TIME = 1.0 33 | 34 | 35 | class DynamicBatchingTest(tf.test.TestCase): 36 | 37 | def test_one(self): 38 | with self.test_session() as session: 39 | @dynamic_batching.batch_fn 40 | def f(a, b): 41 | batch_size = tf.shape(a)[0] 42 | return a + b, tf.tile([batch_size], [batch_size]) 43 | 44 | output = f(tf.constant([[1, 3]]), tf.constant([2])) 45 | 46 | tf.train.start_queue_runners() 47 | 48 | result, batch_size = session.run(output) 49 | 50 | self.assertAllEqual([[3, 5]], result) 51 | self.assertAllEqual([1], batch_size) 52 | 53 | def test_two(self): 54 | with self.test_session() as session: 55 | @dynamic_batching.batch_fn 56 | def f(a, b): 57 | batch_size = tf.shape(a)[0] 58 | return a + b, tf.tile([batch_size], [batch_size]) 59 | 60 | output0 = f(tf.constant([1]), tf.constant([2])) 61 | output1 = f(tf.constant([2]), tf.constant([3])) 62 | 63 | tp = pool.ThreadPool(2) 64 | f0 = tp.apply_async(session.run, [output0]) 65 | f1 = tp.apply_async(session.run, [output1]) 66 | 67 | # Make sure both inputs are in the batcher before starting it. 68 | time.sleep(_SLEEP_TIME) 69 | 70 | tf.train.start_queue_runners() 71 | 72 | result0, batch_size0 = f0.get() 73 | result1, batch_size1 = f1.get() 74 | 75 | self.assertAllEqual([3], result0) 76 | self.assertAllEqual([2], batch_size0) 77 | self.assertAllEqual([5], result1) 78 | self.assertAllEqual([2], batch_size1) 79 | 80 | def test_many_small(self): 81 | with self.test_session() as session: 82 | @dynamic_batching.batch_fn 83 | def f(a, b): 84 | return a + b 85 | 86 | outputs = [] 87 | for i in range(200): 88 | outputs.append(f(tf.fill([1, 5], i), tf.fill([1, 5], i))) 89 | 90 | tf.train.start_queue_runners() 91 | 92 | tp = pool.ThreadPool(10) 93 | futures = [] 94 | for output in outputs: 95 | futures.append(tp.apply_async(session.run, [output])) 96 | 97 | for i, future in enumerate(futures): 98 | result = future.get() 99 | self.assertAllEqual([[i * 2] * 5], result) 100 | 101 | def test_input_batch_size_should_be_one(self): 102 | with self.test_session() as session: 103 | @dynamic_batching.batch_fn 104 | def f(a): 105 | return a 106 | 107 | output = f(tf.constant([1, 2])) 108 | 109 | coord = tf.train.Coordinator() 110 | tf.train.start_queue_runners(coord=coord) 111 | 112 | with self.assertRaises(tf.errors.CancelledError): 113 | session.run(output) 114 | 115 | with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 116 | 'requires batch size 1'): 117 | coord.join() 118 | 119 | def test_run_after_error_should_be_cancelled(self): 120 | with self.test_session() as session: 121 | 122 | @dynamic_batching.batch_fn 123 | def f(a): 124 | return a 125 | 126 | output = f(tf.constant([1, 2])) 127 | 128 | coord = tf.train.Coordinator() 129 | tf.train.start_queue_runners(coord=coord) 130 | 131 | with self.assertRaises(tf.errors.CancelledError): 132 | session.run(output) 133 | 134 | with self.assertRaises(tf.errors.CancelledError): 135 | session.run(output) 136 | 137 | def test_input_shapes_should_be_equal(self): 138 | with self.test_session() as session: 139 | 140 | @dynamic_batching.batch_fn 141 | def f(a, b): 142 | return a + b 143 | 144 | output0 = f(tf.constant([1]), tf.constant([2])) 145 | output1 = f(tf.constant([[2]]), tf.constant([3])) 146 | 147 | tp = pool.ThreadPool(2) 148 | f0 = tp.apply_async(session.run, [output0]) 149 | f1 = tp.apply_async(session.run, [output1]) 150 | 151 | time.sleep(_SLEEP_TIME) 152 | 153 | coord = tf.train.Coordinator() 154 | tf.train.start_queue_runners(coord=coord) 155 | 156 | with self.assertRaises(tf.errors.CancelledError): 157 | f0.get() 158 | f1.get() 159 | 160 | with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 161 | 'Shapes of inputs much be equal'): 162 | coord.join() 163 | 164 | def test_output_must_have_batch_dimension(self): 165 | with self.test_session() as session: 166 | @dynamic_batching.batch_fn 167 | def f(_): 168 | return tf.constant(1) 169 | 170 | output = f(tf.constant([1])) 171 | 172 | coord = tf.train.Coordinator() 173 | tf.train.start_queue_runners(coord=coord) 174 | 175 | with self.assertRaises(tf.errors.CancelledError): 176 | session.run(output) 177 | 178 | with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 179 | 'Output shape must have a batch dimension'): 180 | coord.join() 181 | 182 | def test_output_must_have_same_batch_dimension_size_as_input(self): 183 | with self.test_session() as session: 184 | @dynamic_batching.batch_fn 185 | def f(_): 186 | return tf.constant([1, 2, 3, 4]) 187 | 188 | output = f(tf.constant([1])) 189 | 190 | coord = tf.train.Coordinator() 191 | tf.train.start_queue_runners(coord=coord) 192 | 193 | with self.assertRaises(tf.errors.CancelledError): 194 | session.run(output) 195 | 196 | with self.assertRaisesRegexp( 197 | tf.errors.InvalidArgumentError, 198 | 'Output shape must have the same batch dimension as the input batch ' 199 | 'size. Expected: 1 Observed: 4'): 200 | coord.join() 201 | 202 | def test_get_inputs_cancelled(self): 203 | with tf.Graph().as_default(): 204 | 205 | @dynamic_batching.batch_fn 206 | def f(a): 207 | return a 208 | 209 | f(tf.constant([1])) 210 | 211 | # Intentionally using tf.Session() instead of self.test_session() to have 212 | # control over closing the session. test_session() is a cached session. 213 | with tf.Session(): 214 | coord = tf.train.Coordinator() 215 | tf.train.start_queue_runners(coord=coord) 216 | # Sleep to make sure the queue runner has started the first run call. 217 | time.sleep(_SLEEP_TIME) 218 | 219 | # Session closed. 220 | with self.assertRaisesRegexp(tf.errors.CancelledError, 221 | 'GetInputs operation was cancelled'): 222 | coord.join() 223 | 224 | def test_batcher_closed(self): 225 | with tf.Graph().as_default(): 226 | @dynamic_batching.batch_fn 227 | def f(a): 228 | return a 229 | 230 | f(tf.constant([1])) 231 | 232 | # Intentionally using tf.Session() instead of self.test_session() to have 233 | # control over closing the session. test_session() is a cached session. 234 | with tf.Session(): 235 | coord = tf.train.Coordinator() 236 | tf.train.start_queue_runners(coord=coord) 237 | time.sleep(_SLEEP_TIME) 238 | coord.request_stop() # Calls close operation. 239 | coord.join() 240 | # Session closed. 241 | 242 | def test_minimum_batch_size(self): 243 | with self.test_session() as session: 244 | @dynamic_batching.batch_fn_with_options( 245 | minimum_batch_size=2, timeout_ms=1000) 246 | def f(a, b): 247 | batch_size = tf.shape(a)[0] 248 | return a + b, tf.tile([batch_size], [batch_size]) 249 | 250 | output = f(tf.constant([[1, 3]]), tf.constant([2])) 251 | 252 | tf.train.start_queue_runners() 253 | 254 | start = datetime.datetime.now() 255 | session.run(output) 256 | duration = datetime.datetime.now() - start 257 | 258 | # There should have been a timeout here because only one sample was added 259 | # and the minimum batch size is 2. 260 | self.assertLessEqual(.9, duration.total_seconds()) 261 | self.assertGreaterEqual(1.5, duration.total_seconds()) 262 | 263 | outputs = [ 264 | f(tf.constant([[1, 3]]), tf.constant([2])), 265 | f(tf.constant([[1, 3]]), tf.constant([2])) 266 | ] 267 | 268 | start = datetime.datetime.now() 269 | (_, batch_size), _ = session.run(outputs) 270 | duration = datetime.datetime.now() - start 271 | 272 | # The outputs should be executed immediately because two samples are 273 | # added. 274 | self.assertGreaterEqual(.5, duration.total_seconds()) 275 | self.assertEqual(2, batch_size) 276 | 277 | def test_maximum_batch_size(self): 278 | with self.test_session() as session: 279 | @dynamic_batching.batch_fn_with_options(maximum_batch_size=2) 280 | def f(a, b): 281 | batch_size = tf.shape(a)[0] 282 | return a + b, tf.tile([batch_size], [batch_size]) 283 | 284 | outputs = [ 285 | f(tf.constant([1]), tf.constant([2])), 286 | f(tf.constant([1]), tf.constant([2])), 287 | f(tf.constant([1]), tf.constant([2])), 288 | f(tf.constant([1]), tf.constant([2])), 289 | f(tf.constant([1]), tf.constant([2])), 290 | ] 291 | 292 | tf.train.start_queue_runners() 293 | 294 | results = session.run(outputs) 295 | 296 | for value, batch_size in results: 297 | self.assertEqual(3, value) 298 | self.assertGreaterEqual(2, batch_size) 299 | 300 | def test_static_shape(self): 301 | assertions_triggered = [0] 302 | 303 | @dynamic_batching.batch_fn_with_options(minimum_batch_size=1, 304 | maximum_batch_size=2) 305 | def f0(a): 306 | self.assertEqual(None, a.shape[0].value) 307 | assertions_triggered[0] += 1 308 | return a 309 | 310 | @dynamic_batching.batch_fn_with_options(minimum_batch_size=2, 311 | maximum_batch_size=2) 312 | def f1(a): 313 | # Even though minimum_batch_size and maximum_batch_size are equal, the 314 | # timeout can cause a batch with less than mininum_batch_size. 315 | self.assertEqual(None, a.shape[0].value) 316 | assertions_triggered[0] += 1 317 | return a 318 | 319 | @dynamic_batching.batch_fn_with_options(minimum_batch_size=2, 320 | maximum_batch_size=2, 321 | timeout_ms=None) 322 | def f2(a): 323 | # When timeout is disabled and minimum/maximum batch size are equal, the 324 | # shape is statically known. 325 | self.assertEqual(2, a.shape[0].value) 326 | assertions_triggered[0] += 1 327 | return a 328 | 329 | f0(tf.constant([1])) 330 | f1(tf.constant([1])) 331 | f2(tf.constant([1])) 332 | self.assertEqual(3, assertions_triggered[0]) 333 | 334 | def test_out_of_order_execution1(self): 335 | with self.test_session() as session: 336 | batcher = dynamic_batching._Batcher(minimum_batch_size=1, 337 | maximum_batch_size=1, 338 | timeout_ms=None) 339 | 340 | tp = pool.ThreadPool(10) 341 | r0 = tp.apply_async(session.run, batcher.compute([[1]], [tf.int32])) 342 | (input0,), computation_id0 = session.run(batcher.get_inputs([tf.int32])) 343 | r1 = tp.apply_async(session.run, batcher.compute([[2]], [tf.int32])) 344 | (input1,), computation_id1 = session.run(batcher.get_inputs([tf.int32])) 345 | 346 | self.assertAllEqual([1], input0) 347 | self.assertAllEqual([2], input1) 348 | 349 | session.run(batcher.set_outputs([input0 + 42], computation_id0)) 350 | session.run(batcher.set_outputs([input1 + 42], computation_id1)) 351 | 352 | self.assertAllEqual([43], r0.get()) 353 | self.assertAllEqual([44], r1.get()) 354 | 355 | def test_out_of_order_execution2(self): 356 | with self.test_session() as session: 357 | batcher = dynamic_batching._Batcher(minimum_batch_size=1, 358 | maximum_batch_size=1, 359 | timeout_ms=None) 360 | 361 | tp = pool.ThreadPool(10) 362 | r0 = tp.apply_async(session.run, batcher.compute([[1]], [tf.int32])) 363 | (input0,), computation_id0 = session.run(batcher.get_inputs([tf.int32])) 364 | r1 = tp.apply_async(session.run, batcher.compute([[2]], [tf.int32])) 365 | (input1,), computation_id1 = session.run(batcher.get_inputs([tf.int32])) 366 | 367 | self.assertAllEqual([1], input0) 368 | self.assertAllEqual([2], input1) 369 | 370 | # These two runs are switched from testOutOfOrderExecution1. 371 | session.run(batcher.set_outputs([input1 + 42], computation_id1)) 372 | session.run(batcher.set_outputs([input0 + 42], computation_id0)) 373 | 374 | self.assertAllEqual([43], r0.get()) 375 | self.assertAllEqual([44], r1.get()) 376 | 377 | def test_invalid_computation_id(self): 378 | with self.test_session() as session: 379 | batcher = dynamic_batching._Batcher(minimum_batch_size=1, 380 | maximum_batch_size=1, 381 | timeout_ms=None) 382 | 383 | tp = pool.ThreadPool(10) 384 | tp.apply_async(session.run, batcher.compute([[1]], [tf.int32])) 385 | (input0,), _ = session.run(batcher.get_inputs([tf.int32])) 386 | 387 | self.assertAllEqual([1], input0) 388 | 389 | with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 390 | 'Invalid computation id'): 391 | session.run(batcher.set_outputs([input0], 42)) 392 | 393 | def test_op_shape(self): 394 | with self.test_session(): 395 | batcher = dynamic_batching._Batcher(minimum_batch_size=1, 396 | maximum_batch_size=1, 397 | timeout_ms=None) 398 | 399 | _, computation_id = batcher.get_inputs([tf.int32]) 400 | 401 | self.assertEqual([], computation_id.shape) 402 | 403 | 404 | class DynamicBatchingBenchmarks(tf.test.Benchmark): 405 | 406 | def benchmark_batching_small(self): 407 | with tf.Session() as session: 408 | @dynamic_batching.batch_fn 409 | def f(a, b): 410 | return a + b 411 | 412 | outputs = [] 413 | for _ in range(1000): 414 | outputs.append(f(tf.ones([1, 10]), tf.ones([1, 10]))) 415 | op_to_benchmark = tf.group(*outputs) 416 | 417 | tf.train.start_queue_runners() 418 | 419 | self.run_op_benchmark( 420 | name='batching_many_small', 421 | sess=session, 422 | op_or_tensor=op_to_benchmark, 423 | burn_iters=10, 424 | min_iters=50) 425 | 426 | def benchmark_batching_large(self): 427 | with tf.Session() as session: 428 | @dynamic_batching.batch_fn 429 | def f(a, b): 430 | return a + b 431 | 432 | outputs = [] 433 | for _ in range(1000): 434 | outputs.append(f(tf.ones([1, 100000]), tf.ones([1, 100000]))) 435 | op_to_benchmark = tf.group(*outputs) 436 | 437 | tf.train.start_queue_runners() 438 | 439 | self.run_op_benchmark( 440 | name='batching_many_large', 441 | sess=session, 442 | op_or_tensor=op_to_benchmark, 443 | burn_iters=10, 444 | min_iters=50) 445 | 446 | 447 | if __name__ == '__main__': 448 | tf.test.main() 449 | -------------------------------------------------------------------------------- /environments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Environments and environment helper classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os.path 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | import deepmind_lab 28 | 29 | 30 | nest = tf.contrib.framework.nest 31 | 32 | 33 | class LocalLevelCache(object): 34 | """Local level cache.""" 35 | 36 | def __init__(self, cache_dir='/tmp/level_cache'): 37 | self._cache_dir = cache_dir 38 | tf.gfile.MakeDirs(cache_dir) 39 | 40 | def fetch(self, key, pk3_path): 41 | path = os.path.join(self._cache_dir, key) 42 | if tf.gfile.Exists(path): 43 | tf.gfile.Copy(path, pk3_path, overwrite=True) 44 | return True 45 | return False 46 | 47 | def write(self, key, pk3_path): 48 | path = os.path.join(self._cache_dir, key) 49 | if not tf.gfile.Exists(path): 50 | tf.gfile.Copy(pk3_path, path) 51 | 52 | 53 | DEFAULT_ACTION_SET = ( 54 | (0, 0, 0, 1, 0, 0, 0), # Forward 55 | (0, 0, 0, -1, 0, 0, 0), # Backward 56 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left 57 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right 58 | (-20, 0, 0, 0, 0, 0, 0), # Look Left 59 | (20, 0, 0, 0, 0, 0, 0), # Look Right 60 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward 61 | (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward 62 | (0, 0, 0, 0, 1, 0, 0), # Fire. 63 | ) 64 | 65 | 66 | class PyProcessDmLab(object): 67 | """DeepMind Lab wrapper for PyProcess.""" 68 | 69 | def __init__(self, level, config, num_action_repeats, seed, 70 | runfiles_path=None, level_cache=None): 71 | self._num_action_repeats = num_action_repeats 72 | self._random_state = np.random.RandomState(seed=seed) 73 | if runfiles_path: 74 | deepmind_lab.set_runfiles_path(runfiles_path) 75 | config = {k: str(v) for k, v in config.iteritems()} 76 | self._observation_spec = ['RGB_INTERLEAVED', 'INSTR'] 77 | self._env = deepmind_lab.Lab( 78 | level=level, 79 | observations=self._observation_spec, 80 | config=config, 81 | level_cache=level_cache, 82 | ) 83 | 84 | def _reset(self): 85 | self._env.reset(seed=self._random_state.randint(0, 2 ** 31 - 1)) 86 | 87 | def _observation(self): 88 | d = self._env.observations() 89 | return [d[k] for k in self._observation_spec] 90 | 91 | def initial(self): 92 | self._reset() 93 | return self._observation() 94 | 95 | def step(self, action): 96 | reward = self._env.step(action, num_steps=self._num_action_repeats) 97 | done = np.array(not self._env.is_running()) 98 | if done: 99 | self._reset() 100 | observation = self._observation() 101 | reward = np.array(reward, dtype=np.float32) 102 | return reward, done, observation 103 | 104 | def close(self): 105 | self._env.close() 106 | 107 | @staticmethod 108 | def _tensor_specs(method_name, unused_kwargs, constructor_kwargs): 109 | """Returns a nest of `TensorSpec` with the method's output specification.""" 110 | width = constructor_kwargs['config'].get('width', 320) 111 | height = constructor_kwargs['config'].get('height', 240) 112 | 113 | observation_spec = [ 114 | tf.contrib.framework.TensorSpec([height, width, 3], tf.uint8), 115 | tf.contrib.framework.TensorSpec([], tf.string), 116 | ] 117 | 118 | if method_name == 'initial': 119 | return observation_spec 120 | elif method_name == 'step': 121 | return ( 122 | tf.contrib.framework.TensorSpec([], tf.float32), 123 | tf.contrib.framework.TensorSpec([], tf.bool), 124 | observation_spec, 125 | ) 126 | 127 | 128 | StepOutputInfo = collections.namedtuple('StepOutputInfo', 129 | 'episode_return episode_step') 130 | StepOutput = collections.namedtuple('StepOutput', 131 | 'reward info done observation') 132 | 133 | 134 | class FlowEnvironment(object): 135 | """An environment that returns a new state for every modifying method. 136 | 137 | The environment returns a new environment state for every modifying action and 138 | forces previous actions to be completed first. Similar to `flow` for 139 | `TensorArray`. 140 | """ 141 | 142 | def __init__(self, env): 143 | """Initializes the environment. 144 | 145 | Args: 146 | env: An environment with `initial()` and `step(action)` methods where 147 | `initial` returns the initial observations and `step` takes an action 148 | and returns a tuple of (reward, done, observation). `observation` 149 | should be the observation after the step is taken. If `done` is 150 | True, the observation should be the first observation in the next 151 | episode. 152 | """ 153 | self._env = env 154 | 155 | def initial(self): 156 | """Returns the initial output and initial state. 157 | 158 | Returns: 159 | A tuple of (`StepOutput`, environment state). The environment state should 160 | be passed in to the next invocation of `step` and should not be used in 161 | any other way. The reward and transition type in the `StepOutput` is the 162 | reward/transition type that lead to the observation in `StepOutput`. 163 | """ 164 | with tf.name_scope('flow_environment_initial'): 165 | initial_reward = tf.constant(0.) 166 | initial_info = StepOutputInfo(tf.constant(0.), tf.constant(0)) 167 | initial_done = tf.constant(True) 168 | initial_observation = self._env.initial() 169 | 170 | initial_output = StepOutput( 171 | initial_reward, 172 | initial_info, 173 | initial_done, 174 | initial_observation) 175 | 176 | # Control dependency to make sure the next step can't be taken before the 177 | # initial output has been read from the environment. 178 | with tf.control_dependencies(nest.flatten(initial_output)): 179 | initial_flow = tf.constant(0, dtype=tf.int64) 180 | initial_state = (initial_flow, initial_info) 181 | return initial_output, initial_state 182 | 183 | def step(self, action, state): 184 | """Takes a step in the environment. 185 | 186 | Args: 187 | action: An action tensor suitable for the underlying environment. 188 | state: The environment state from the last step or initial state. 189 | 190 | Returns: 191 | A tuple of (`StepOutput`, environment state). The environment state should 192 | be passed in to the next invocation of `step` and should not be used in 193 | any other way. On episode end (i.e. `done` is True), the returned reward 194 | should be included in the sum of rewards for the ending episode and not 195 | part of the next episode. 196 | """ 197 | with tf.name_scope('flow_environment_step'): 198 | flow, info = nest.map_structure(tf.convert_to_tensor, state) 199 | 200 | # Make sure the previous step has been executed before running the next 201 | # step. 202 | with tf.control_dependencies([flow]): 203 | reward, done, observation = self._env.step(action) 204 | 205 | with tf.control_dependencies(nest.flatten(observation)): 206 | new_flow = tf.add(flow, 1) 207 | 208 | # When done, include the reward in the output info but not in the 209 | # state for the next step. 210 | new_info = StepOutputInfo(info.episode_return + reward, 211 | info.episode_step + 1) 212 | new_state = new_flow, nest.map_structure( 213 | lambda a, b: tf.where(done, a, b), 214 | StepOutputInfo(tf.constant(0.), tf.constant(0)), 215 | new_info) 216 | 217 | output = StepOutput(reward, new_info, done, observation) 218 | return output, new_state 219 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Importance Weighted Actor-Learner Architectures.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import contextlib 23 | import functools 24 | import os 25 | import sys 26 | 27 | import dmlab30 28 | import environments 29 | import numpy as np 30 | import py_process 31 | import sonnet as snt 32 | import tensorflow as tf 33 | import vtrace 34 | 35 | try: 36 | import dynamic_batching 37 | except tf.errors.NotFoundError: 38 | tf.logging.warning('Running without dynamic batching.') 39 | 40 | from six.moves import range 41 | 42 | 43 | nest = tf.contrib.framework.nest 44 | 45 | flags = tf.app.flags 46 | FLAGS = tf.app.flags.FLAGS 47 | 48 | flags.DEFINE_string('logdir', '/tmp/agent', 'TensorFlow log directory.') 49 | flags.DEFINE_enum('mode', 'train', ['train', 'test'], 'Training or test mode.') 50 | 51 | # Flags used for testing. 52 | flags.DEFINE_integer('test_num_episodes', 10, 'Number of episodes per level.') 53 | 54 | # Flags used for distributed training. 55 | flags.DEFINE_integer('task', -1, 'Task id. Use -1 for local training.') 56 | flags.DEFINE_enum('job_name', 'learner', ['learner', 'actor'], 57 | 'Job name. Ignored when task is set to -1.') 58 | 59 | # Training. 60 | flags.DEFINE_integer('total_environment_frames', int(1e9), 61 | 'Total environment frames to train for.') 62 | flags.DEFINE_integer('num_actors', 4, 'Number of actors.') 63 | flags.DEFINE_integer('batch_size', 2, 'Batch size for training.') 64 | flags.DEFINE_integer('unroll_length', 100, 'Unroll length in agent steps.') 65 | flags.DEFINE_integer('num_action_repeats', 4, 'Number of action repeats.') 66 | flags.DEFINE_integer('seed', 1, 'Random seed.') 67 | 68 | # Loss settings. 69 | flags.DEFINE_float('entropy_cost', 0.00025, 'Entropy cost/multiplier.') 70 | flags.DEFINE_float('baseline_cost', .5, 'Baseline cost/multiplier.') 71 | flags.DEFINE_float('discounting', .99, 'Discounting factor.') 72 | flags.DEFINE_enum('reward_clipping', 'abs_one', ['abs_one', 'soft_asymmetric'], 73 | 'Reward clipping.') 74 | 75 | # Environment settings. 76 | flags.DEFINE_string( 77 | 'dataset_path', '', 78 | 'Path to dataset needed for psychlab_*, see ' 79 | 'https://github.com/deepmind/lab/tree/master/data/brady_konkle_oliva2008') 80 | flags.DEFINE_string('level_name', 'explore_goal_locations_small', 81 | '''Level name or \'dmlab30\' for the full DmLab-30 suite ''' 82 | '''with levels assigned round robin to the actors.''') 83 | flags.DEFINE_integer('width', 96, 'Width of observation.') 84 | flags.DEFINE_integer('height', 72, 'Height of observation.') 85 | 86 | # Optimizer settings. 87 | flags.DEFINE_float('learning_rate', 0.00048, 'Learning rate.') 88 | flags.DEFINE_float('decay', .99, 'RMSProp optimizer decay.') 89 | flags.DEFINE_float('momentum', 0., 'RMSProp momentum.') 90 | flags.DEFINE_float('epsilon', .1, 'RMSProp epsilon.') 91 | 92 | 93 | # Structure to be sent from actors to learner. 94 | ActorOutput = collections.namedtuple( 95 | 'ActorOutput', 'level_name agent_state env_outputs agent_outputs') 96 | AgentOutput = collections.namedtuple('AgentOutput', 97 | 'action policy_logits baseline') 98 | 99 | 100 | def is_single_machine(): 101 | return FLAGS.task == -1 102 | 103 | 104 | class Agent(snt.RNNCore): 105 | """Agent with ResNet.""" 106 | 107 | def __init__(self, num_actions): 108 | super(Agent, self).__init__(name='agent') 109 | 110 | self._num_actions = num_actions 111 | 112 | with self._enter_variable_scope(): 113 | self._core = tf.contrib.rnn.LSTMBlockCell(256) 114 | 115 | def initial_state(self, batch_size): 116 | return self._core.zero_state(batch_size, tf.float32) 117 | 118 | def _instruction(self, instruction): 119 | # Split string. 120 | splitted = tf.string_split(instruction) 121 | dense = tf.sparse_tensor_to_dense(splitted, default_value='') 122 | length = tf.reduce_sum(tf.to_int32(tf.not_equal(dense, '')), axis=1) 123 | 124 | # To int64 hash buckets. Small risk of having collisions. Alternatively, a 125 | # vocabulary can be used. 126 | num_hash_buckets = 1000 127 | buckets = tf.string_to_hash_bucket_fast(dense, num_hash_buckets) 128 | 129 | # Embed the instruction. Embedding size 20 seems to be enough. 130 | embedding_size = 20 131 | embedding = snt.Embed(num_hash_buckets, embedding_size)(buckets) 132 | 133 | # Pad to make sure there is at least one output. 134 | padding = tf.to_int32(tf.equal(tf.shape(embedding)[1], 0)) 135 | embedding = tf.pad(embedding, [[0, 0], [0, padding], [0, 0]]) 136 | 137 | core = tf.contrib.rnn.LSTMBlockCell(64, name='language_lstm') 138 | output, _ = tf.nn.dynamic_rnn(core, embedding, length, dtype=tf.float32) 139 | 140 | # Return last output. 141 | return tf.reverse_sequence(output, length, seq_axis=1)[:, 0] 142 | 143 | def _torso(self, input_): 144 | last_action, env_output = input_ 145 | reward, _, _, (frame, instruction) = env_output 146 | 147 | # Convert to floats. 148 | frame = tf.to_float(frame) 149 | 150 | frame /= 255 151 | with tf.variable_scope('convnet'): 152 | conv_out = frame 153 | for i, (num_ch, num_blocks) in enumerate([(16, 2), (32, 2), (32, 2)]): 154 | # Downscale. 155 | conv_out = snt.Conv2D(num_ch, 3, stride=1, padding='SAME')(conv_out) 156 | conv_out = tf.nn.pool( 157 | conv_out, 158 | window_shape=[3, 3], 159 | pooling_type='MAX', 160 | padding='SAME', 161 | strides=[2, 2]) 162 | 163 | # Residual block(s). 164 | for j in range(num_blocks): 165 | with tf.variable_scope('residual_%d_%d' % (i, j)): 166 | block_input = conv_out 167 | conv_out = tf.nn.relu(conv_out) 168 | conv_out = snt.Conv2D(num_ch, 3, stride=1, padding='SAME')(conv_out) 169 | conv_out = tf.nn.relu(conv_out) 170 | conv_out = snt.Conv2D(num_ch, 3, stride=1, padding='SAME')(conv_out) 171 | conv_out += block_input 172 | 173 | conv_out = tf.nn.relu(conv_out) 174 | conv_out = snt.BatchFlatten()(conv_out) 175 | 176 | conv_out = snt.Linear(256)(conv_out) 177 | conv_out = tf.nn.relu(conv_out) 178 | 179 | instruction_out = self._instruction(instruction) 180 | 181 | # Append clipped last reward and one hot last action. 182 | clipped_reward = tf.expand_dims(tf.clip_by_value(reward, -1, 1), -1) 183 | one_hot_last_action = tf.one_hot(last_action, self._num_actions) 184 | return tf.concat( 185 | [conv_out, clipped_reward, one_hot_last_action, instruction_out], 186 | axis=1) 187 | 188 | def _head(self, core_output): 189 | policy_logits = snt.Linear(self._num_actions, name='policy_logits')( 190 | core_output) 191 | baseline = tf.squeeze(snt.Linear(1, name='baseline')(core_output), axis=-1) 192 | 193 | # Sample an action from the policy. 194 | new_action = tf.multinomial(policy_logits, num_samples=1, 195 | output_dtype=tf.int32) 196 | new_action = tf.squeeze(new_action, 1, name='new_action') 197 | 198 | return AgentOutput(new_action, policy_logits, baseline) 199 | 200 | def _build(self, input_, core_state): 201 | action, env_output = input_ 202 | actions, env_outputs = nest.map_structure(lambda t: tf.expand_dims(t, 0), 203 | (action, env_output)) 204 | outputs, core_state = self.unroll(actions, env_outputs, core_state) 205 | return nest.map_structure(lambda t: tf.squeeze(t, 0), outputs), core_state 206 | 207 | @snt.reuse_variables 208 | def unroll(self, actions, env_outputs, core_state): 209 | _, _, done, _ = env_outputs 210 | 211 | torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs)) 212 | 213 | # Note, in this implementation we can't use CuDNN RNN to speed things up due 214 | # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be 215 | # changed to implement snt.LSTMCell). 216 | initial_core_state = self._core.zero_state(tf.shape(actions)[1], tf.float32) 217 | core_output_list = [] 218 | for input_, d in zip(tf.unstack(torso_outputs), tf.unstack(done)): 219 | # If the episode ended, the core state should be reset before the next. 220 | core_state = nest.map_structure(functools.partial(tf.where, d), 221 | initial_core_state, core_state) 222 | core_output, core_state = self._core(input_, core_state) 223 | core_output_list.append(core_output) 224 | 225 | return snt.BatchApply(self._head)(tf.stack(core_output_list)), core_state 226 | 227 | 228 | def build_actor(agent, env, level_name, action_set): 229 | """Builds the actor loop.""" 230 | # Initial values. 231 | initial_env_output, initial_env_state = env.initial() 232 | initial_agent_state = agent.initial_state(1) 233 | initial_action = tf.zeros([1], dtype=tf.int32) 234 | dummy_agent_output, _ = agent( 235 | (initial_action, 236 | nest.map_structure(lambda t: tf.expand_dims(t, 0), initial_env_output)), 237 | initial_agent_state) 238 | initial_agent_output = nest.map_structure( 239 | lambda t: tf.zeros(t.shape, t.dtype), dummy_agent_output) 240 | 241 | # All state that needs to persist across training iterations. This includes 242 | # the last environment output, agent state and last agent output. These 243 | # variables should never go on the parameter servers. 244 | def create_state(t): 245 | # Creates a unique variable scope to ensure the variable name is unique. 246 | with tf.variable_scope(None, default_name='state'): 247 | return tf.get_local_variable(t.op.name, initializer=t, use_resource=True) 248 | 249 | persistent_state = nest.map_structure( 250 | create_state, (initial_env_state, initial_env_output, initial_agent_state, 251 | initial_agent_output)) 252 | 253 | def step(input_, unused_i): 254 | """Steps through the agent and the environment.""" 255 | env_state, env_output, agent_state, agent_output = input_ 256 | 257 | # Run agent. 258 | action = agent_output[0] 259 | batched_env_output = nest.map_structure(lambda t: tf.expand_dims(t, 0), 260 | env_output) 261 | agent_output, agent_state = agent((action, batched_env_output), agent_state) 262 | 263 | # Convert action index to the native action. 264 | action = agent_output[0][0] 265 | raw_action = tf.gather(action_set, action) 266 | 267 | env_output, env_state = env.step(raw_action, env_state) 268 | 269 | return env_state, env_output, agent_state, agent_output 270 | 271 | # Run the unroll. `read_value()` is needed to make sure later usage will 272 | # return the first values and not a new snapshot of the variables. 273 | first_values = nest.map_structure(lambda v: v.read_value(), persistent_state) 274 | _, first_env_output, first_agent_state, first_agent_output = first_values 275 | 276 | # Use scan to apply `step` multiple times, therefore unrolling the agent 277 | # and environment interaction for `FLAGS.unroll_length`. `tf.scan` forwards 278 | # the output of each call of `step` as input of the subsequent call of `step`. 279 | # The unroll sequence is initialized with the agent and environment states 280 | # and outputs as stored at the end of the previous unroll. 281 | # `output` stores lists of all states and outputs stacked along the entire 282 | # unroll. Note that the initial states and outputs (fed through `initializer`) 283 | # are not in `output` and will need to be added manually later. 284 | output = tf.scan(step, tf.range(FLAGS.unroll_length), first_values) 285 | _, env_outputs, _, agent_outputs = output 286 | 287 | # Update persistent state with the last output from the loop. 288 | assign_ops = nest.map_structure(lambda v, t: v.assign(t[-1]), 289 | persistent_state, output) 290 | 291 | # The control dependency ensures that the final agent and environment states 292 | # and outputs are stored in `persistent_state` (to initialize next unroll). 293 | with tf.control_dependencies(nest.flatten(assign_ops)): 294 | # Remove the batch dimension from the agent state/output. 295 | first_agent_state = nest.map_structure(lambda t: t[0], first_agent_state) 296 | first_agent_output = nest.map_structure(lambda t: t[0], first_agent_output) 297 | agent_outputs = nest.map_structure(lambda t: t[:, 0], agent_outputs) 298 | 299 | # Concatenate first output and the unroll along the time dimension. 300 | full_agent_outputs, full_env_outputs = nest.map_structure( 301 | lambda first, rest: tf.concat([[first], rest], 0), 302 | (first_agent_output, first_env_output), (agent_outputs, env_outputs)) 303 | 304 | output = ActorOutput( 305 | level_name=level_name, agent_state=first_agent_state, 306 | env_outputs=full_env_outputs, agent_outputs=full_agent_outputs) 307 | 308 | # No backpropagation should be done here. 309 | return nest.map_structure(tf.stop_gradient, output) 310 | 311 | 312 | def compute_baseline_loss(advantages): 313 | # Loss for the baseline, summed over the time dimension. 314 | # Multiply by 0.5 to match the standard update rule: 315 | # d(loss) / d(baseline) = advantage 316 | return .5 * tf.reduce_sum(tf.square(advantages)) 317 | 318 | 319 | def compute_entropy_loss(logits): 320 | policy = tf.nn.softmax(logits) 321 | log_policy = tf.nn.log_softmax(logits) 322 | entropy_per_timestep = tf.reduce_sum(-policy * log_policy, axis=-1) 323 | return -tf.reduce_sum(entropy_per_timestep) 324 | 325 | 326 | def compute_policy_gradient_loss(logits, actions, advantages): 327 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 328 | labels=actions, logits=logits) 329 | advantages = tf.stop_gradient(advantages) 330 | policy_gradient_loss_per_timestep = cross_entropy * advantages 331 | return tf.reduce_sum(policy_gradient_loss_per_timestep) 332 | 333 | 334 | def build_learner(agent, agent_state, env_outputs, agent_outputs): 335 | """Builds the learner loop. 336 | 337 | Args: 338 | agent: A snt.RNNCore module outputting `AgentOutput` named tuples, with an 339 | `unroll` call for computing the outputs for a whole trajectory. 340 | agent_state: The initial agent state for each sequence in the batch. 341 | env_outputs: A `StepOutput` namedtuple where each field is of shape 342 | [T+1, ...]. 343 | agent_outputs: An `AgentOutput` namedtuple where each field is of shape 344 | [T+1, ...]. 345 | 346 | Returns: 347 | A tuple of (done, infos, and environment frames) where 348 | the environment frames tensor causes an update. 349 | """ 350 | learner_outputs, _ = agent.unroll(agent_outputs.action, env_outputs, 351 | agent_state) 352 | 353 | # Use last baseline value (from the value function) to bootstrap. 354 | bootstrap_value = learner_outputs.baseline[-1] 355 | 356 | # At this point, the environment outputs at time step `t` are the inputs that 357 | # lead to the learner_outputs at time step `t`. After the following shifting, 358 | # the actions in agent_outputs and learner_outputs at time step `t` is what 359 | # leads to the environment outputs at time step `t`. 360 | agent_outputs = nest.map_structure(lambda t: t[1:], agent_outputs) 361 | rewards, infos, done, _ = nest.map_structure( 362 | lambda t: t[1:], env_outputs) 363 | learner_outputs = nest.map_structure(lambda t: t[:-1], learner_outputs) 364 | 365 | if FLAGS.reward_clipping == 'abs_one': 366 | clipped_rewards = tf.clip_by_value(rewards, -1, 1) 367 | elif FLAGS.reward_clipping == 'soft_asymmetric': 368 | squeezed = tf.tanh(rewards / 5.0) 369 | # Negative rewards are given less weight than positive rewards. 370 | clipped_rewards = tf.where(rewards < 0, .3 * squeezed, squeezed) * 5. 371 | 372 | discounts = tf.to_float(~done) * FLAGS.discounting 373 | 374 | # Compute V-trace returns and weights. 375 | # Note, this is put on the CPU because it's faster than on GPU. It can be 376 | # improved further with XLA-compilation or with a custom TensorFlow operation. 377 | with tf.device('/cpu'): 378 | vtrace_returns = vtrace.from_logits( 379 | behaviour_policy_logits=agent_outputs.policy_logits, 380 | target_policy_logits=learner_outputs.policy_logits, 381 | actions=agent_outputs.action, 382 | discounts=discounts, 383 | rewards=clipped_rewards, 384 | values=learner_outputs.baseline, 385 | bootstrap_value=bootstrap_value) 386 | 387 | # Compute loss as a weighted sum of the baseline loss, the policy gradient 388 | # loss and an entropy regularization term. 389 | total_loss = compute_policy_gradient_loss( 390 | learner_outputs.policy_logits, agent_outputs.action, 391 | vtrace_returns.pg_advantages) 392 | total_loss += FLAGS.baseline_cost * compute_baseline_loss( 393 | vtrace_returns.vs - learner_outputs.baseline) 394 | total_loss += FLAGS.entropy_cost * compute_entropy_loss( 395 | learner_outputs.policy_logits) 396 | 397 | # Optimization 398 | num_env_frames = tf.train.get_global_step() 399 | learning_rate = tf.train.polynomial_decay(FLAGS.learning_rate, num_env_frames, 400 | FLAGS.total_environment_frames, 0) 401 | optimizer = tf.train.RMSPropOptimizer(learning_rate, FLAGS.decay, 402 | FLAGS.momentum, FLAGS.epsilon) 403 | train_op = optimizer.minimize(total_loss) 404 | 405 | # Merge updating the network and environment frames into a single tensor. 406 | with tf.control_dependencies([train_op]): 407 | num_env_frames_and_train = num_env_frames.assign_add( 408 | FLAGS.batch_size * FLAGS.unroll_length * FLAGS.num_action_repeats) 409 | 410 | # Adding a few summaries. 411 | tf.summary.scalar('learning_rate', learning_rate) 412 | tf.summary.scalar('total_loss', total_loss) 413 | tf.summary.histogram('action', agent_outputs.action) 414 | 415 | return done, infos, num_env_frames_and_train 416 | 417 | 418 | def create_environment(level_name, seed, is_test=False): 419 | """Creates an environment wrapped in a `FlowEnvironment`.""" 420 | if level_name in dmlab30.ALL_LEVELS: 421 | level_name = 'contributed/dmlab30/' + level_name 422 | 423 | # Note, you may want to use a level cache to speed of compilation of 424 | # environment maps. See the documentation for the Python interface of DeepMind 425 | # Lab. 426 | config = { 427 | 'width': FLAGS.width, 428 | 'height': FLAGS.height, 429 | 'datasetPath': FLAGS.dataset_path, 430 | 'logLevel': 'WARN', 431 | } 432 | if is_test: 433 | config['allowHoldOutLevels'] = 'true' 434 | # Mixer seed for evalution, see 435 | # https://github.com/deepmind/lab/blob/master/docs/users/python_api.md 436 | config['mixerSeed'] = 0x600D5EED 437 | p = py_process.PyProcess(environments.PyProcessDmLab, level_name, config, 438 | FLAGS.num_action_repeats, seed) 439 | return environments.FlowEnvironment(p.proxy) 440 | 441 | 442 | @contextlib.contextmanager 443 | def pin_global_variables(device): 444 | """Pins global variables to the specified device.""" 445 | def getter(getter, *args, **kwargs): 446 | var_collections = kwargs.get('collections', None) 447 | if var_collections is None: 448 | var_collections = [tf.GraphKeys.GLOBAL_VARIABLES] 449 | if tf.GraphKeys.GLOBAL_VARIABLES in var_collections: 450 | with tf.device(device): 451 | return getter(*args, **kwargs) 452 | else: 453 | return getter(*args, **kwargs) 454 | 455 | with tf.variable_scope('', custom_getter=getter) as vs: 456 | yield vs 457 | 458 | 459 | def train(action_set, level_names): 460 | """Train.""" 461 | 462 | if is_single_machine(): 463 | local_job_device = '' 464 | shared_job_device = '' 465 | is_actor_fn = lambda i: True 466 | is_learner = True 467 | global_variable_device = '/gpu' 468 | server = tf.train.Server.create_local_server() 469 | filters = [] 470 | else: 471 | local_job_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task) 472 | shared_job_device = '/job:learner/task:0' 473 | is_actor_fn = lambda i: FLAGS.job_name == 'actor' and i == FLAGS.task 474 | is_learner = FLAGS.job_name == 'learner' 475 | 476 | # Placing the variable on CPU, makes it cheaper to send it to all the 477 | # actors. Continual copying the variables from the GPU is slow. 478 | global_variable_device = shared_job_device + '/cpu' 479 | cluster = tf.train.ClusterSpec({ 480 | 'actor': ['localhost:%d' % (8001 + i) for i in range(FLAGS.num_actors)], 481 | 'learner': ['localhost:8000'] 482 | }) 483 | server = tf.train.Server(cluster, job_name=FLAGS.job_name, 484 | task_index=FLAGS.task) 485 | filters = [shared_job_device, local_job_device] 486 | 487 | # Only used to find the actor output structure. 488 | with tf.Graph().as_default(): 489 | agent = Agent(len(action_set)) 490 | env = create_environment(level_names[0], seed=1) 491 | structure = build_actor(agent, env, level_names[0], action_set) 492 | flattened_structure = nest.flatten(structure) 493 | dtypes = [t.dtype for t in flattened_structure] 494 | shapes = [t.shape.as_list() for t in flattened_structure] 495 | 496 | with tf.Graph().as_default(), \ 497 | tf.device(local_job_device + '/cpu'), \ 498 | pin_global_variables(global_variable_device): 499 | tf.set_random_seed(FLAGS.seed) # Makes initialization deterministic. 500 | 501 | # Create Queue and Agent on the learner. 502 | with tf.device(shared_job_device): 503 | queue = tf.FIFOQueue(1, dtypes, shapes, shared_name='buffer') 504 | agent = Agent(len(action_set)) 505 | 506 | if is_single_machine() and 'dynamic_batching' in sys.modules: 507 | # For single machine training, we use dynamic batching for improved GPU 508 | # utilization. The semantics of single machine training are slightly 509 | # different from the distributed setting because within a single unroll 510 | # of an environment, the actions may be computed using different weights 511 | # if an update happens within the unroll. 512 | old_build = agent._build 513 | @dynamic_batching.batch_fn 514 | def build(*args): 515 | with tf.device('/gpu'): 516 | return old_build(*args) 517 | tf.logging.info('Using dynamic batching.') 518 | agent._build = build 519 | 520 | # Build actors and ops to enqueue their output. 521 | enqueue_ops = [] 522 | for i in range(FLAGS.num_actors): 523 | if is_actor_fn(i): 524 | level_name = level_names[i % len(level_names)] 525 | tf.logging.info('Creating actor %d with level %s', i, level_name) 526 | env = create_environment(level_name, seed=i + 1) 527 | actor_output = build_actor(agent, env, level_name, action_set) 528 | with tf.device(shared_job_device): 529 | enqueue_ops.append(queue.enqueue(nest.flatten(actor_output))) 530 | 531 | # If running in a single machine setup, run actors with QueueRunners 532 | # (separate threads). 533 | if is_learner and enqueue_ops: 534 | tf.train.add_queue_runner(tf.train.QueueRunner(queue, enqueue_ops)) 535 | 536 | # Build learner. 537 | if is_learner: 538 | # Create global step, which is the number of environment frames processed. 539 | tf.get_variable( 540 | 'num_environment_frames', 541 | initializer=tf.zeros_initializer(), 542 | shape=[], 543 | dtype=tf.int64, 544 | trainable=False, 545 | collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES]) 546 | 547 | # Create batch (time major) and recreate structure. 548 | dequeued = queue.dequeue_many(FLAGS.batch_size) 549 | dequeued = nest.pack_sequence_as(structure, dequeued) 550 | 551 | def make_time_major(s): 552 | return nest.map_structure( 553 | lambda t: tf.transpose(t, [1, 0] + list(range(t.shape.ndims))[2:]), s) 554 | 555 | dequeued = dequeued._replace( 556 | env_outputs=make_time_major(dequeued.env_outputs), 557 | agent_outputs=make_time_major(dequeued.agent_outputs)) 558 | 559 | with tf.device('/gpu'): 560 | # Using StagingArea allows us to prepare the next batch and send it to 561 | # the GPU while we're performing a training step. This adds up to 1 step 562 | # policy lag. 563 | flattened_output = nest.flatten(dequeued) 564 | area = tf.contrib.staging.StagingArea( 565 | [t.dtype for t in flattened_output], 566 | [t.shape for t in flattened_output]) 567 | stage_op = area.put(flattened_output) 568 | 569 | data_from_actors = nest.pack_sequence_as(structure, area.get()) 570 | 571 | # Unroll agent on sequence, create losses and update ops. 572 | output = build_learner(agent, data_from_actors.agent_state, 573 | data_from_actors.env_outputs, 574 | data_from_actors.agent_outputs) 575 | 576 | # Create MonitoredSession (to run the graph, checkpoint and log). 577 | tf.logging.info('Creating MonitoredSession, is_chief %s', is_learner) 578 | config = tf.ConfigProto(allow_soft_placement=True, device_filters=filters) 579 | with tf.train.MonitoredTrainingSession( 580 | server.target, 581 | is_chief=is_learner, 582 | checkpoint_dir=FLAGS.logdir, 583 | save_checkpoint_secs=600, 584 | save_summaries_secs=30, 585 | log_step_count_steps=50000, 586 | config=config, 587 | hooks=[py_process.PyProcessHook()]) as session: 588 | 589 | if is_learner: 590 | # Logging. 591 | level_returns = {level_name: [] for level_name in level_names} 592 | summary_writer = tf.summary.FileWriterCache.get(FLAGS.logdir) 593 | 594 | # Prepare data for first run. 595 | session.run_step_fn( 596 | lambda step_context: step_context.session.run(stage_op)) 597 | 598 | # Execute learning and track performance. 599 | num_env_frames_v = 0 600 | while num_env_frames_v < FLAGS.total_environment_frames: 601 | level_names_v, done_v, infos_v, num_env_frames_v, _ = session.run( 602 | (data_from_actors.level_name,) + output + (stage_op,)) 603 | level_names_v = np.repeat([level_names_v], done_v.shape[0], 0) 604 | 605 | for level_name, episode_return, episode_step in zip( 606 | level_names_v[done_v], 607 | infos_v.episode_return[done_v], 608 | infos_v.episode_step[done_v]): 609 | episode_frames = episode_step * FLAGS.num_action_repeats 610 | 611 | tf.logging.info('Level: %s Episode return: %f', 612 | level_name, episode_return) 613 | 614 | summary = tf.summary.Summary() 615 | summary.value.add(tag=level_name + '/episode_return', 616 | simple_value=episode_return) 617 | summary.value.add(tag=level_name + '/episode_frames', 618 | simple_value=episode_frames) 619 | summary_writer.add_summary(summary, num_env_frames_v) 620 | 621 | if FLAGS.level_name == 'dmlab30': 622 | level_returns[level_name].append(episode_return) 623 | 624 | if (FLAGS.level_name == 'dmlab30' and 625 | min(map(len, level_returns.values())) >= 1): 626 | no_cap = dmlab30.compute_human_normalized_score(level_returns, 627 | per_level_cap=None) 628 | cap_100 = dmlab30.compute_human_normalized_score(level_returns, 629 | per_level_cap=100) 630 | summary = tf.summary.Summary() 631 | summary.value.add( 632 | tag='dmlab30/training_no_cap', simple_value=no_cap) 633 | summary.value.add( 634 | tag='dmlab30/training_cap_100', simple_value=cap_100) 635 | summary_writer.add_summary(summary, num_env_frames_v) 636 | 637 | # Clear level scores. 638 | level_returns = {level_name: [] for level_name in level_names} 639 | 640 | else: 641 | # Execute actors (they just need to enqueue their output). 642 | while True: 643 | session.run(enqueue_ops) 644 | 645 | 646 | def test(action_set, level_names): 647 | """Test.""" 648 | 649 | level_returns = {level_name: [] for level_name in level_names} 650 | with tf.Graph().as_default(): 651 | agent = Agent(len(action_set)) 652 | outputs = {} 653 | for level_name in level_names: 654 | env = create_environment(level_name, seed=1, is_test=True) 655 | outputs[level_name] = build_actor(agent, env, level_name, action_set) 656 | 657 | with tf.train.SingularMonitoredSession( 658 | checkpoint_dir=FLAGS.logdir, 659 | hooks=[py_process.PyProcessHook()]) as session: 660 | for level_name in level_names: 661 | tf.logging.info('Testing level: %s', level_name) 662 | while True: 663 | done_v, infos_v = session.run(( 664 | outputs[level_name].env_outputs.done, 665 | outputs[level_name].env_outputs.info 666 | )) 667 | returns = level_returns[level_name] 668 | returns.extend(infos_v.episode_return[1:][done_v[1:]]) 669 | 670 | if len(returns) >= FLAGS.test_num_episodes: 671 | tf.logging.info('Mean episode return: %f', np.mean(returns)) 672 | break 673 | 674 | if FLAGS.level_name == 'dmlab30': 675 | no_cap = dmlab30.compute_human_normalized_score(level_returns, 676 | per_level_cap=None) 677 | cap_100 = dmlab30.compute_human_normalized_score(level_returns, 678 | per_level_cap=100) 679 | tf.logging.info('No cap.: %f Cap 100: %f', no_cap, cap_100) 680 | 681 | 682 | def main(_): 683 | tf.logging.set_verbosity(tf.logging.INFO) 684 | 685 | action_set = environments.DEFAULT_ACTION_SET 686 | if FLAGS.level_name == 'dmlab30' and FLAGS.mode == 'train': 687 | level_names = dmlab30.LEVEL_MAPPING.keys() 688 | elif FLAGS.level_name == 'dmlab30' and FLAGS.mode == 'test': 689 | level_names = dmlab30.LEVEL_MAPPING.values() 690 | else: 691 | level_names = [FLAGS.level_name] 692 | 693 | if FLAGS.mode == 'train': 694 | train(action_set, level_names) 695 | else: 696 | test(action_set, level_names) 697 | 698 | 699 | if __name__ == '__main__': 700 | tf.app.run() 701 | -------------------------------------------------------------------------------- /py_process.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PyProcess. 16 | 17 | This file includes utilities for running code in separate Python processes as 18 | part of a TensorFlow graph. It is similar to tf.py_func, but the code is run in 19 | separate processes to avoid the GIL. 20 | 21 | Example: 22 | 23 | class Zeros(object): 24 | 25 | def __init__(self, dim0): 26 | self._dim0 = dim0 27 | 28 | def compute(self, dim1): 29 | return np.zeros([self._dim0, dim1], dtype=np.int32) 30 | 31 | @staticmethod 32 | def _tensor_specs(method_name, kwargs, constructor_kwargs): 33 | dim0 = constructor_kwargs['dim0'] 34 | dim1 = kwargs['dim1'] 35 | if method_name == 'compute': 36 | return tf.contrib.framework.TensorSpec([dim0, dim1], tf.int32) 37 | 38 | with tf.Graph().as_default(): 39 | p = py_process.PyProcess(Zeros, 1) 40 | result = p.proxy.compute(2) 41 | 42 | with tf.train.SingularMonitoredSession( 43 | hooks=[py_process.PyProcessHook()]) as session: 44 | print(session.run(result)) # Prints [[0, 0]]. 45 | """ 46 | 47 | from __future__ import absolute_import 48 | from __future__ import division 49 | from __future__ import print_function 50 | 51 | import multiprocessing 52 | 53 | import tensorflow as tf 54 | 55 | from tensorflow.python.util import function_utils 56 | 57 | 58 | nest = tf.contrib.framework.nest 59 | 60 | 61 | class _TFProxy(object): 62 | """A proxy that creates TensorFlow operations for each method call to a 63 | separate process.""" 64 | 65 | def __init__(self, type_, constructor_kwargs): 66 | self._type = type_ 67 | self._constructor_kwargs = constructor_kwargs 68 | 69 | def __getattr__(self, name): 70 | def call(*args): 71 | kwargs = dict( 72 | zip(function_utils.fn_args(getattr(self._type, name))[1:], args)) 73 | specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs) 74 | 75 | if specs is None: 76 | raise ValueError( 77 | 'No tensor specifications were provided for: %s' % name) 78 | 79 | flat_dtypes = nest.flatten(nest.map_structure(lambda s: s.dtype, specs)) 80 | flat_shapes = nest.flatten(nest.map_structure(lambda s: s.shape, specs)) 81 | 82 | def py_call(*args): 83 | try: 84 | self._out.send(args) 85 | result = self._out.recv() 86 | if isinstance(result, Exception): 87 | raise result 88 | if result is not None: 89 | return result 90 | except Exception as e: 91 | if isinstance(e, IOError): 92 | raise StopIteration() # Clean exit. 93 | else: 94 | raise 95 | 96 | result = tf.py_func(py_call, (name,) + tuple(args), flat_dtypes, 97 | name=name) 98 | 99 | if isinstance(result, tf.Operation): 100 | return result 101 | 102 | for t, shape in zip(result, flat_shapes): 103 | t.set_shape(shape) 104 | return nest.pack_sequence_as(specs, result) 105 | return call 106 | 107 | def _start(self): 108 | self._out, in_ = multiprocessing.Pipe() 109 | self._process = multiprocessing.Process( 110 | target=self._worker_fn, 111 | args=(self._type, self._constructor_kwargs, in_)) 112 | self._process.start() 113 | result = self._out.recv() 114 | 115 | if isinstance(result, Exception): 116 | raise result 117 | 118 | def _close(self, session): 119 | try: 120 | self._out.send(None) 121 | self._out.close() 122 | except IOError: 123 | pass 124 | self._process.join() 125 | 126 | def _worker_fn(self, type_, constructor_kwargs, in_): 127 | try: 128 | o = type_(**constructor_kwargs) 129 | 130 | in_.send(None) # Ready. 131 | 132 | while True: 133 | # Receive request. 134 | serialized = in_.recv() 135 | 136 | if serialized is None: 137 | if hasattr(o, 'close'): 138 | o.close() 139 | in_.close() 140 | return 141 | 142 | method_name = str(serialized[0]) 143 | inputs = serialized[1:] 144 | 145 | # Compute result. 146 | results = getattr(o, method_name)(*inputs) 147 | if results is not None: 148 | results = nest.flatten(results) 149 | 150 | # Respond. 151 | in_.send(results) 152 | except Exception as e: 153 | if 'o' in locals() and hasattr(o, 'close'): 154 | try: 155 | o.close() 156 | except: 157 | pass 158 | in_.send(e) 159 | 160 | 161 | class PyProcess(object): 162 | COLLECTION = 'py_process_processes' 163 | 164 | def __init__(self, type_, *constructor_args, **constructor_kwargs): 165 | self._type = type_ 166 | self._constructor_kwargs = dict( 167 | zip(function_utils.fn_args(type_.__init__)[1:], constructor_args)) 168 | self._constructor_kwargs.update(constructor_kwargs) 169 | 170 | tf.add_to_collection(PyProcess.COLLECTION, self) 171 | 172 | self._proxy = _TFProxy(type_, self._constructor_kwargs) 173 | 174 | @property 175 | def proxy(self): 176 | """A proxy that creates TensorFlow operations for each method call.""" 177 | return self._proxy 178 | 179 | def close(self, session): 180 | self._proxy._close(session) 181 | 182 | def start(self): 183 | self._proxy._start() 184 | 185 | 186 | class PyProcessHook(tf.train.SessionRunHook): 187 | """A MonitoredSession hook that starts and stops PyProcess instances.""" 188 | 189 | def begin(self): 190 | tf.logging.info('Starting all processes.') 191 | tp = multiprocessing.pool.ThreadPool() 192 | tp.map(lambda p: p.start(), tf.get_collection(PyProcess.COLLECTION)) 193 | tp.close() 194 | tp.join() 195 | tf.logging.info('All processes started.') 196 | 197 | def end(self, session): 198 | tf.logging.info('Closing all processes.') 199 | tp = multiprocessing.pool.ThreadPool() 200 | tp.map(lambda p: p.close(session), tf.get_collection(PyProcess.COLLECTION)) 201 | tp.close() 202 | tp.join() 203 | tf.logging.info('All processes closed.') 204 | -------------------------------------------------------------------------------- /py_process_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests py_process.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tempfile 22 | import time 23 | 24 | import numpy as np 25 | import py_process 26 | import tensorflow as tf 27 | 28 | from six.moves import range 29 | 30 | 31 | class PyProcessTest(tf.test.TestCase): 32 | 33 | def test_small(self): 34 | 35 | class Example(object): 36 | 37 | def __init__(self, a): 38 | self._a = a 39 | 40 | def inc(self): 41 | self._a += 1 42 | 43 | def compute(self, b): 44 | return np.array(self._a + b, dtype=np.int32) 45 | 46 | @staticmethod 47 | def _tensor_specs(method_name, unused_args, unused_constructor_kwargs): 48 | if method_name == 'compute': 49 | return tf.contrib.framework.TensorSpec([], tf.int32) 50 | elif method_name == 'inc': 51 | return () 52 | 53 | with tf.Graph().as_default(): 54 | p = py_process.PyProcess(Example, 1) 55 | inc = p.proxy.inc() 56 | compute = p.proxy.compute(2) 57 | 58 | with tf.train.SingularMonitoredSession( 59 | hooks=[py_process.PyProcessHook()]) as session: 60 | self.assertTrue(isinstance(inc, tf.Operation)) 61 | session.run(inc) 62 | 63 | self.assertEqual([], compute.shape) 64 | self.assertEqual(4, session.run(compute)) 65 | 66 | def test_threading(self): 67 | 68 | class Example(object): 69 | 70 | def __init__(self): 71 | pass 72 | 73 | def wait(self): 74 | time.sleep(.2) 75 | return None 76 | 77 | @staticmethod 78 | def _tensor_specs(method_name, unused_args, unused_constructor_kwargs): 79 | if method_name == 'wait': 80 | return tf.contrib.framework.TensorSpec([], tf.int32) 81 | 82 | with tf.Graph().as_default(): 83 | p = py_process.PyProcess(Example) 84 | wait = p.proxy.wait() 85 | 86 | hook = py_process.PyProcessHook() 87 | with tf.train.SingularMonitoredSession(hooks=[hook]) as session: 88 | 89 | def run(): 90 | with self.assertRaises(tf.errors.OutOfRangeError): 91 | session.run(wait) 92 | 93 | t = self.checkedThread(target=run) 94 | t.start() 95 | time.sleep(.1) 96 | t.join() 97 | 98 | def test_args(self): 99 | 100 | class Example(object): 101 | 102 | def __init__(self, dim0): 103 | self._dim0 = dim0 104 | 105 | def compute(self, dim1): 106 | return np.zeros([self._dim0, dim1], dtype=np.int32) 107 | 108 | @staticmethod 109 | def _tensor_specs(method_name, kwargs, constructor_kwargs): 110 | dim0 = constructor_kwargs['dim0'] 111 | dim1 = kwargs['dim1'] 112 | if method_name == 'compute': 113 | return tf.contrib.framework.TensorSpec([dim0, dim1], tf.int32) 114 | 115 | with tf.Graph().as_default(): 116 | p = py_process.PyProcess(Example, 1) 117 | result = p.proxy.compute(2) 118 | 119 | with tf.train.SingularMonitoredSession( 120 | hooks=[py_process.PyProcessHook()]) as session: 121 | self.assertEqual([1, 2], result.shape) 122 | self.assertAllEqual([[0, 0]], session.run(result)) 123 | 124 | def test_error_handling_constructor(self): 125 | 126 | class Example(object): 127 | 128 | def __init__(self): 129 | raise ValueError('foo') 130 | 131 | def something(self): 132 | pass 133 | 134 | @staticmethod 135 | def _tensor_specs(method_name, unused_kwargs, unused_constructor_kwargs): 136 | if method_name == 'something': 137 | return () 138 | 139 | with tf.Graph().as_default(): 140 | py_process.PyProcess(Example, 1) 141 | 142 | with self.assertRaisesRegexp(Exception, 'foo'): 143 | with tf.train.SingularMonitoredSession( 144 | hooks=[py_process.PyProcessHook()]): 145 | pass 146 | 147 | def test_error_handling_method(self): 148 | 149 | class Example(object): 150 | 151 | def __init__(self): 152 | pass 153 | 154 | def something(self): 155 | raise ValueError('foo') 156 | 157 | @staticmethod 158 | def _tensor_specs(method_name, unused_kwargs, unused_constructor_kwargs): 159 | if method_name == 'something': 160 | return () 161 | 162 | with tf.Graph().as_default(): 163 | p = py_process.PyProcess(Example, 1) 164 | result = p.proxy.something() 165 | 166 | with tf.train.SingularMonitoredSession( 167 | hooks=[py_process.PyProcessHook()]) as session: 168 | with self.assertRaisesRegexp(Exception, 'foo'): 169 | session.run(result) 170 | 171 | def test_close(self): 172 | with tempfile.NamedTemporaryFile() as tmp: 173 | class Example(object): 174 | 175 | def __init__(self, filename): 176 | self._filename = filename 177 | 178 | def close(self): 179 | with tf.gfile.Open(self._filename, 'w') as f: 180 | f.write('was_closed') 181 | 182 | with tf.Graph().as_default(): 183 | py_process.PyProcess(Example, tmp.name) 184 | 185 | with tf.train.SingularMonitoredSession( 186 | hooks=[py_process.PyProcessHook()]): 187 | pass 188 | 189 | self.assertEqual('was_closed', tmp.read()) 190 | 191 | def test_close_on_error(self): 192 | with tempfile.NamedTemporaryFile() as tmp: 193 | 194 | class Example(object): 195 | 196 | def __init__(self, filename): 197 | self._filename = filename 198 | 199 | def something(self): 200 | raise ValueError('foo') 201 | 202 | def close(self): 203 | with tf.gfile.Open(self._filename, 'w') as f: 204 | f.write('was_closed') 205 | 206 | @staticmethod 207 | def _tensor_specs(method_name, unused_kwargs, 208 | unused_constructor_kwargs): 209 | if method_name == 'something': 210 | return () 211 | 212 | with tf.Graph().as_default(): 213 | p = py_process.PyProcess(Example, tmp.name) 214 | result = p.proxy.something() 215 | 216 | with tf.train.SingularMonitoredSession( 217 | hooks=[py_process.PyProcessHook()]) as session: 218 | with self.assertRaisesRegexp(Exception, 'foo'): 219 | session.run(result) 220 | 221 | self.assertEqual('was_closed', tmp.read()) 222 | 223 | 224 | class PyProcessBenchmarks(tf.test.Benchmark): 225 | 226 | class Example(object): 227 | 228 | def __init__(self): 229 | self._result = np.random.randint(0, 256, (72, 96, 3), np.uint8) 230 | 231 | def compute(self, unused_a): 232 | return self._result 233 | 234 | @staticmethod 235 | def _tensor_specs(method_name, unused_args, unused_constructor_kwargs): 236 | if method_name == 'compute': 237 | return tf.contrib.framework.TensorSpec([72, 96, 3], tf.uint8) 238 | 239 | def benchmark_one(self): 240 | with tf.Graph().as_default(): 241 | p = py_process.PyProcess(PyProcessBenchmarks.Example) 242 | compute = p.proxy.compute(2) 243 | 244 | with tf.train.SingularMonitoredSession( 245 | hooks=[py_process.PyProcessHook()]) as session: 246 | 247 | self.run_op_benchmark( 248 | name='process_one', 249 | sess=session, 250 | op_or_tensor=compute, 251 | burn_iters=10, 252 | min_iters=5000) 253 | 254 | def benchmark_many(self): 255 | with tf.Graph().as_default(): 256 | ps = [ 257 | py_process.PyProcess(PyProcessBenchmarks.Example) for _ in range(200) 258 | ] 259 | compute_ops = [p.proxy.compute(2) for p in ps] 260 | compute = tf.group(*compute_ops) 261 | 262 | with tf.train.SingularMonitoredSession( 263 | hooks=[py_process.PyProcessHook()]) as session: 264 | 265 | self.run_op_benchmark( 266 | name='process_many', 267 | sess=session, 268 | op_or_tensor=compute, 269 | burn_iters=10, 270 | min_iters=500) 271 | 272 | 273 | if __name__ == '__main__': 274 | tf.test.main() 275 | -------------------------------------------------------------------------------- /vtrace.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions to compute V-trace off-policy actor critic targets. 16 | 17 | For details and theory see: 18 | 19 | "IMPALA: Scalable Distributed Deep-RL with 20 | Importance Weighted Actor-Learner Architectures" 21 | by Espeholt, Soyer, Munos et al. 22 | 23 | See https://arxiv.org/abs/1802.01561 for the full paper. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import collections 31 | 32 | import tensorflow as tf 33 | 34 | nest = tf.contrib.framework.nest 35 | 36 | 37 | VTraceFromLogitsReturns = collections.namedtuple( 38 | 'VTraceFromLogitsReturns', 39 | ['vs', 'pg_advantages', 'log_rhos', 40 | 'behaviour_action_log_probs', 'target_action_log_probs']) 41 | 42 | VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages') 43 | 44 | 45 | def log_probs_from_logits_and_actions(policy_logits, actions): 46 | """Computes action log-probs from policy logits and actions. 47 | 48 | In the notation used throughout documentation and comments, T refers to the 49 | time dimension ranging from 0 to T-1. B refers to the batch size and 50 | NUM_ACTIONS refers to the number of actions. 51 | 52 | Args: 53 | policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with 54 | un-normalized log-probabilities parameterizing a softmax policy. 55 | actions: An int32 tensor of shape [T, B] with actions. 56 | 57 | Returns: 58 | A float32 tensor of shape [T, B] corresponding to the sampling log 59 | probability of the chosen action w.r.t. the policy. 60 | """ 61 | policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32) 62 | actions = tf.convert_to_tensor(actions, dtype=tf.int32) 63 | 64 | policy_logits.shape.assert_has_rank(3) 65 | actions.shape.assert_has_rank(2) 66 | 67 | return -tf.nn.sparse_softmax_cross_entropy_with_logits( 68 | logits=policy_logits, labels=actions) 69 | 70 | 71 | def from_logits( 72 | behaviour_policy_logits, target_policy_logits, actions, 73 | discounts, rewards, values, bootstrap_value, 74 | clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, 75 | name='vtrace_from_logits'): 76 | r"""V-trace for softmax policies. 77 | 78 | Calculates V-trace actor critic targets for softmax polices as described in 79 | 80 | "IMPALA: Scalable Distributed Deep-RL with 81 | Importance Weighted Actor-Learner Architectures" 82 | by Espeholt, Soyer, Munos et al. 83 | 84 | Target policy refers to the policy we are interested in improving and 85 | behaviour policy refers to the policy that generated the given 86 | rewards and actions. 87 | 88 | In the notation used throughout documentation and comments, T refers to the 89 | time dimension ranging from 0 to T-1. B refers to the batch size and 90 | NUM_ACTIONS refers to the number of actions. 91 | 92 | Args: 93 | behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with 94 | un-normalized log-probabilities parametrizing the softmax behaviour 95 | policy. 96 | target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with 97 | un-normalized log-probabilities parametrizing the softmax target policy. 98 | actions: An int32 tensor of shape [T, B] of actions sampled from the 99 | behaviour policy. 100 | discounts: A float32 tensor of shape [T, B] with the discount encountered 101 | when following the behaviour policy. 102 | rewards: A float32 tensor of shape [T, B] with the rewards generated by 103 | following the behaviour policy. 104 | values: A float32 tensor of shape [T, B] with the value function estimates 105 | wrt. the target policy. 106 | bootstrap_value: A float32 of shape [B] with the value function estimate at 107 | time T. 108 | clip_rho_threshold: A scalar float32 tensor with the clipping threshold for 109 | importance weights (rho) when calculating the baseline targets (vs). 110 | rho^bar in the paper. 111 | clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold 112 | on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). 113 | name: The name scope that all V-trace operations will be created in. 114 | 115 | Returns: 116 | A `VTraceFromLogitsReturns` namedtuple with the following fields: 117 | vs: A float32 tensor of shape [T, B]. Can be used as target to train a 118 | baseline (V(x_t) - vs_t)^2. 119 | pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an 120 | estimate of the advantage in the calculation of policy gradients. 121 | log_rhos: A float32 tensor of shape [T, B] containing the log importance 122 | sampling weights (log rhos). 123 | behaviour_action_log_probs: A float32 tensor of shape [T, B] containing 124 | behaviour policy action log probabilities (log \mu(a_t)). 125 | target_action_log_probs: A float32 tensor of shape [T, B] containing 126 | target policy action probabilities (log \pi(a_t)). 127 | """ 128 | behaviour_policy_logits = tf.convert_to_tensor( 129 | behaviour_policy_logits, dtype=tf.float32) 130 | target_policy_logits = tf.convert_to_tensor( 131 | target_policy_logits, dtype=tf.float32) 132 | actions = tf.convert_to_tensor(actions, dtype=tf.int32) 133 | 134 | # Make sure tensor ranks are as expected. 135 | # The rest will be checked by from_action_log_probs. 136 | behaviour_policy_logits.shape.assert_has_rank(3) 137 | target_policy_logits.shape.assert_has_rank(3) 138 | actions.shape.assert_has_rank(2) 139 | 140 | with tf.name_scope(name, values=[ 141 | behaviour_policy_logits, target_policy_logits, actions, 142 | discounts, rewards, values, bootstrap_value]): 143 | target_action_log_probs = log_probs_from_logits_and_actions( 144 | target_policy_logits, actions) 145 | behaviour_action_log_probs = log_probs_from_logits_and_actions( 146 | behaviour_policy_logits, actions) 147 | log_rhos = target_action_log_probs - behaviour_action_log_probs 148 | vtrace_returns = from_importance_weights( 149 | log_rhos=log_rhos, 150 | discounts=discounts, 151 | rewards=rewards, 152 | values=values, 153 | bootstrap_value=bootstrap_value, 154 | clip_rho_threshold=clip_rho_threshold, 155 | clip_pg_rho_threshold=clip_pg_rho_threshold) 156 | return VTraceFromLogitsReturns( 157 | log_rhos=log_rhos, 158 | behaviour_action_log_probs=behaviour_action_log_probs, 159 | target_action_log_probs=target_action_log_probs, 160 | **vtrace_returns._asdict() 161 | ) 162 | 163 | 164 | def from_importance_weights( 165 | log_rhos, discounts, rewards, values, bootstrap_value, 166 | clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, 167 | name='vtrace_from_importance_weights'): 168 | r"""V-trace from log importance weights. 169 | 170 | Calculates V-trace actor critic targets as described in 171 | 172 | "IMPALA: Scalable Distributed Deep-RL with 173 | Importance Weighted Actor-Learner Architectures" 174 | by Espeholt, Soyer, Munos et al. 175 | 176 | In the notation used throughout documentation and comments, T refers to the 177 | time dimension ranging from 0 to T-1. B refers to the batch size and 178 | NUM_ACTIONS refers to the number of actions. This code also supports the 179 | case where all tensors have the same number of additional dimensions, e.g., 180 | `rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C]. 181 | 182 | Args: 183 | log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the log 184 | importance sampling weights, i.e. 185 | log(target_policy(a) / behaviour_policy(a)). V-trace performs operations 186 | on rhos in log-space for numerical stability. 187 | discounts: A float32 tensor of shape [T, B] with discounts encountered when 188 | following the behaviour policy. 189 | rewards: A float32 tensor of shape [T, B] containing rewards generated by 190 | following the behaviour policy. 191 | values: A float32 tensor of shape [T, B] with the value function estimates 192 | wrt. the target policy. 193 | bootstrap_value: A float32 of shape [B] with the value function estimate at 194 | time T. 195 | clip_rho_threshold: A scalar float32 tensor with the clipping threshold for 196 | importance weights (rho) when calculating the baseline targets (vs). 197 | rho^bar in the paper. If None, no clipping is applied. 198 | clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold 199 | on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If 200 | None, no clipping is applied. 201 | name: The name scope that all V-trace operations will be created in. 202 | 203 | Returns: 204 | A VTraceReturns namedtuple (vs, pg_advantages) where: 205 | vs: A float32 tensor of shape [T, B]. Can be used as target to 206 | train a baseline (V(x_t) - vs_t)^2. 207 | pg_advantages: A float32 tensor of shape [T, B]. Can be used as the 208 | advantage in the calculation of policy gradients. 209 | """ 210 | log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32) 211 | discounts = tf.convert_to_tensor(discounts, dtype=tf.float32) 212 | rewards = tf.convert_to_tensor(rewards, dtype=tf.float32) 213 | values = tf.convert_to_tensor(values, dtype=tf.float32) 214 | bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32) 215 | if clip_rho_threshold is not None: 216 | clip_rho_threshold = tf.convert_to_tensor(clip_rho_threshold, 217 | dtype=tf.float32) 218 | if clip_pg_rho_threshold is not None: 219 | clip_pg_rho_threshold = tf.convert_to_tensor(clip_pg_rho_threshold, 220 | dtype=tf.float32) 221 | 222 | # Make sure tensor ranks are consistent. 223 | rho_rank = log_rhos.shape.ndims # Usually 2. 224 | values.shape.assert_has_rank(rho_rank) 225 | bootstrap_value.shape.assert_has_rank(rho_rank - 1) 226 | discounts.shape.assert_has_rank(rho_rank) 227 | rewards.shape.assert_has_rank(rho_rank) 228 | if clip_rho_threshold is not None: 229 | clip_rho_threshold.shape.assert_has_rank(0) 230 | if clip_pg_rho_threshold is not None: 231 | clip_pg_rho_threshold.shape.assert_has_rank(0) 232 | 233 | with tf.name_scope(name, values=[ 234 | log_rhos, discounts, rewards, values, bootstrap_value]): 235 | rhos = tf.exp(log_rhos) 236 | if clip_rho_threshold is not None: 237 | clipped_rhos = tf.minimum(clip_rho_threshold, rhos, name='clipped_rhos') 238 | else: 239 | clipped_rhos = rhos 240 | 241 | cs = tf.minimum(1.0, rhos, name='cs') 242 | # Append bootstrapped value to get [v1, ..., v_t+1] 243 | values_t_plus_1 = tf.concat( 244 | [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) 245 | deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) 246 | 247 | sequences = (discounts, cs, deltas) 248 | # V-trace vs are calculated through a scan from the back to the beginning 249 | # of the given trajectory. 250 | def scanfunc(acc, sequence_item): 251 | discount_t, c_t, delta_t = sequence_item 252 | return delta_t + discount_t * c_t * acc 253 | 254 | initial_values = tf.zeros_like(bootstrap_value) 255 | vs_minus_v_xs = tf.scan( 256 | fn=scanfunc, 257 | elems=sequences, 258 | initializer=initial_values, 259 | parallel_iterations=1, 260 | back_prop=False, 261 | reverse=True, # Computation starts from the back. 262 | name='scan') 263 | 264 | # Add V(x_s) to get v_s. 265 | vs = tf.add(vs_minus_v_xs, values, name='vs') 266 | 267 | # Advantage for policy gradient. 268 | vs_t_plus_1 = tf.concat([ 269 | vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) 270 | if clip_pg_rho_threshold is not None: 271 | clipped_pg_rhos = tf.minimum(clip_pg_rho_threshold, rhos, 272 | name='clipped_pg_rhos') 273 | else: 274 | clipped_pg_rhos = rhos 275 | pg_advantages = ( 276 | clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)) 277 | 278 | # Make sure no gradients backpropagated through the returned values. 279 | return VTraceReturns(vs=tf.stop_gradient(vs), 280 | pg_advantages=tf.stop_gradient(pg_advantages)) 281 | -------------------------------------------------------------------------------- /vtrace_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for V-trace. 16 | 17 | For details and theory see: 18 | 19 | "IMPALA: Scalable Distributed Deep-RL with 20 | Importance Weighted Actor-Learner Architectures" 21 | by Espeholt, Soyer, Munos et al. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | from absl.testing import parameterized 29 | import numpy as np 30 | import tensorflow as tf 31 | import vtrace 32 | 33 | 34 | def _shaped_arange(*shape): 35 | """Runs np.arange, converts to float and reshapes.""" 36 | return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) 37 | 38 | 39 | def _softmax(logits): 40 | """Applies softmax non-linearity on inputs.""" 41 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) 42 | 43 | 44 | def _ground_truth_calculation(discounts, log_rhos, rewards, values, 45 | bootstrap_value, clip_rho_threshold, 46 | clip_pg_rho_threshold): 47 | """Calculates the ground truth for V-trace in Python/Numpy.""" 48 | vs = [] 49 | seq_len = len(discounts) 50 | rhos = np.exp(log_rhos) 51 | cs = np.minimum(rhos, 1.0) 52 | clipped_rhos = rhos 53 | if clip_rho_threshold: 54 | clipped_rhos = np.minimum(rhos, clip_rho_threshold) 55 | clipped_pg_rhos = rhos 56 | if clip_pg_rho_threshold: 57 | clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) 58 | 59 | # This is a very inefficient way to calculate the V-trace ground truth. 60 | # We calculate it this way because it is close to the mathematical notation of 61 | # V-trace. 62 | # v_s = V(x_s) 63 | # + \sum^{T-1}_{t=s} \gamma^{t-s} 64 | # * \prod_{i=s}^{t-1} c_i 65 | # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) 66 | # Note that when we take the product over c_i, we write `s:t` as the notation 67 | # of the paper is inclusive of the `t-1`, but Python is exclusive. 68 | # Also note that np.prod([]) == 1. 69 | values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0) 70 | for s in range(seq_len): 71 | v_s = np.copy(values[s]) # Very important copy. 72 | for t in range(s, seq_len): 73 | v_s += ( 74 | np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t], 75 | axis=0) * clipped_rhos[t] * 76 | (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - values[t])) 77 | vs.append(v_s) 78 | vs = np.stack(vs, axis=0) 79 | pg_advantages = ( 80 | clipped_pg_rhos * (rewards + discounts * np.concatenate( 81 | [vs[1:], bootstrap_value[None, :]], axis=0) - values)) 82 | 83 | return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages) 84 | 85 | 86 | class LogProbsFromLogitsAndActionsTest(tf.test.TestCase, 87 | parameterized.TestCase): 88 | 89 | @parameterized.named_parameters(('Batch1', 1), ('Batch2', 2)) 90 | def test_log_probs_from_logits_and_actions(self, batch_size): 91 | """Tests log_probs_from_logits_and_actions.""" 92 | seq_len = 7 93 | num_actions = 3 94 | 95 | policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10 96 | actions = np.random.randint( 97 | 0, num_actions, size=(seq_len, batch_size), dtype=np.int32) 98 | 99 | action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions( 100 | policy_logits, actions) 101 | 102 | # Ground Truth 103 | # Using broadcasting to create a mask that indexes action logits 104 | action_index_mask = actions[..., None] == np.arange(num_actions) 105 | 106 | def index_with_mask(array, mask): 107 | return array[mask].reshape(*array.shape[:-1]) 108 | 109 | # Note: Normally log(softmax) is not a good idea because it's not 110 | # numerically stable. However, in this test we have well-behaved values. 111 | ground_truth_v = index_with_mask( 112 | np.log(_softmax(policy_logits)), action_index_mask) 113 | 114 | with self.test_session() as session: 115 | self.assertAllClose(ground_truth_v, session.run(action_log_probs_tensor)) 116 | 117 | 118 | class VtraceTest(tf.test.TestCase, parameterized.TestCase): 119 | 120 | @parameterized.named_parameters(('Batch1', 1), ('Batch5', 5)) 121 | def test_vtrace(self, batch_size): 122 | """Tests V-trace against ground truth data calculated in python.""" 123 | seq_len = 5 124 | 125 | # Create log_rhos such that rho will span from near-zero to above the 126 | # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), 127 | # so that rho is in approx [0.08, 12.2). 128 | log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) 129 | log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). 130 | values = { 131 | 'log_rhos': log_rhos, 132 | # T, B where B_i: [0.9 / (i+1)] * T 133 | 'discounts': 134 | np.array([[0.9 / (b + 1) 135 | for b in range(batch_size)] 136 | for _ in range(seq_len)]), 137 | 'rewards': 138 | _shaped_arange(seq_len, batch_size), 139 | 'values': 140 | _shaped_arange(seq_len, batch_size) / batch_size, 141 | 'bootstrap_value': 142 | _shaped_arange(batch_size) + 1.0, 143 | 'clip_rho_threshold': 144 | 3.7, 145 | 'clip_pg_rho_threshold': 146 | 2.2, 147 | } 148 | 149 | output = vtrace.from_importance_weights(**values) 150 | 151 | with self.test_session() as session: 152 | output_v = session.run(output) 153 | 154 | ground_truth_v = _ground_truth_calculation(**values) 155 | for a, b in zip(ground_truth_v, output_v): 156 | self.assertAllClose(a, b) 157 | 158 | @parameterized.named_parameters(('Batch1', 1), ('Batch2', 2)) 159 | def test_vtrace_from_logits(self, batch_size): 160 | """Tests V-trace calculated from logits.""" 161 | seq_len = 5 162 | num_actions = 3 163 | clip_rho_threshold = None # No clipping. 164 | clip_pg_rho_threshold = None # No clipping. 165 | 166 | # Intentionally leaving shapes unspecified to test if V-trace can 167 | # deal with that. 168 | placeholders = { 169 | # T, B, NUM_ACTIONS 170 | 'behaviour_policy_logits': 171 | tf.placeholder(dtype=tf.float32, shape=[None, None, None]), 172 | # T, B, NUM_ACTIONS 173 | 'target_policy_logits': 174 | tf.placeholder(dtype=tf.float32, shape=[None, None, None]), 175 | 'actions': 176 | tf.placeholder(dtype=tf.int32, shape=[None, None]), 177 | 'discounts': 178 | tf.placeholder(dtype=tf.float32, shape=[None, None]), 179 | 'rewards': 180 | tf.placeholder(dtype=tf.float32, shape=[None, None]), 181 | 'values': 182 | tf.placeholder(dtype=tf.float32, shape=[None, None]), 183 | 'bootstrap_value': 184 | tf.placeholder(dtype=tf.float32, shape=[None]), 185 | } 186 | 187 | from_logits_output = vtrace.from_logits( 188 | clip_rho_threshold=clip_rho_threshold, 189 | clip_pg_rho_threshold=clip_pg_rho_threshold, 190 | **placeholders) 191 | 192 | target_log_probs = vtrace.log_probs_from_logits_and_actions( 193 | placeholders['target_policy_logits'], placeholders['actions']) 194 | behaviour_log_probs = vtrace.log_probs_from_logits_and_actions( 195 | placeholders['behaviour_policy_logits'], placeholders['actions']) 196 | log_rhos = target_log_probs - behaviour_log_probs 197 | ground_truth = (log_rhos, behaviour_log_probs, target_log_probs) 198 | 199 | values = { 200 | 'behaviour_policy_logits': 201 | _shaped_arange(seq_len, batch_size, num_actions), 202 | 'target_policy_logits': 203 | _shaped_arange(seq_len, batch_size, num_actions), 204 | 'actions': 205 | np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)), 206 | 'discounts': 207 | np.array( # T, B where B_i: [0.9 / (i+1)] * T 208 | [[0.9 / (b + 1) 209 | for b in range(batch_size)] 210 | for _ in range(seq_len)]), 211 | 'rewards': 212 | _shaped_arange(seq_len, batch_size), 213 | 'values': 214 | _shaped_arange(seq_len, batch_size) / batch_size, 215 | 'bootstrap_value': 216 | _shaped_arange(batch_size) + 1.0, # B 217 | } 218 | 219 | feed_dict = {placeholders[k]: v for k, v in values.items()} 220 | with self.test_session() as session: 221 | from_logits_output_v = session.run( 222 | from_logits_output, feed_dict=feed_dict) 223 | (ground_truth_log_rhos, ground_truth_behaviour_action_log_probs, 224 | ground_truth_target_action_log_probs) = session.run( 225 | ground_truth, feed_dict=feed_dict) 226 | 227 | # Calculate V-trace using the ground truth logits. 228 | from_iw = vtrace.from_importance_weights( 229 | log_rhos=ground_truth_log_rhos, 230 | discounts=values['discounts'], 231 | rewards=values['rewards'], 232 | values=values['values'], 233 | bootstrap_value=values['bootstrap_value'], 234 | clip_rho_threshold=clip_rho_threshold, 235 | clip_pg_rho_threshold=clip_pg_rho_threshold) 236 | 237 | with self.test_session() as session: 238 | from_iw_v = session.run(from_iw) 239 | 240 | self.assertAllClose(from_iw_v.vs, from_logits_output_v.vs) 241 | self.assertAllClose(from_iw_v.pg_advantages, 242 | from_logits_output_v.pg_advantages) 243 | self.assertAllClose(ground_truth_behaviour_action_log_probs, 244 | from_logits_output_v.behaviour_action_log_probs) 245 | self.assertAllClose(ground_truth_target_action_log_probs, 246 | from_logits_output_v.target_action_log_probs) 247 | self.assertAllClose(ground_truth_log_rhos, from_logits_output_v.log_rhos) 248 | 249 | def test_higher_rank_inputs_for_importance_weights(self): 250 | """Checks support for additional dimensions in inputs.""" 251 | placeholders = { 252 | 'log_rhos': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 253 | 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 254 | 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 255 | 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 256 | 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None, 42]) 257 | } 258 | output = vtrace.from_importance_weights(**placeholders) 259 | self.assertEqual(output.vs.shape.as_list()[-1], 42) 260 | 261 | def test_inconsistent_rank_inputs_for_importance_weights(self): 262 | """Test one of many possible errors in shape of inputs.""" 263 | placeholders = { 264 | 'log_rhos': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 265 | 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 266 | 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 267 | 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 268 | # Should be [None, 42]. 269 | 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None]) 270 | } 271 | with self.assertRaisesRegexp(ValueError, 'must have rank 2'): 272 | vtrace.from_importance_weights(**placeholders) 273 | 274 | 275 | if __name__ == '__main__': 276 | tf.test.main() 277 | --------------------------------------------------------------------------------