├── .github └── workflows │ └── publish-to-pypi.yml ├── .gitignore ├── .gitmodules ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── RELEASE_NOTES.md ├── dm_env_rpc ├── __init__.py ├── _version.py └── v1 │ ├── __init__.py │ ├── async_connection.py │ ├── async_connection_test.py │ ├── compliance │ ├── __init__.py │ ├── create_destroy_world.py │ ├── join_leave_world.py │ ├── reset.py │ ├── reset_world.py │ └── step.py │ ├── connection.py │ ├── connection_test.py │ ├── dm_env_adaptor.py │ ├── dm_env_adaptor_test.py │ ├── dm_env_flatten_utils.py │ ├── dm_env_flatten_utils_test.py │ ├── dm_env_rpc.proto │ ├── dm_env_rpc_test.py │ ├── dm_env_utils.py │ ├── dm_env_utils_test.py │ ├── error.py │ ├── error_test.py │ ├── extensions │ ├── __init__.py │ ├── properties.proto │ ├── properties.py │ └── properties_test.py │ ├── message_utils.py │ ├── message_utils_test.py │ ├── spec_manager.py │ ├── spec_manager_test.py │ ├── tensor_spec_utils.py │ ├── tensor_spec_utils_test.py │ ├── tensor_utils.py │ ├── tensor_utils_benchmark.py │ └── tensor_utils_test.py ├── docs └── v1 │ ├── 2x2.png │ ├── 2x3.png │ ├── appendix.md │ ├── extensions │ ├── index.md │ └── properties.md │ ├── glossary.md │ ├── index.md │ ├── overview.md │ ├── reference.md │ ├── single_agent_connect_and_step.png │ ├── single_agent_sequence_transitions.png │ ├── single_agent_world_destruction.png │ ├── state_transitions.graphviz │ └── state_transitions.png ├── examples ├── catch_environment.py ├── catch_human_agent.py └── catch_test.py ├── requirements.txt └── setup.py /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created. 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | workflow_dispatch: 10 | 11 | jobs: 12 | deploy: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: '3.7' 22 | - name: Add repository sub-modules 23 | run: | 24 | git submodule init 25 | git submodule update 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install setuptools wheel twine 30 | - name: Build and publish 31 | env: 32 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 33 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 34 | run: | 35 | python setup.py sdist bdist_wheel 36 | twine upload dist/* 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore byte-compiled Python code 2 | *.py[cod] 3 | 4 | # Ignore protobuf bindings. 5 | **/*pb2*.py 6 | 7 | # Ignore directories created during the build/installation process 8 | *.egg-info/ 9 | .eggs/ 10 | build/ 11 | dist/ 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/api-common-protos"] 2 | path = third_party/api-common-protos 3 | url = https://github.com/googleapis/api-common-protos 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project, however this 4 | library is widely used in our research code, so we are unlikely to be able to 5 | accept breaking changes to the interface. 6 | 7 | There are just a few small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `dm_env_rpc`: A networking protocol for agent-environment communication. 2 | 3 | ![PyPI Python version](https://img.shields.io/pypi/pyversions/dm-env-rpc) 4 | ![PyPI version](https://badge.fury.io/py/dm-env-rpc.svg) 5 | 6 | `dm_env_rpc` is a remote procedure call (RPC) protocol for communicating between 7 | machine learning agents and environments. It uses [gRPC](http://www.grpc.io) as 8 | the underlying communication framework, specifically its 9 | [bidirectional streaming](http://grpc.io/docs/guides/concepts/#bidirectional-streaming-rpc) 10 | RPC variant. 11 | 12 | This package also contains an implementation of 13 | [`dm_env`](http://www.github.com/deepmind/dm_env), a Python interface for 14 | interacting with such environments. 15 | 16 | Please see the documentation for more detailed information on the semantics of 17 | the protocol and how to use it. The examples sub-directory also provides 18 | examples of RL environments implemented using the `dm_env_rpc` protocol. 19 | 20 | ## Intended audience 21 | 22 | Games can make for interesting AI research platforms, for example as 23 | reinforcement learning (RL) environments. However, exposing a game as an RL 24 | environment can be a subtle, fraught process. We aim to provide a protocol that 25 | allows agents and environments to communicate in a standardized way, without 26 | specialized knowledge about how the other side works. Game developers can expose 27 | their games as environments with minimal domain knowledge and researchers can 28 | test their agents on a large library of different games. 29 | 30 | This protocol also removes the need for agents and environments to run in the 31 | same process or even on the same machine, allowing agents and environments to 32 | have very different technology stacks and requirements. 33 | 34 | ## Documentation 35 | 36 | * [Protocol overview](docs/v1/overview.md) 37 | * [Protocol reference](docs/v1/reference.md) 38 | * [Extensions](docs/v1/extensions/index.md) 39 | * [Appendix](docs/v1/appendix.md) 40 | * [Glossary](docs/v1/glossary.md) 41 | 42 | ## Installation 43 | 44 | Note: You may optionally wish to create a 45 | [Python Virtual Environment](https://docs.python.org/3/tutorial/venv.html) to 46 | prevent conflicts with your system's Python environment. 47 | 48 | `dm_env_rpc` can be installed from [PyPi](https://pypi.org/project/dm-env-rpc/) 49 | using `pip`: 50 | 51 | ```bash 52 | $ pip install dm-env-rpc 53 | ``` 54 | 55 | To also install the dependencies for the `examples/`, install with: 56 | 57 | ```bash 58 | $ pip install dm-env-rpc[examples] 59 | ``` 60 | 61 | Alternatively, you can install `dm_env_rpc` by cloning a local copy of our 62 | GitHub repository: 63 | 64 | ```bash 65 | $ git clone --recursive https://github.com/deepmind/dm_env_rpc.git 66 | $ pip install ./dm_env_rpc 67 | ``` 68 | 69 | ## Citing `dm_env_rpc` 70 | 71 | To cite this repository: 72 | 73 | ```bibtex 74 | @misc{dm_env_rpc2019, 75 | author = {Tom Ward and Jay Lemmon}, 76 | title = {dm\_env\_rpc: A networking protocol for agent-environment communication}, 77 | url = {http://github.com/deepmind/dm_env_rpc}, 78 | year = {2019}, 79 | } 80 | ``` 81 | 82 | ## Developer Notes 83 | 84 | In order to run the included unit tests, developers must also install additional 85 | dependencies defined in `requirements.txt`, at which point the tests should be 86 | runnable via `pytest`. For example: 87 | 88 | ```bash 89 | $ git clone --recursive https://github.com/deepmind/dm_env_rpc.git 90 | $ pip install -U pip wheel 91 | $ pip install -e ./dm_env_rpc 92 | $ pip install -r ./dm_env_rpc/requirements.txt 93 | $ pytest ./dm_env_rpc 94 | ``` 95 | 96 | ## Notice 97 | 98 | This is not an officially supported Google product 99 | -------------------------------------------------------------------------------- /RELEASE_NOTES.md: -------------------------------------------------------------------------------- 1 | # Release Notes 2 | 3 | ## [1.1.6] 4 | 5 | * New `AsyncConnection` class that allows users to make asynchronous requests 6 | via `asyncio`. 7 | * Add optional `metadata` argument to `create_secure_channel_and_connect`. 8 | * Expose error status details in `DmEnvRpcError` exception. 9 | * Improvements and fixes related to type hints, in particular hints relating 10 | to NumPy. 11 | * Add utility functions for packing/unpacking custom request/responses. 12 | * Remove support for Python 3.7, now that it has reached EOL. 13 | 14 | ## [1.1.5] 15 | 16 | * Add optional `strict` argument to dictionary flattening. 17 | * Relax flattening of create/join world settings, so that users can pass 18 | already flattened settings. 19 | 20 | ## [1.1.4] 21 | 22 | * Support for nested create/join world settings in `dm_env_adaptor` helper 23 | functions. 24 | * Support scalar min/max bounds for array actions in compliance tests. 25 | * Allow `DmEnvAdaptor` to be closable multiple times. 26 | * Cleaned up various type hints, specifically removing deprecated NumPy type 27 | aliases. 28 | * Bug fixes. 29 | 30 | ## [1.1.3] 31 | 32 | * Pass all additional keyword arguments through to `DmEnvAdaptor` when using 33 | create/join helper functions. 34 | * Bug fixes and cleanup. 35 | 36 | ## [1.1.2] 37 | 38 | * Fixed Catch human agent example, raised in this 39 | [GitHub issue](https://github.com/deepmind/dm_env_rpc/issues/2). 40 | * Fixed bug when attempting to pack an empty array as a particular dtype. 41 | * Minor cleanup. 42 | 43 | ## [1.1.1] 44 | 45 | * Removed support for Python 3.6 46 | * Updated compliance tests to support wider range of environments. 47 | * Fixed bug with packing large `np.uint64` scalars. 48 | * Various PyType fixes. 49 | 50 | ## [1.1.0] 51 | 52 | WARNING: This release removes support for some previously deprecated fields. 53 | This may mean that scalar bounds for older environments are no longer readable. 54 | Users are advised to either revert to an older version, or re-build their 55 | environments to use the newer, multi-dimensional `TensorSpec.Value` fields. 56 | 57 | * Removed scalar `TensorSpec.Value` fields, which were marked as deprecated in 58 | [v1.0.1](#101). These have been superseded by array variants, which can be 59 | used for scalar bounds by creating a single element array. 60 | * Removed deprecated Property request/responses. These are now provided 61 | through the optional Property extension. 62 | * Refactored `Connection` to expose message packing utilities. 63 | 64 | ## [1.0.6] 65 | 66 | * `tensor_spec.bounds()` no longer broadcasts scalar bounds. 67 | * Fixed bug where `reward` and `discount` were inadvertently included in the 68 | observations when using `dm_env_adaptor`, without explicitly requesting 69 | these as observations. 70 | 71 | ## [1.0.5] 72 | 73 | * Better support for string specs in `dm_env_adaptor`. 74 | * Improved Python type annotations. 75 | * Check that the server has returned the correct response in 76 | `dm_env_rpc.Connection` class. 77 | * Added `create_world` helper function to `dm_env_adaptor`. 78 | 79 | ## [1.0.4] 80 | 81 | * Better support for variable sized tensors. 82 | * Support for packing/unpacking tensors that use `Any` protobuf messages. 83 | * Bug fixes. 84 | 85 | ## [1.0.3] 86 | 87 | ### Added 88 | 89 | * Support for property descriptions. 90 | * New utility functions for creating a Connection instance from a server 91 | address. 92 | * DmEnvAdaptor helper functions for creating and joining worlds. 93 | * Additional compliance tests for resetting. 94 | * Support for optional DmEnvAdaptor extensions. 95 | 96 | ### Changed 97 | 98 | * Removed portpicker dependency, instead relying on gRPC port picking 99 | functionality. 100 | * Changed property extension API to be more amenable to being used as an 101 | extension object for DmEnvAdaptor. 102 | 103 | ## [1.0.2] 104 | 105 | * Explicitly support nested tensors by the use of a period character in the 106 | `TensorSpec` name to indicate a level of nesting. Updated `dm_env` adaptor 107 | to flatten/unflattten actions and observations. 108 | * Increased minimum Python version to 3.6. 109 | * Moved property request/responses to its own extension. This supersedes the 110 | previous property requests, which have been marked as deprecated. **These 111 | requests will be removed in a future version of dm_env_rpc**. 112 | * Speed improvements for packing and unpacking byte arrays in Python. 113 | 114 | ## [1.0.1] 115 | 116 | ### Added 117 | 118 | * Support for per-element min/max values. This supersedes the existing scalar 119 | fields, which have been marked as deprecated. **These fields will be be 120 | removed in a future version of dm_env_rpc.** 121 | * Initial set of compliance tests that environment authors can use to better 122 | ensure their implementations adhere to the protocol specification. 123 | * Support for `dm_env` DiscreteArray specs. 124 | 125 | ### Changed 126 | 127 | * `dm_env_rpc` `EnvironmentResponse` errors in Python are now raised as a 128 | custom, `DmEnvRpcError` exception. 129 | 130 | ## [1.0.0] 131 | 132 | * Initial release. 133 | 134 | ## [1.0.0b2] 135 | 136 | * Updated minimum requirements for Python and protobuf. 137 | 138 | ## [1.0.0b1] 139 | 140 | * Initial beta release 141 | -------------------------------------------------------------------------------- /dm_env_rpc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A networking protocol for agent-environment communication.""" 16 | from dm_env_rpc._version import __version__ 17 | -------------------------------------------------------------------------------- /dm_env_rpc/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Package version for dm_env_rpc. 16 | 17 | Kept in separate file so it can be used during installation. 18 | """ 19 | 20 | __version__ = '1.1.6' # https://www.python.org/dev/peps/pep-0440/ 21 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Initial version of dm_env_rpc networking protocol.""" 16 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/async_connection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A helper class to manage a connection to a dm_env_rpc server asynchronously. 16 | 17 | This helper class allows sending the Request message types and receiving the 18 | Response message as a future. The class automatically wraps and unwraps from 19 | EnvironmentRequest and EnvironmentResponse, respectively. It also turns error 20 | messages in to exceptions. 21 | 22 | For most calls (such as create, join, etc.): 23 | 24 | with async_connection.AsyncConnection(grpc_channel) as async_channel: 25 | create_response = await async_channel.send( 26 | dm_env_rpc_pb2.CreateRequest(settings={ 27 | 'players': 5 28 | }) 29 | 30 | For the `extension` message type, you must send an Any proto and you'll get back 31 | an Any proto. It is up to you to wrap and unwrap these to concrete proto types 32 | that you know how to handle. 33 | 34 | with async_connection.AsyncConnection(grpc_channel) as async_channel: 35 | request = struct_pb2.Struct() 36 | ... 37 | request_any = any_pb2.Any() 38 | request_any.Pack(request) 39 | response_any = await async_channel.send(request_any) 40 | response = my_type_pb2.MyType() 41 | response_any.Unpack(response) 42 | 43 | 44 | Any errors encountered in the EnvironmentResponse are turned into Python 45 | exceptions, so explicit error handling code isn't needed per call. 46 | """ 47 | 48 | import asyncio 49 | from typing import Optional, Sequence, Tuple 50 | 51 | import grpc 52 | 53 | from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc 54 | from dm_env_rpc.v1 import message_utils 55 | 56 | 57 | Metadata = Sequence[Tuple[str, str]] 58 | 59 | 60 | class AsyncConnection: 61 | """A helper class for interacting with dm_env_rpc servers asynchronously.""" 62 | 63 | def __init__( 64 | self, channel: grpc.aio.Channel, metadata: Optional[Metadata] = None 65 | ): 66 | """Manages an async connection to a dm_env_rpc server. 67 | 68 | Args: 69 | channel: An async grpc channel to connect to the dm_env_rpc server over. 70 | metadata: Optional sequence of 2-tuples, sent to the gRPC server as 71 | metadata. 72 | """ 73 | self._stream = dm_env_rpc_pb2_grpc.EnvironmentStub(channel).Process( 74 | metadata=metadata 75 | ) 76 | 77 | async def send( 78 | self, 79 | request: message_utils.DmEnvRpcRequest) -> message_utils.DmEnvRpcResponse: 80 | """Sends `request` to the dm_env_rpc server and returns a response future. 81 | 82 | The request should be an instance of one of the dm_env_rpc Request messages, 83 | such as CreateWorldRequest. Based on the type the correct payload for the 84 | EnvironmentRequest will be constructed and sent to the dm_env_rpc server. 85 | 86 | Returns an awaitable future to retrieve the response. 87 | 88 | Args: 89 | request: An instance of a dm_env_rpc Request type, such as 90 | CreateWorldRequest. 91 | 92 | Returns: 93 | An asyncio Future which can be awaited to retrieve the response from the 94 | dm_env_rpc server returned for the given RPC call, unwrapped from the 95 | EnvironmentStream message. For instance if `request` had type 96 | `CreateWorldRequest` this returns a message of type `CreateWorldResponse`. 97 | 98 | Raises: 99 | DmEnvRpcError: The dm_env_rpc server responded to the request with an 100 | error. 101 | ValueError: The dm_env_rpc server responded to the request with an 102 | unexpected response message. 103 | """ 104 | environment_request, field_name = ( 105 | message_utils.pack_environment_request(request)) 106 | if self._stream is None: 107 | raise ValueError('Cannot send request after stream is closed.') 108 | await self._stream.write(environment_request) 109 | return message_utils.unpack_environment_response(await self._stream.read(), 110 | field_name) 111 | 112 | def close(self): 113 | """Closes the connection. Call when the connection is no longer needed.""" 114 | if self._stream: 115 | self._stream = None 116 | 117 | def __exit__(self, *args, **kwargs): 118 | self.close() 119 | 120 | def __enter__(self): 121 | return self 122 | 123 | 124 | async def create_secure_async_channel_and_connect( 125 | server_address: str, 126 | credentials: grpc.ChannelCredentials = grpc.local_channel_credentials(), 127 | metadata: Optional[Metadata] = None, 128 | ) -> AsyncConnection: 129 | """Creates a secure async channel from address and credentials and connects. 130 | 131 | We allow the created channel to have un-bounded message lengths, to support 132 | large observations. 133 | 134 | Args: 135 | server_address: URI server address to connect to. 136 | credentials: gRPC credentials necessary to connect to the server. 137 | metadata: Optional sequence of 2-tuples, sent to the gRPC server as 138 | metadata. 139 | 140 | Returns: 141 | An instance of dm_env_rpc.AsyncConnection, where the async channel is closed 142 | upon the connection being closed. 143 | """ 144 | options = [('grpc.max_send_message_length', -1), 145 | ('grpc.max_receive_message_length', -1)] 146 | channel = grpc.aio.secure_channel(server_address, credentials, 147 | options=options) 148 | await channel.channel_ready() 149 | 150 | class _ConnectionWrapper(AsyncConnection): 151 | """Utility to ensure channel is closed when the connection is closed.""" 152 | 153 | def __init__(self, channel, metadata): 154 | super().__init__(channel=channel, metadata=metadata) 155 | self._channel = channel 156 | 157 | def __del__(self): 158 | self.close() 159 | 160 | def close(self): 161 | super().close() 162 | try: 163 | loop = asyncio.get_running_loop() 164 | except RuntimeError: 165 | loop = None 166 | if loop and loop.is_running(): 167 | return asyncio.ensure_future(self._channel.close()) 168 | else: 169 | return asyncio.run(self._channel.close()) 170 | 171 | return _ConnectionWrapper(channel=channel, metadata=metadata) 172 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/async_connection_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for AsyncConnection.""" 16 | 17 | import asyncio 18 | import contextlib 19 | import queue 20 | import unittest 21 | from unittest import mock 22 | 23 | from absl.testing import absltest 24 | import grpc 25 | 26 | from google.protobuf import any_pb2 27 | from google.protobuf import struct_pb2 28 | from google.rpc import status_pb2 29 | from dm_env_rpc.v1 import async_connection 30 | from dm_env_rpc.v1 import dm_env_rpc_pb2 31 | from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc 32 | from dm_env_rpc.v1 import error 33 | from dm_env_rpc.v1 import tensor_utils 34 | 35 | _CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest( 36 | settings={'foo': tensor_utils.pack_tensor('bar')}) 37 | _CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse() 38 | 39 | _BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest() 40 | _TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse( 41 | error=status_pb2.Status(message='A test error.')) 42 | 43 | _INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest( 44 | world_name='foo') 45 | _INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse( 46 | leave_world=dm_env_rpc_pb2.LeaveWorldResponse()) 47 | 48 | _EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request') 49 | _EXTENSION_RESPONSE = struct_pb2.Value(number_value=555) 50 | 51 | 52 | def _wrap_in_any(proto): 53 | any_proto = any_pb2.Any() 54 | any_proto.Pack(proto) 55 | return any_proto 56 | 57 | 58 | _REQUEST_RESPONSE_PAIRS = { 59 | dm_env_rpc_pb2.EnvironmentRequest( 60 | create_world=_CREATE_REQUEST).SerializeToString(): 61 | dm_env_rpc_pb2.EnvironmentResponse(create_world=_CREATE_RESPONSE), 62 | dm_env_rpc_pb2.EnvironmentRequest( 63 | create_world=_BAD_CREATE_REQUEST).SerializeToString(): 64 | _TEST_ERROR, 65 | dm_env_rpc_pb2.EnvironmentRequest( 66 | extension=_wrap_in_any(_EXTENSION_REQUEST)).SerializeToString(): 67 | dm_env_rpc_pb2.EnvironmentResponse( 68 | extension=_wrap_in_any(_EXTENSION_RESPONSE)), 69 | dm_env_rpc_pb2.EnvironmentRequest( 70 | destroy_world=_INCORRECT_RESPONSE_TEST_MSG).SerializeToString(): 71 | _INCORRECT_RESPONSE, 72 | } 73 | 74 | 75 | def _process(metadata: async_connection.Metadata) -> grpc.aio.StreamStreamCall: 76 | requests = queue.Queue() 77 | 78 | async def _write(request): 79 | requests.put(request) 80 | 81 | async def _read(): 82 | request = requests.get() 83 | return _REQUEST_RESPONSE_PAIRS.get(request.SerializeToString(), _TEST_ERROR) 84 | 85 | mock_stream = mock.create_autospec(grpc.aio.StreamStreamCall) 86 | mock_stream.write = _write 87 | mock_stream.read = _read 88 | mock_stream.metadata = metadata 89 | return mock_stream 90 | 91 | 92 | @contextlib.contextmanager 93 | def _create_mock_async_channel(): 94 | """Mocks out gRPC and returns a channel to be passed to Connection.""" 95 | with mock.patch.object(async_connection, 'dm_env_rpc_pb2_grpc') as mock_grpc: 96 | mock_stub_class = mock.create_autospec(dm_env_rpc_pb2_grpc.EnvironmentStub) 97 | mock_stub_class.Process = _process 98 | mock_grpc.EnvironmentStub.return_value = mock_stub_class 99 | yield mock.MagicMock() 100 | 101 | 102 | class AsyncConnectionAsyncTests(unittest.IsolatedAsyncioTestCase): 103 | 104 | async def test_create(self): 105 | with _create_mock_async_channel() as mock_channel: 106 | with async_connection.AsyncConnection(mock_channel) as connection: 107 | response = await connection.send(_CREATE_REQUEST) 108 | self.assertEqual(_CREATE_RESPONSE, response) 109 | 110 | async def test_send_error_after_close(self): 111 | with _create_mock_async_channel() as mock_channel: 112 | with async_connection.AsyncConnection(mock_channel) as connection: 113 | connection.close() 114 | with self.assertRaisesRegex(ValueError, 'stream is closed'): 115 | await connection.send(_CREATE_REQUEST) 116 | 117 | async def test_error(self): 118 | with _create_mock_async_channel() as mock_channel: 119 | with async_connection.AsyncConnection(mock_channel) as connection: 120 | with self.assertRaisesRegex(error.DmEnvRpcError, 'test error'): 121 | await connection.send(_BAD_CREATE_REQUEST) 122 | 123 | async def test_extension(self): 124 | with _create_mock_async_channel() as mock_channel: 125 | with async_connection.AsyncConnection(mock_channel) as connection: 126 | request = any_pb2.Any() 127 | request.Pack(_EXTENSION_REQUEST) 128 | response = await connection.send(request) 129 | expected_response = any_pb2.Any() 130 | expected_response.Pack(_EXTENSION_RESPONSE) 131 | self.assertEqual(expected_response, response) 132 | 133 | async def test_incorrect_response(self): 134 | with _create_mock_async_channel() as mock_channel: 135 | with async_connection.AsyncConnection(mock_channel) as connection: 136 | with self.assertRaisesRegex(ValueError, 'Unexpected response message'): 137 | await connection.send(_INCORRECT_RESPONSE_TEST_MSG) 138 | 139 | def test_with_metadata(self): 140 | expected_metadata = (('key', 'value'),) 141 | with mock.patch.object(async_connection, 142 | 'dm_env_rpc_pb2_grpc') as mock_grpc: 143 | mock_stub_class = mock.MagicMock() 144 | mock_grpc.EnvironmentStub.return_value = mock_stub_class 145 | _ = async_connection.AsyncConnection( 146 | mock.MagicMock(), metadata=expected_metadata) 147 | mock_stub_class.Process.assert_called_with( 148 | metadata=expected_metadata) 149 | 150 | @mock.patch.object(grpc.aio, 'secure_channel') 151 | async def test_create_secure_channel_and_connect_context( 152 | self, mock_secure_channel): 153 | 154 | mock_async_channel = mock.MagicMock() 155 | mock_async_channel.channel_ready = absltest.mock.AsyncMock() 156 | mock_async_channel.close = absltest.mock.AsyncMock() 157 | mock_secure_channel.return_value = mock_async_channel 158 | 159 | with await async_connection.create_secure_async_channel_and_connect( 160 | 'valid_address') as connection: 161 | self.assertIsNotNone(connection) 162 | 163 | await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) 164 | 165 | mock_async_channel.close.assert_called_once() 166 | mock_async_channel.channel_ready.assert_called_once() 167 | mock_secure_channel.assert_called_once() 168 | 169 | 170 | class AsyncConnectionSyncTests(absltest.TestCase): 171 | 172 | @absltest.mock.patch.object(grpc.aio, 'secure_channel') 173 | def test_create_secure_channel_and_connect_context(self, mock_secure_channel): 174 | 175 | mock_async_channel = absltest.mock.MagicMock() 176 | mock_async_channel.channel_ready = absltest.mock.AsyncMock() 177 | mock_async_channel.close = absltest.mock.AsyncMock() 178 | mock_secure_channel.return_value = mock_async_channel 179 | 180 | asyncio.set_event_loop(asyncio.new_event_loop()) 181 | loop = asyncio.get_event_loop() 182 | 183 | connection_task = asyncio.ensure_future( 184 | async_connection.create_secure_async_channel_and_connect( 185 | 'valid_address')) 186 | connection = loop.run_until_complete(connection_task) 187 | 188 | loop.stop() 189 | asyncio.set_event_loop(None) 190 | 191 | connection.close() 192 | 193 | mock_async_channel.close.assert_called_once() 194 | mock_async_channel.channel_ready.assert_called_once() 195 | mock_secure_channel.assert_called_once() 196 | 197 | 198 | if __name__ == '__main__': 199 | absltest.main() 200 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/compliance/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Compliance test base classes for dm_env_rpc.""" 16 | from dm_env_rpc.v1.compliance import create_destroy_world 17 | from dm_env_rpc.v1.compliance import join_leave_world 18 | from dm_env_rpc.v1.compliance import reset 19 | from dm_env_rpc.v1.compliance import reset_world 20 | from dm_env_rpc.v1.compliance import step 21 | 22 | CreateDestroyWorld = create_destroy_world.CreateDestroyWorld 23 | JoinLeaveWorld = join_leave_world.JoinLeaveWorld 24 | Reset = reset.Reset 25 | ResetWorld = reset_world.ResetWorld 26 | Step = step.Step 27 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/compliance/create_destroy_world.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A base class for CreateWorld and DestroyWorld tests for a server.""" 16 | import abc 17 | 18 | from absl.testing import absltest 19 | 20 | from dm_env_rpc.v1 import dm_env_rpc_pb2 21 | from dm_env_rpc.v1 import error 22 | 23 | 24 | class CreateDestroyWorld(absltest.TestCase, metaclass=abc.ABCMeta): 25 | """A base class for `CreateWorld` and `DestroyWorld` compliance tests.""" 26 | 27 | @abc.abstractproperty 28 | def required_world_settings(self): 29 | """A string to Tensor mapping of the minimum set of required settings.""" 30 | pass 31 | 32 | @abc.abstractproperty 33 | def invalid_world_settings(self): 34 | """World creation settings which are invalid in some way.""" 35 | pass 36 | 37 | @abc.abstractproperty 38 | def has_multiple_world_support(self): 39 | """Does the server support creating more than one world?""" 40 | pass 41 | 42 | @abc.abstractproperty 43 | def connection(self): 44 | """An instance of dm_env_rpc's Connection.""" 45 | pass 46 | 47 | def create_world(self, settings): 48 | """Returns the world name of the world created with the given settings.""" 49 | response = self.connection.send( 50 | dm_env_rpc_pb2.CreateWorldRequest(settings=settings)) 51 | return response.world_name 52 | 53 | def destroy_world(self, world_name): 54 | """Destroys the world named `world_name`.""" 55 | if world_name is not None: 56 | self.connection.send( 57 | dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) 58 | 59 | # pylint: disable=missing-docstring 60 | def test_can_create_and_destroy_world(self): 61 | # If this doesn't raise an exception the test passes. 62 | world_name = self.create_world(self.required_world_settings) 63 | self.destroy_world(world_name) 64 | 65 | def test_cannot_create_world_with_less_than_required_settings(self): 66 | settings = self.required_world_settings 67 | 68 | for name, _ in settings.items(): 69 | sans_setting = dict(settings) 70 | del sans_setting[name] 71 | message = f'world was created without required setting "{name}"' 72 | with self.assertRaises(error.DmEnvRpcError, msg=message): 73 | self.create_world(sans_setting) 74 | 75 | def test_cannot_create_world_with_invalid_settings(self): 76 | settings = self.required_world_settings 77 | invalid_settings = self.invalid_world_settings 78 | for name, tensor in invalid_settings.items(): 79 | message = f'world was created with invalid setting "{name}"' 80 | with self.assertRaises(error.DmEnvRpcError, msg=message): 81 | self.create_world({name: tensor, **settings}) 82 | 83 | def test_world_name_is_unique(self): 84 | if not self.has_multiple_world_support: 85 | return 86 | world1_name = None 87 | world2_name = None 88 | try: 89 | world1_name = self.create_world(self.required_world_settings) 90 | world2_name = self.create_world(self.required_world_settings) 91 | self.assertIsNotNone(world1_name) 92 | self.assertIsNotNone(world2_name) 93 | self.assertNotEqual(world1_name, world2_name) 94 | finally: 95 | self.destroy_world(world1_name) 96 | self.destroy_world(world2_name) 97 | 98 | def test_cannot_destroy_uncreated_world(self): 99 | with self.assertRaises(error.DmEnvRpcError): 100 | self.destroy_world('foo') 101 | # pylint: enable=missing-docstring 102 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/compliance/join_leave_world.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A base class for JoinWorld and LeaveWord tests for a server.""" 16 | import abc 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | from dm_env_rpc.v1 import dm_env_rpc_pb2 22 | from dm_env_rpc.v1 import error 23 | from dm_env_rpc.v1 import tensor_spec_utils 24 | 25 | 26 | def _find_duplicates(iterable): 27 | """Returns a list of duplicate entries found in `iterable`.""" 28 | duplicates = [] 29 | seen = set() 30 | for item in iterable: 31 | if item in seen: 32 | duplicates.append(item) 33 | else: 34 | seen.add(item) 35 | return duplicates 36 | 37 | 38 | def _check_tensor_spec(tensor_spec): 39 | """Raises an error if the given `tensor_spec` is internally inconsistent.""" 40 | if np.sum(np.asarray(tensor_spec.shape) < 0) > 1: 41 | raise ValueError( 42 | f'"{tensor_spec.name}" has shape {tensor_spec.shape} which has more ' 43 | 'than one negative element.') 44 | min_type = tensor_spec.min and tensor_spec.min.WhichOneof('payload') 45 | max_type = tensor_spec.max and tensor_spec.max.WhichOneof('payload') 46 | if min_type or max_type: 47 | _ = tensor_spec_utils.bounds(tensor_spec) 48 | 49 | 50 | class JoinLeaveWorld(absltest.TestCase, metaclass=abc.ABCMeta): 51 | """A base class for `JoinWorld` and `LeaveWorld` compliance tests.""" 52 | 53 | @property 54 | def required_join_settings(self): 55 | """A dict of required settings for a Join World call.""" 56 | return {} 57 | 58 | @property 59 | def invalid_join_settings(self): 60 | """A list of dicts of Join World settings which are invalid in some way.""" 61 | return {} 62 | 63 | @abc.abstractproperty 64 | def world_name(self): 65 | """A string of the world name of an already created world.""" 66 | pass 67 | 68 | @property 69 | def invalid_world_name(self): 70 | """A string which doesn't correspond to any valid world_name.""" 71 | return 'invalid_world_name' 72 | 73 | @property 74 | @abc.abstractmethod 75 | def connection(self): 76 | """An instance of dm_env_rpc's Connection.""" 77 | pass 78 | 79 | def tearDown(self): 80 | super().tearDown() 81 | try: 82 | self.leave_world() 83 | finally: 84 | pass 85 | 86 | def join_world(self, **kwargs): 87 | """Joins the world and returns the spec.""" 88 | response = self.connection.send(dm_env_rpc_pb2.JoinWorldRequest(**kwargs)) 89 | return response.specs 90 | 91 | def leave_world(self): 92 | """Leaves currently joined world, if any.""" 93 | self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) 94 | 95 | # pylint: disable=missing-docstring 96 | def test_can_join(self): 97 | self.join_world( 98 | world_name=self.world_name, settings=self.required_join_settings) 99 | # Success if there's no error raised. 100 | 101 | def test_cannot_join_with_wrong_world_name(self): 102 | with self.assertRaises(error.DmEnvRpcError): 103 | self.join_world(world_name=self.invalid_world_name) 104 | 105 | def test_cannot_join_world_with_invalid_settings(self): 106 | settings = self.required_join_settings 107 | for name, tensor in self.invalid_join_settings.items(): 108 | with self.assertRaises(error.DmEnvRpcError): 109 | self.join_world( 110 | world_name=self.world_name, settings={ 111 | name: tensor, 112 | **settings 113 | }) 114 | 115 | def test_cannot_join_world_twice(self): 116 | self.join_world( 117 | world_name=self.world_name, settings=self.required_join_settings) 118 | with self.assertRaises(error.DmEnvRpcError): 119 | self.join_world( 120 | world_name=self.world_name, settings=self.required_join_settings) 121 | 122 | def test_action_specs_have_unique_names(self): 123 | specs = self.join_world( 124 | world_name=self.world_name, settings=self.required_join_settings) 125 | self.assertEmpty(_find_duplicates( 126 | spec.name for spec in specs.actions.values())) 127 | 128 | def test_action_specs_for_consistency(self): 129 | specs = self.join_world( 130 | world_name=self.world_name, settings=self.required_join_settings) 131 | for action_spec in specs.actions.values(): 132 | _check_tensor_spec(action_spec) 133 | 134 | def test_observation_specs_have_unique_names(self): 135 | specs = self.join_world( 136 | world_name=self.world_name, settings=self.required_join_settings) 137 | self.assertEmpty(_find_duplicates( 138 | spec.name for spec in specs.observations.values())) 139 | 140 | def test_observation_specs_for_consistency(self): 141 | specs = self.join_world( 142 | world_name=self.world_name, settings=self.required_join_settings) 143 | for observation_spec in specs.observations.values(): 144 | _check_tensor_spec(observation_spec) 145 | 146 | def test_can_leave_world_if_not_joined(self): 147 | self.leave_world() 148 | # Success if there's no error raised. 149 | 150 | def test_can_leave_world_after_joining(self): 151 | self.join_world( 152 | world_name=self.world_name, settings=self.required_join_settings) 153 | self.leave_world() 154 | # Success if there's no error raised. 155 | 156 | def test_can_rejoin_world_after_leaving(self): 157 | self.join_world( 158 | world_name=self.world_name, settings=self.required_join_settings) 159 | self.leave_world() 160 | self.join_world( 161 | world_name=self.world_name, settings=self.required_join_settings) 162 | # Success if there's no error raised. 163 | 164 | # pylint: enable=missing-docstring 165 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/compliance/reset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A base class for Reset tests for a server.""" 16 | 17 | import abc 18 | 19 | from absl.testing import absltest 20 | 21 | from dm_env_rpc.v1 import dm_env_rpc_pb2 22 | from dm_env_rpc.v1 import error 23 | 24 | 25 | class Reset(absltest.TestCase, metaclass=abc.ABCMeta): 26 | """A base class for dm_env_rpc `Reset` compliance tests.""" 27 | 28 | @property 29 | @abc.abstractmethod 30 | def connection(self): 31 | """An instance of dm_env_rpc's Connection already joined to a world.""" 32 | pass 33 | 34 | @property 35 | def required_reset_settings(self): 36 | return {} 37 | 38 | @abc.abstractmethod 39 | def join_world(self): 40 | """Joins a world, returning the specs.""" 41 | pass 42 | 43 | def reset(self): 44 | """Resets the environment, returning the specs.""" 45 | return self.connection.send(dm_env_rpc_pb2.ResetRequest( 46 | settings=self.required_reset_settings)).specs 47 | 48 | # pylint: disable=missing-docstring 49 | def test_reset_resends_the_specs(self): 50 | join_specs = self.join_world() 51 | specs = self.reset() 52 | self.assertEqual(join_specs, specs) 53 | 54 | def test_cannot_reset_if_not_joined_to_world(self): 55 | with self.assertRaises(error.DmEnvRpcError): 56 | self.reset() 57 | 58 | def test_can_reset_multiple_times(self): 59 | join_specs = self.join_world() 60 | self.reset() 61 | specs = self.reset() 62 | self.assertEqual(join_specs, specs) 63 | # pylint: enable=missing-docstring 64 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/compliance/reset_world.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A base class for ResetWorld tests for a server.""" 16 | 17 | import abc 18 | 19 | from absl.testing import absltest 20 | 21 | from dm_env_rpc.v1 import dm_env_rpc_pb2 22 | from dm_env_rpc.v1 import error 23 | 24 | 25 | class ResetWorld(absltest.TestCase, metaclass=abc.ABCMeta): 26 | """A base class for dm_env_rpc `ResetWorld` compliance tests.""" 27 | 28 | @property 29 | @abc.abstractmethod 30 | def connection(self): 31 | """An instance of dm_env_rpc's Connection already joined to a world.""" 32 | pass 33 | 34 | @property 35 | def required_reset_world_settings(self): 36 | """Settings necessary to pass to ResetWorld.""" 37 | return {} 38 | 39 | @property 40 | def required_join_world_settings(self): 41 | """Settings necessary to pass to JoinWorld.""" 42 | return {} 43 | 44 | @property 45 | def invalid_world_name(self): 46 | """The name of a world which doesn't exist.""" 47 | return 'invalid_world_name' 48 | 49 | @property 50 | @abc.abstractmethod 51 | def world_name(self): 52 | """The name of the world to attempt to call ResetWorld on.""" 53 | return '' 54 | 55 | def join_world(self): 56 | """Joins the world to call ResetWorld on.""" 57 | self.connection.send(dm_env_rpc_pb2.JoinWorldRequest( 58 | world_name=self.world_name, settings=self.required_join_world_settings)) 59 | 60 | def reset_world(self, world_name): 61 | """Resets the world.""" 62 | self.connection.send(dm_env_rpc_pb2.ResetWorldRequest( 63 | world_name=world_name, settings=self.required_reset_world_settings)) 64 | 65 | def leave_world(self): 66 | """Leaves the world.""" 67 | self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) 68 | 69 | # pylint: disable=missing-docstring 70 | def test_cannot_reset_invalid_world(self): 71 | with self.assertRaises(error.DmEnvRpcError): 72 | self.reset_world(self.invalid_world_name) 73 | 74 | def test_can_reset_world_not_joined_to(self): 75 | self.reset_world(self.world_name) 76 | # If there are no errors the test passes. 77 | 78 | def test_can_reset_world_when_joined_to_it(self): 79 | try: 80 | self.join_world() 81 | self.reset_world(self.world_name) 82 | # If there are no errors the test passes. 83 | finally: 84 | self.leave_world() 85 | # pylint: enable=missing-docstring 86 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/connection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A helper class to manage a connection to a dm_env_rpc server. 16 | 17 | This helper class allows sending the Request message types and receiving the 18 | Response message types without wrapping in an EnvironmentRequest or unwrapping 19 | from an EnvironmentResponse. It also turns error messages in to exceptions. 20 | 21 | For most calls (such as create, join, etc.): 22 | 23 | with connection.Connection(grpc_channel) as channel: 24 | create_response = channel.send(dm_env_rpc_pb2.CreateRequest(settings={ 25 | 'players': 5 26 | }) 27 | 28 | For the `extension` message type, you must send an Any proto and you'll get back 29 | an Any proto. It is up to you to wrap and unwrap these to concrete proto types 30 | that you know how to handle. 31 | 32 | with connection.Connection(grpc_channel) as channel: 33 | request = struct_pb2.Struct() 34 | ... 35 | request_any = any_pb2.Any() 36 | request_any.Pack(request) 37 | response_any = channel.send(request_any) 38 | response = my_type_pb2.MyType() 39 | response_any.Unpack(response) 40 | 41 | 42 | Any errors encountered in the EnvironmentResponse are turned into Python 43 | exceptions, so explicit error handling code isn't needed per call. 44 | """ 45 | 46 | import queue 47 | from typing import Optional, Protocol, Sequence, Tuple 48 | import grpc 49 | 50 | from dm_env_rpc.v1 import dm_env_rpc_pb2 51 | from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc 52 | from dm_env_rpc.v1 import message_utils 53 | 54 | Metadata = Sequence[Tuple[str, str]] 55 | 56 | 57 | class ConnectionType(Protocol): 58 | """Connection protocol definition for interacting with dm_env_rpc servers.""" 59 | 60 | def send( 61 | self, 62 | request: message_utils.DmEnvRpcRequest) -> message_utils.DmEnvRpcResponse: 63 | """Blocking call to send and receive a message from a dm_env_rpc server.""" 64 | 65 | def close(self): 66 | """Closes the connection. Call when the connection is no longer needed.""" 67 | 68 | 69 | class StreamReaderWriter(object): 70 | """Helper class for reading/writing gRPC streams.""" 71 | 72 | def __init__(self, 73 | stub: dm_env_rpc_pb2_grpc.EnvironmentStub, 74 | metadata: Optional[Metadata] = None): 75 | self._requests = queue.Queue() 76 | self._stream = stub.Process( 77 | iter(self._requests.get, None), metadata=metadata) 78 | 79 | def write(self, request: dm_env_rpc_pb2.EnvironmentRequest): 80 | """Asynchronously sends `request` to the stream.""" 81 | self._requests.put(request) 82 | 83 | def read(self) -> dm_env_rpc_pb2.EnvironmentResponse: 84 | """Returns the response from stream. Blocking.""" 85 | return next(self._stream) 86 | 87 | 88 | class Connection(object): 89 | """A helper class for interacting with dm_env_rpc servers.""" 90 | 91 | def __init__(self, 92 | channel: grpc.Channel, 93 | metadata: Optional[Metadata] = None): 94 | """Manages a connection to a dm_env_rpc server. 95 | 96 | Args: 97 | channel: A grpc channel to connect to the dm_env_rpc server over. 98 | metadata: Optional sequence of 2-tuples, sent to the gRPC server as 99 | metadata. 100 | """ 101 | self._stream = StreamReaderWriter( 102 | dm_env_rpc_pb2_grpc.EnvironmentStub(channel), metadata) 103 | 104 | def send( 105 | self, 106 | request: message_utils.DmEnvRpcRequest) -> message_utils.DmEnvRpcResponse: 107 | """Sends the given request to the dm_env_rpc server and returns the response. 108 | 109 | The request should be an instance of one of the dm_env_rpc Request messages, 110 | such as CreateWorldRequest. Based on the type the correct payload for the 111 | EnvironmentRequest will be constructed and sent to the dm_env_rpc server. 112 | 113 | Blocks until the server sends back its response. 114 | 115 | Args: 116 | request: An instance of a dm_env_rpc Request type, such as 117 | CreateWorldRequest. 118 | 119 | Returns: 120 | The response the dm_env_rpc server returned for the given RPC call, 121 | unwrapped from the EnvironmentStream message. For instance if `request` 122 | had type `CreateWorldRequest` this returns a message of type 123 | `CreateWorldResponse`. 124 | 125 | Raises: 126 | DmEnvRpcError: The dm_env_rpc server responded to the request with an 127 | error. 128 | ValueError: The dm_env_rpc server responded to the request with an 129 | unexpected response message. 130 | """ 131 | environment_request, field_name = ( 132 | message_utils.pack_environment_request(request)) 133 | if self._stream is None: 134 | raise ValueError('Cannot send request after stream is closed.') 135 | self._stream.write(environment_request) 136 | return message_utils.unpack_environment_response(self._stream.read(), 137 | field_name) 138 | 139 | def close(self): 140 | """Closes the connection. Call when the connection is no longer needed.""" 141 | if self._stream: 142 | self._stream = None 143 | 144 | def __exit__(self, *args, **kwargs): 145 | self.close() 146 | 147 | def __enter__(self): 148 | return self 149 | 150 | 151 | def create_secure_channel_and_connect( 152 | server_address: str, 153 | credentials: grpc.ChannelCredentials = grpc.local_channel_credentials(), 154 | timeout: Optional[float] = None, 155 | metadata: Optional[Metadata] = None, 156 | ) -> Connection: 157 | """Creates a secure channel from server address and credentials and connects. 158 | 159 | We allow the created channel to have un-bounded message lengths, to support 160 | large observations. 161 | 162 | Args: 163 | server_address: URI server address to connect to. 164 | credentials: gRPC credentials necessary to connect to the server. 165 | timeout: Optional timeout in seconds to wait for channel to be ready. 166 | Default to waiting indefinitely. 167 | metadata: Optional sequence of 2-tuples, sent to the gRPC server as 168 | metadata. 169 | 170 | Returns: 171 | An instance of dm_env_rpc.Connection, where the channel is close upon the 172 | connection being closed. 173 | """ 174 | options = [('grpc.max_send_message_length', -1), 175 | ('grpc.max_receive_message_length', -1)] 176 | channel = grpc.secure_channel(server_address, credentials, options=options) 177 | grpc.channel_ready_future(channel).result(timeout) 178 | 179 | class _ConnectionWrapper(Connection): 180 | """Utility to ensure channel is closed when the connection is closed.""" 181 | 182 | def __init__(self, channel, metadata): 183 | super().__init__(channel=channel, metadata=metadata) 184 | self._channel = channel 185 | 186 | def __del__(self): 187 | self.close() 188 | 189 | def close(self): 190 | super().close() 191 | self._channel.close() 192 | 193 | return _ConnectionWrapper(channel=channel, metadata=metadata) 194 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/connection_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for Connection.""" 16 | 17 | import contextlib 18 | from unittest import mock 19 | 20 | from absl.testing import absltest 21 | import grpc 22 | 23 | from google.protobuf import any_pb2 24 | from google.protobuf import struct_pb2 25 | from google.rpc import status_pb2 26 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection 27 | from dm_env_rpc.v1 import dm_env_rpc_pb2 28 | from dm_env_rpc.v1 import error 29 | from dm_env_rpc.v1 import tensor_utils 30 | 31 | _CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest( 32 | settings={'foo': tensor_utils.pack_tensor('bar')}) 33 | _CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse() 34 | 35 | _BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest() 36 | _TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse( 37 | error=status_pb2.Status(message='A test error.')) 38 | 39 | _INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest( 40 | world_name='foo') 41 | _INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse( 42 | leave_world=dm_env_rpc_pb2.LeaveWorldResponse()) 43 | 44 | _EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request') 45 | _EXTENSION_RESPONSE = struct_pb2.Value(number_value=555) 46 | 47 | 48 | def _wrap_in_any(proto): 49 | any_proto = any_pb2.Any() 50 | any_proto.Pack(proto) 51 | return any_proto 52 | 53 | 54 | _REQUEST_RESPONSE_PAIRS = { 55 | dm_env_rpc_pb2.EnvironmentRequest( 56 | create_world=_CREATE_REQUEST).SerializeToString(): 57 | dm_env_rpc_pb2.EnvironmentResponse(create_world=_CREATE_RESPONSE), 58 | dm_env_rpc_pb2.EnvironmentRequest( 59 | create_world=_BAD_CREATE_REQUEST).SerializeToString(): 60 | _TEST_ERROR, 61 | dm_env_rpc_pb2.EnvironmentRequest( 62 | extension=_wrap_in_any(_EXTENSION_REQUEST)).SerializeToString(): 63 | dm_env_rpc_pb2.EnvironmentResponse( 64 | extension=_wrap_in_any(_EXTENSION_RESPONSE)), 65 | dm_env_rpc_pb2.EnvironmentRequest( 66 | destroy_world=_INCORRECT_RESPONSE_TEST_MSG).SerializeToString(): 67 | _INCORRECT_RESPONSE, 68 | } 69 | 70 | 71 | def _process(request_iterator, **kwargs): 72 | del kwargs 73 | for request in request_iterator: 74 | yield _REQUEST_RESPONSE_PAIRS.get(request.SerializeToString(), _TEST_ERROR) 75 | 76 | 77 | @contextlib.contextmanager 78 | def _create_mock_channel(): 79 | """Mocks out gRPC and returns a channel to be passed to Connection.""" 80 | with mock.patch.object(dm_env_rpc_connection, 81 | 'dm_env_rpc_pb2_grpc') as mock_grpc: 82 | mock_stub_class = mock.MagicMock() 83 | mock_stub_class.Process = _process 84 | mock_grpc.EnvironmentStub.return_value = mock_stub_class 85 | yield mock.MagicMock() 86 | 87 | 88 | class ConnectionTests(absltest.TestCase): 89 | 90 | def test_create(self): 91 | with _create_mock_channel() as mock_channel: 92 | with dm_env_rpc_connection.Connection(mock_channel) as connection: 93 | response = connection.send(_CREATE_REQUEST) 94 | self.assertEqual(_CREATE_RESPONSE, response) 95 | 96 | def test_send_error_after_close(self): 97 | with _create_mock_channel() as mock_channel: 98 | with dm_env_rpc_connection.Connection(mock_channel) as connection: 99 | connection.close() 100 | with self.assertRaisesRegex(ValueError, 'stream is closed'): 101 | connection.send(_CREATE_REQUEST) 102 | 103 | def test_error(self): 104 | with _create_mock_channel() as mock_channel: 105 | with dm_env_rpc_connection.Connection(mock_channel) as connection: 106 | with self.assertRaisesRegex(error.DmEnvRpcError, 'test error'): 107 | connection.send(_BAD_CREATE_REQUEST) 108 | 109 | def test_extension(self): 110 | with _create_mock_channel() as mock_channel: 111 | with dm_env_rpc_connection.Connection(mock_channel) as connection: 112 | request = any_pb2.Any() 113 | request.Pack(_EXTENSION_REQUEST) 114 | response = connection.send(request) 115 | expected_response = any_pb2.Any() 116 | expected_response.Pack(_EXTENSION_RESPONSE) 117 | self.assertEqual(expected_response, response) 118 | 119 | @mock.patch.object(grpc, 'secure_channel') 120 | @mock.patch.object(grpc, 'channel_ready_future') 121 | def test_create_secure_channel_and_connect(self, mock_channel_ready, 122 | mock_secure_channel): 123 | mock_channel = mock.MagicMock() 124 | mock_secure_channel.return_value = mock_channel 125 | 126 | self.assertIsNotNone( 127 | dm_env_rpc_connection.create_secure_channel_and_connect( 128 | 'valid_address', grpc.local_channel_credentials())) 129 | 130 | mock_channel_ready.assert_called_once_with(mock_channel) 131 | mock_secure_channel.assert_called_once() 132 | mock_channel.close.assert_called_once() 133 | 134 | @mock.patch.object(grpc, 'secure_channel') 135 | @mock.patch.object(grpc, 'channel_ready_future') 136 | def test_create_secure_channel_and_connect_context(self, mock_channel_ready, 137 | mock_secure_channel): 138 | mock_channel = mock.MagicMock() 139 | mock_secure_channel.return_value = mock_channel 140 | 141 | with dm_env_rpc_connection.create_secure_channel_and_connect( 142 | 'valid_address') as connection: 143 | self.assertIsNotNone(connection) 144 | 145 | mock_channel_ready.assert_called_once_with(mock_channel) 146 | mock_secure_channel.assert_called_once() 147 | mock_channel.close.assert_called_once() 148 | 149 | @mock.patch.object(grpc, 'secure_channel') 150 | @mock.patch.object(grpc, 'channel_ready_future') 151 | @mock.patch.object(dm_env_rpc_connection, 'StreamReaderWriter') 152 | def test_create_secure_channel_and_connect_metadata( 153 | self, mock_stream_writer, mock_channel_ready, mock_secure_channel 154 | ): 155 | mock_channel = mock.MagicMock() 156 | mock_secure_channel.return_value = mock_channel 157 | metadata = [('upstream', 'fake_address')] 158 | with dm_env_rpc_connection.create_secure_channel_and_connect( 159 | 'valid_address', metadata=metadata 160 | ) as connection: 161 | self.assertIsNotNone(connection) 162 | 163 | mock_channel_ready.assert_called_once_with(mock_channel) 164 | mock_secure_channel.assert_called_once() 165 | mock_channel.close.assert_called_once() 166 | mock_stream_writer.assert_called_once_with(mock.ANY, metadata) 167 | 168 | def test_create_secure_channel_and_connect_timeout(self): 169 | with self.assertRaises(grpc.FutureTimeoutError): 170 | dm_env_rpc_connection.create_secure_channel_and_connect( 171 | 'invalid_address', grpc.local_channel_credentials(), timeout=1.) 172 | 173 | def test_incorrect_response(self): 174 | with _create_mock_channel() as mock_channel: 175 | with dm_env_rpc_connection.Connection(mock_channel) as connection: 176 | with self.assertRaisesRegex(ValueError, 'Unexpected response message'): 177 | connection.send(_INCORRECT_RESPONSE_TEST_MSG) 178 | 179 | def test_with_metadata(self): 180 | expected_metadata = (('key', 'value'),) 181 | with mock.patch.object(dm_env_rpc_connection, 182 | 'dm_env_rpc_pb2_grpc') as mock_grpc: 183 | mock_stub_class = mock.MagicMock() 184 | mock_grpc.EnvironmentStub.return_value = mock_stub_class 185 | _ = dm_env_rpc_connection.Connection( 186 | mock.MagicMock(), metadata=expected_metadata) 187 | mock_stub_class.Process.assert_called_with( 188 | mock.ANY, metadata=expected_metadata) 189 | 190 | 191 | if __name__ == '__main__': 192 | absltest.main() 193 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/dm_env_flatten_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Python utilities for flattening and unflattening key-value mappings.""" 16 | 17 | from typing import Any, Dict, Mapping 18 | 19 | 20 | def flatten_dict( 21 | input_dict: Mapping[str, Any], 22 | separator: str, 23 | *, 24 | strict: bool = True, 25 | ) -> Dict[str, Any]: 26 | """Flattens mappings by joining sub-keys using the provided separator. 27 | 28 | Only non-empty, mapping types will be flattened. All other types are deemed 29 | leaf values. 30 | 31 | Args: 32 | input_dict: Mapping of key-value pairs to flatten. 33 | separator: Delimiter used to concatenate keys. 34 | strict: Whether to permit keys that already contain the separator. Setting 35 | this to False will cause unflattening to be ambiguous. 36 | 37 | Returns: 38 | Flattened dictionary of key-value pairs. 39 | 40 | Raises: 41 | ValueError: If the `input_dict` has a key that contains the separator 42 | string, and strict is set to False. 43 | """ 44 | result: Dict[str, Any] = {} 45 | for key, value in input_dict.items(): 46 | if strict and separator in key: 47 | raise ValueError( 48 | f"Can not safely flatten dictionary: key '{key}' already contains " 49 | f"the separator '{separator}'!" 50 | ) 51 | if isinstance(value, Mapping) and len(value): 52 | result.update({ 53 | f'{key}{separator}{sub_key}': sub_value 54 | for sub_key, sub_value in flatten_dict( 55 | value, separator, strict=strict).items() 56 | }) 57 | else: 58 | result[key] = value 59 | return result 60 | 61 | 62 | def unflatten_dict(input_dict: Mapping[str, Any], 63 | separator: str) -> Dict[str, Any]: 64 | """Unflatten dictionary using split keys to determine the structure. 65 | 66 | For each key, split based on the provided separator and create nested 67 | dictionary entry for each sub-key. 68 | 69 | Args: 70 | input_dict: Mapping of key-value pairs to un-flatten. 71 | separator: Delimiter used to split keys. 72 | 73 | Returns: 74 | Unflattened dictionary. 75 | 76 | Raises: 77 | ValueError: If a key, or it's constituent sub-keys already has a value. For 78 | instance, unflattening `{"foo": True, "foo.bar": "baz"}` will result in 79 | "foo" being set to both a dict and a Bool. 80 | """ 81 | result: Dict[str, Any] = {} 82 | for key, value in input_dict.items(): 83 | sub_keys = key.split(separator) 84 | sub_tree = result 85 | for sub_key in sub_keys[:-1]: 86 | sub_tree = sub_tree.setdefault(sub_key, {}) 87 | if not isinstance(sub_tree, Mapping): 88 | raise ValueError(f"Sub-tree '{sub_key}' has already been assigned a " 89 | f"leaf value {sub_tree}") 90 | 91 | if sub_keys[-1] in sub_tree: 92 | raise ValueError(f'Duplicate key {key}') 93 | sub_tree[sub_keys[-1]] = value 94 | return result 95 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/dm_env_flatten_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for dm_env_flatten_utils.""" 16 | 17 | from absl.testing import absltest 18 | from dm_env_rpc.v1 import dm_env_flatten_utils 19 | 20 | 21 | class FlattenUtilsTest(absltest.TestCase): 22 | 23 | def test_flatten(self): 24 | input_dict = { 25 | 'foo': { 26 | 'bar': 1, 27 | 'baz': False 28 | }, 29 | 'fiz': object(), 30 | } 31 | expected = { 32 | 'foo.bar': 1, 33 | 'foo.baz': False, 34 | 'fiz': object(), 35 | } 36 | self.assertSameElements(expected, 37 | dm_env_flatten_utils.flatten_dict(input_dict, '.')) 38 | 39 | def test_unflatten(self): 40 | input_dict = { 41 | 'foo.bar.baz': True, 42 | 'fiz.buz': 1, 43 | 'foo.baz': 'val', 44 | 'buz': {} 45 | } 46 | expected = { 47 | 'foo': { 48 | 'bar': { 49 | 'baz': True 50 | }, 51 | 'baz': 'val' 52 | }, 53 | 'fiz': { 54 | 'buz': 1 55 | }, 56 | 'buz': {}, 57 | } 58 | self.assertSameElements( 59 | expected, dm_env_flatten_utils.unflatten_dict(input_dict, '.')) 60 | 61 | def test_unflatten_different_separator(self): 62 | input_dict = {'foo::bar.baz': True, 'foo.bar::baz': 1} 63 | expected = {'foo': {'bar.baz': True}, 'foo.bar': {'baz': 1}} 64 | self.assertSameElements( 65 | expected, dm_env_flatten_utils.unflatten_dict(input_dict, '::')) 66 | 67 | def test_flatten_unflatten(self): 68 | input_output = { 69 | 'foo': { 70 | 'bar': 1, 71 | 'baz': False 72 | }, 73 | 'fiz': object(), 74 | } 75 | self.assertSameElements( 76 | input_output, 77 | dm_env_flatten_utils.unflatten_dict( 78 | dm_env_flatten_utils.flatten_dict(input_output, '.'), '.')) 79 | 80 | def test_flatten_with_key_containing_separator(self): 81 | input_dict = {'foo.bar': {'baz': 123}, 'bar': {'foo.baz': 456}} 82 | expected = {'foo.bar.baz': 123, 'bar.foo.baz': 456} 83 | 84 | self.assertSameElements( 85 | expected, 86 | dm_env_flatten_utils.flatten_dict(input_dict, '.', strict=False), 87 | ) 88 | 89 | def test_flatten_with_key_containing_separator_strict_raises_error(self): 90 | with self.assertRaisesRegex(ValueError, 'foo.bar'): 91 | dm_env_flatten_utils.flatten_dict({'foo.bar': True}, '.') 92 | 93 | def test_invalid_flattened_dict_raises_error(self): 94 | input_dict = dict(( 95 | ('foo.bar', True), 96 | ('foo', 'invalid_value_for_sub_key'), 97 | )) 98 | with self.assertRaisesRegex(ValueError, 'Duplicate key'): 99 | dm_env_flatten_utils.unflatten_dict(input_dict, '.') 100 | 101 | def test_sub_tree_has_value_raises_error(self): 102 | input_dict = dict(( 103 | ('branch', 'should_not_have_value'), 104 | ('branch.leaf', True), 105 | )) 106 | with self.assertRaisesRegex(ValueError, 107 | "Sub-tree 'branch' has already been assigned"): 108 | dm_env_flatten_utils.unflatten_dict(input_dict, '.') 109 | 110 | def test_empty_dict_values_flatten(self): 111 | input_dict = { 112 | 'foo': {}, 113 | 'bar': { 114 | 'baz': {} 115 | }, 116 | } 117 | expected = { 118 | 'foo': {}, 119 | 'bar.baz': {}, 120 | } 121 | self.assertSameElements(expected, 122 | dm_env_flatten_utils.flatten_dict(input_dict, '.')) 123 | 124 | 125 | if __name__ == '__main__': 126 | absltest.main() 127 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/dm_env_rpc.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // =========================================================================== 15 | syntax = "proto3"; 16 | 17 | package dm_env_rpc.v1; 18 | 19 | import "google/protobuf/any.proto"; 20 | import "google/rpc/status.proto"; 21 | 22 | // A potentially multi-dimensional array of data, laid out in row-major format. 23 | // Note, only one data channel should be used at a time. 24 | message Tensor { 25 | message Int8Array { 26 | bytes array = 1; 27 | } 28 | message Int32Array { 29 | repeated int32 array = 1; 30 | } 31 | message Int64Array { 32 | repeated int64 array = 1; 33 | } 34 | message Uint8Array { 35 | bytes array = 1; 36 | } 37 | message Uint32Array { 38 | repeated uint32 array = 1; 39 | } 40 | message Uint64Array { 41 | repeated uint64 array = 1; 42 | } 43 | message FloatArray { 44 | repeated float array = 1; 45 | } 46 | message DoubleArray { 47 | repeated double array = 1; 48 | } 49 | message BoolArray { 50 | repeated bool array = 1; 51 | } 52 | message StringArray { 53 | repeated string array = 1; 54 | } 55 | message ProtoArray { 56 | repeated google.protobuf.Any array = 1; 57 | } 58 | 59 | // The flattened tensor data. Data is laid out in row-major order. 60 | oneof payload { 61 | // LINT.IfChange(Tensor) 62 | FloatArray floats = 1; 63 | DoubleArray doubles = 2; 64 | Int8Array int8s = 3; 65 | Int32Array int32s = 4; 66 | Int64Array int64s = 5; 67 | Uint8Array uint8s = 6; 68 | Uint32Array uint32s = 7; 69 | Uint64Array uint64s = 8; 70 | BoolArray bools = 9; 71 | StringArray strings = 10; 72 | ProtoArray protos = 11; 73 | // LINT.ThenChange(:DataType) 74 | } 75 | 76 | // The dimensions of the repeated data fields. If empty, the data channel 77 | // will be treated as a scalar and expected to have exactly one element. 78 | // 79 | // If the payload has exactly one element, it will be repeated to fill the 80 | // shape. 81 | // 82 | // A negative element in a dimension indicates its size should be determined 83 | // based on the number of elements in the payload and the rest of the shape. 84 | // For instance, a shape of [-1, 5] means the shape is a matrix with 5 columns 85 | // and a variable number of rows. Only one element in the shape may be set to 86 | // a negative value. 87 | repeated int32 shape = 15; 88 | } 89 | 90 | // The data type of elements of a tensor. This must match the types in the 91 | // Tensor payload oneof. 92 | enum DataType { 93 | // This is the default value indicating no value was set. 94 | INVALID_DATA_TYPE = 0; 95 | 96 | // LINT.IfChange(DataType) 97 | FLOAT = 1; 98 | DOUBLE = 2; 99 | INT8 = 3; 100 | INT32 = 4; 101 | INT64 = 5; 102 | UINT8 = 6; 103 | UINT32 = 7; 104 | UINT64 = 8; 105 | BOOL = 9; 106 | STRING = 10; 107 | PROTO = 11; 108 | // LINT.ThenChange(:Tensor) 109 | } 110 | 111 | message TensorSpec { 112 | // A human-readable name describing this tensor. 113 | string name = 1; 114 | 115 | // The dimensionality of the tensor. See Tensor.shape for more information. 116 | repeated int32 shape = 2; 117 | 118 | // The data type of the elements in the tensor. 119 | DataType dtype = 3; 120 | 121 | // Value sub-message that defines the inclusive min/max bounds for numerical 122 | // types. 123 | message Value { 124 | oneof payload { 125 | Tensor.FloatArray floats = 9; 126 | Tensor.DoubleArray doubles = 10; 127 | Tensor.Int8Array int8s = 11; 128 | Tensor.Int32Array int32s = 12; 129 | Tensor.Int64Array int64s = 13; 130 | Tensor.Uint8Array uint8s = 14; 131 | Tensor.Uint32Array uint32s = 15; 132 | Tensor.Uint64Array uint64s = 16; 133 | } 134 | 135 | // Deprecated scalar Value fields. Please use array fields and create a 136 | // single element array of the relevant type. 137 | reserved 1 to 8; 138 | } 139 | 140 | // The minimum value that elements in the tensor can obtain. Inclusive. 141 | Value min = 4; // Optional 142 | 143 | // The maximum value that elements in the tensor can obtain. Inclusive. 144 | Value max = 5; // Optional 145 | } 146 | 147 | message CreateWorldRequest { 148 | // Settings to create the world with. This can define the level layout, the 149 | // number of agents, the goal or game mode, or other universal settings. 150 | // Agent-specific settings, such as anything which would change the action or 151 | // observation spec, should go in the JoinWorldRequest. 152 | map settings = 1; 153 | } 154 | message CreateWorldResponse { 155 | // The unique name for the world just created. 156 | string world_name = 1; 157 | } 158 | 159 | message ActionObservationSpecs { 160 | map actions = 1; 161 | 162 | map observations = 2; 163 | } 164 | 165 | message JoinWorldRequest { 166 | // The name of the world to join. 167 | string world_name = 1; 168 | 169 | // Agent-specific settings which define how to join the world, such as agent 170 | // name and class in an RPG. 171 | map settings = 2; 172 | } 173 | message JoinWorldResponse { 174 | ActionObservationSpecs specs = 1; 175 | } 176 | 177 | enum EnvironmentStateType { 178 | // This is the default value indicating no value was set. It should never be 179 | // sent or received. 180 | INVALID_ENVIRONMENT_STATE = 0; 181 | 182 | // The environment is currently in the middle of a sequence. 183 | RUNNING = 1; 184 | 185 | // The previously running sequence reached its natural conclusion. 186 | TERMINATED = 2; 187 | 188 | // The sequence was interrupted by a reset. 189 | INTERRUPTED = 3; 190 | } 191 | 192 | message StepRequest { 193 | // The actions to perform on the environment. If the environment is currently 194 | // in a non-RUNNING state, whether because the agent has just called 195 | // JoinWorld, the state from the last is StepResponse was TERMINATED or 196 | // INTERRUPTED, or a ResetRequest had previously been sent, the actions will 197 | // be ignored. 198 | map actions = 1; 199 | 200 | // Array of observations UIDs to return. If not set, no observations are 201 | // returned. 202 | repeated uint64 requested_observations = 2; 203 | } 204 | 205 | message StepResponse { 206 | // If state is not RUNNING, the action on the next StepRequest will be 207 | // ignored and the environment will transition to a RUNNING state. 208 | EnvironmentStateType state = 1; 209 | 210 | // The observations requested in `StepRequest`. Observations returned should 211 | // match the dimensionality and type specified in `specs.observations`. 212 | map observations = 2; 213 | } 214 | 215 | // The current sequence will be interrupted. The actions on the next call to 216 | // StepRequest will be ignored and a new sequence will begin. 217 | message ResetRequest { 218 | // Agents-specific settings to apply for the next sequence, such as changing 219 | // class in an RPG. 220 | map settings = 1; 221 | } 222 | message ResetResponse { 223 | ActionObservationSpecs specs = 1; 224 | } 225 | 226 | // All connected agents will have their next StepResponse return INTERRUPTED. 227 | message ResetWorldRequest { 228 | string world_name = 1; 229 | 230 | // World settings to apply for the next sequence, such as changing the map or 231 | // seed. 232 | map settings = 2; 233 | } 234 | message ResetWorldResponse {} 235 | 236 | message LeaveWorldRequest {} 237 | message LeaveWorldResponse {} 238 | 239 | message DestroyWorldRequest { 240 | string world_name = 1; 241 | } 242 | message DestroyWorldResponse {} 243 | 244 | message EnvironmentRequest { 245 | oneof payload { 246 | CreateWorldRequest create_world = 1; 247 | JoinWorldRequest join_world = 2; 248 | StepRequest step = 3; 249 | ResetRequest reset = 4; 250 | ResetWorldRequest reset_world = 5; 251 | LeaveWorldRequest leave_world = 6; 252 | DestroyWorldRequest destroy_world = 7; 253 | 254 | // If the environment supports a specialized request not covered above it 255 | // can be sent this way. 256 | // 257 | // Slot 15 is the last slot which can be encoded with one byte. See 258 | // https://developers.google.com/protocol-buffers/docs/proto3#assigning-field-numbers 259 | google.protobuf.Any extension = 15; 260 | } 261 | 262 | // Deprecated property requests. Please use properties extension for future 263 | // requests/responses. 264 | reserved 8 to 10; 265 | 266 | // Slot corresponds to `error` in the EnvironmentResponse. 267 | reserved 16; 268 | } 269 | 270 | message EnvironmentResponse { 271 | oneof payload { 272 | CreateWorldResponse create_world = 1; 273 | JoinWorldResponse join_world = 2; 274 | StepResponse step = 3; 275 | ResetResponse reset = 4; 276 | ResetWorldResponse reset_world = 5; 277 | LeaveWorldResponse leave_world = 6; 278 | DestroyWorldResponse destroy_world = 7; 279 | 280 | // If the environment supports a specialized response not covered above it 281 | // can be sent this way. 282 | // 283 | // Slot 15 is the last slot which can be encoded with one byte. See 284 | // https://developers.google.com/protocol-buffers/docs/proto3#assigning-field-numbers 285 | google.protobuf.Any extension = 15; 286 | 287 | google.rpc.Status error = 16; 288 | } 289 | 290 | // Deprecated property responses. Please use properties extension for future 291 | // requests/responses. 292 | reserved 8 to 10; 293 | } 294 | 295 | service Environment { 296 | // Process incoming environment requests. 297 | rpc Process(stream EnvironmentRequest) returns (stream EnvironmentResponse) {} 298 | } 299 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/dm_env_rpc_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for environment_stream.proto. 16 | 17 | These aren't for testing functionality (it's assumed protobufs work) but for 18 | testing/demonstrating how the protobufs would have to be used in code. 19 | """ 20 | 21 | from absl.testing import absltest 22 | from dm_env_rpc.v1 import dm_env_rpc_pb2 23 | 24 | 25 | class TensorTests(absltest.TestCase): 26 | 27 | def test_setting_tensor_data(self): 28 | tensor = dm_env_rpc_pb2.Tensor() 29 | tensor.floats.array[:] = [1, 2] 30 | 31 | def test_setting_tensor_data_with_wrong_type(self): 32 | tensor = dm_env_rpc_pb2.Tensor() 33 | with self.assertRaises(TypeError): 34 | tensor.floats.array[:] = ['hello!'] # pytype: disable=unsupported-operands 35 | 36 | def test_which_is_set(self): 37 | tensor = dm_env_rpc_pb2.Tensor() 38 | tensor.floats.array[:] = [1, 2] 39 | self.assertEqual('floats', tensor.WhichOneof('payload')) 40 | 41 | 42 | class TensorSpec(absltest.TestCase): 43 | 44 | def test_setting_spec(self): 45 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 46 | tensor_spec.name = 'Foo' 47 | tensor_spec.min.floats.array[:] = [0.0] 48 | tensor_spec.max.floats.array[:] = [0.0] 49 | tensor_spec.shape[:] = [2, 2] 50 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.FLOAT 51 | 52 | 53 | class JoinWorldResponse(absltest.TestCase): 54 | 55 | def test_setting_spec(self): 56 | response = dm_env_rpc_pb2.JoinWorldResponse() 57 | tensor_spec = response.specs.actions[1] 58 | tensor_spec.shape[:] = [1] 59 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.FLOAT 60 | 61 | 62 | if __name__ == '__main__': 63 | absltest.main() 64 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/dm_env_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utilities for interfacing dm_env and dm_env_rpc.""" 16 | 17 | from typing import Dict 18 | from dm_env import specs 19 | import numpy as np 20 | 21 | from dm_env_rpc.v1 import dm_env_rpc_pb2 22 | from dm_env_rpc.v1 import spec_manager as dm_env_rpc_spec_manager 23 | from dm_env_rpc.v1 import tensor_spec_utils 24 | from dm_env_rpc.v1 import tensor_utils 25 | 26 | 27 | def tensor_spec_to_dm_env_spec( 28 | tensor_spec: dm_env_rpc_pb2.TensorSpec) -> specs.Array: 29 | """Returns a dm_env spec given a dm_env_rpc TensorSpec. 30 | 31 | Args: 32 | tensor_spec: A dm_env_rpc TensorSpec protobuf. 33 | 34 | Returns: 35 | Either a DiscreteArray, BoundedArray, StringArray or Array, depending on the 36 | content of the TensorSpec. 37 | """ 38 | np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype) 39 | if tensor_spec.HasField('min') or tensor_spec.HasField('max'): 40 | bounds = tensor_spec_utils.bounds(tensor_spec) 41 | 42 | if (not tensor_spec.shape 43 | and np.issubdtype(np_type, np.integer) 44 | and bounds.min == 0 45 | and tensor_spec.HasField('max')): 46 | return specs.DiscreteArray( 47 | num_values=bounds.max + 1, dtype=np_type, name=tensor_spec.name) 48 | else: 49 | return specs.BoundedArray( 50 | shape=tensor_spec.shape, 51 | dtype=np_type, 52 | name=tensor_spec.name, 53 | minimum=bounds.min, 54 | maximum=bounds.max) 55 | else: 56 | if tensor_spec.dtype == dm_env_rpc_pb2.DataType.STRING: 57 | return specs.StringArray(shape=tensor_spec.shape, name=tensor_spec.name) 58 | else: 59 | return specs.Array( 60 | shape=tensor_spec.shape, dtype=np_type, name=tensor_spec.name) 61 | 62 | 63 | def dm_env_spec_to_tensor_spec(spec: specs.Array) -> dm_env_rpc_pb2.TensorSpec: 64 | """Returns a dm_env_rpc TensorSpec from the provided dm_env spec type.""" 65 | dtype = np.str_ if isinstance(spec, specs.StringArray) else spec.dtype 66 | tensor_spec = dm_env_rpc_pb2.TensorSpec( 67 | name=spec.name, 68 | shape=spec.shape, 69 | dtype=tensor_utils.np_type_to_data_type(dtype)) 70 | if isinstance(spec, specs.DiscreteArray): 71 | tensor_spec_utils.set_bounds( 72 | tensor_spec, minimum=0, maximum=spec.num_values - 1) 73 | elif isinstance(spec, specs.BoundedArray): 74 | tensor_spec_utils.set_bounds(tensor_spec, spec.minimum, spec.maximum) 75 | 76 | return tensor_spec 77 | 78 | 79 | def dm_env_spec( 80 | spec_manager: dm_env_rpc_spec_manager.SpecManager 81 | ) -> Dict[str, specs.Array]: 82 | """Returns a dm_env spec for the given `spec_manager`. 83 | 84 | Args: 85 | spec_manager: An instance of SpecManager. 86 | 87 | Returns: 88 | A dict mapping names to either a dm_env Array, BoundedArray, DiscreteArray 89 | or StringArray spec for each named TensorSpec in `spec_manager`. 90 | """ 91 | return { 92 | name: tensor_spec_to_dm_env_spec(spec_manager.name_to_spec(name)) 93 | for name in spec_manager.names() 94 | } 95 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/dm_env_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for dm_env_rpc/dm_env utilities.""" 16 | 17 | import typing 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from dm_env import specs 22 | import numpy as np 23 | 24 | from google.protobuf import text_format 25 | from dm_env_rpc.v1 import dm_env_rpc_pb2 26 | from dm_env_rpc.v1 import dm_env_utils 27 | from dm_env_rpc.v1 import spec_manager 28 | 29 | 30 | class TensorSpecToDmEnvSpecTests(absltest.TestCase): 31 | 32 | def test_no_bounds_gives_arrayspec(self): 33 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 34 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 35 | tensor_spec.shape[:] = [3] 36 | tensor_spec.name = 'foo' 37 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 38 | self.assertEqual(specs.Array(shape=[3], dtype=np.uint32), actual) 39 | self.assertEqual('foo', actual.name) 40 | 41 | def test_string_give_string_array(self): 42 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 43 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING 44 | tensor_spec.shape[:] = [1, 2, 3] 45 | tensor_spec.name = 'string_spec' 46 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 47 | self.assertEqual(specs.StringArray(shape=[1, 2, 3]), actual) 48 | self.assertEqual('string_spec', actual.name) 49 | 50 | def test_scalar_with_0_n_bounds_gives_discrete_array(self): 51 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 52 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 53 | tensor_spec.name = 'foo' 54 | 55 | max_value = 9 56 | tensor_spec.min.uint32s.array[:] = [0] 57 | tensor_spec.max.uint32s.array[:] = [max_value] 58 | actual = typing.cast(specs.DiscreteArray, 59 | dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)) 60 | expected = specs.DiscreteArray( 61 | num_values=max_value + 1, dtype=np.uint32, name='foo') 62 | self.assertEqual(expected, actual) 63 | self.assertEqual(0, actual.minimum) 64 | self.assertEqual(max_value, actual.maximum) 65 | self.assertEqual('foo', actual.name) 66 | 67 | def test_scalar_with_1_n_bounds_gives_bounded_array(self): 68 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 69 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 70 | tensor_spec.name = 'foo' 71 | tensor_spec.min.uint32s.array[:] = [1] 72 | tensor_spec.max.uint32s.array[:] = [10] 73 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 74 | expected = specs.BoundedArray( 75 | shape=(), dtype=np.uint32, minimum=1, maximum=10, name='foo') 76 | self.assertEqual(expected, actual) 77 | self.assertEqual('foo', actual.name) 78 | 79 | def test_scalar_with_0_min_and_no_max_bounds_gives_bounded_array(self): 80 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 81 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 82 | tensor_spec.name = 'foo' 83 | tensor_spec.min.uint32s.array[:] = [0] 84 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 85 | expected = specs.BoundedArray( 86 | shape=(), dtype=np.uint32, minimum=0, maximum=2**32 - 1, name='foo') 87 | self.assertEqual(expected, actual) 88 | self.assertEqual('foo', actual.name) 89 | 90 | def test_only_min_bounds(self): 91 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 92 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 93 | tensor_spec.shape[:] = [3] 94 | tensor_spec.name = 'foo' 95 | tensor_spec.min.uint32s.array[:] = [1] 96 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 97 | expected = specs.BoundedArray( 98 | shape=[3], dtype=np.uint32, minimum=1, maximum=2**32 - 1) 99 | self.assertEqual(expected, actual) 100 | self.assertEqual('foo', actual.name) 101 | 102 | def test_only_max_bounds(self): 103 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 104 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 105 | tensor_spec.shape[:] = [3] 106 | tensor_spec.name = 'foo' 107 | tensor_spec.max.uint32s.array[:] = [10] 108 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 109 | expected = specs.BoundedArray( 110 | shape=[3], dtype=np.uint32, minimum=0, maximum=10) 111 | self.assertEqual(expected, actual) 112 | self.assertEqual('foo', actual.name) 113 | 114 | def test_both_bounds(self): 115 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 116 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 117 | tensor_spec.shape[:] = [3] 118 | tensor_spec.name = 'foo' 119 | tensor_spec.min.uint32s.array[:] = [1] 120 | tensor_spec.max.uint32s.array[:] = [10] 121 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 122 | expected = specs.BoundedArray( 123 | shape=[3], dtype=np.uint32, minimum=1, maximum=10) 124 | self.assertEqual(expected, actual) 125 | self.assertEqual('foo', actual.name) 126 | 127 | def test_bounds_oneof_not_set_gives_dtype_bounds(self): 128 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 129 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 130 | tensor_spec.shape[:] = [3] 131 | tensor_spec.name = 'foo' 132 | 133 | # Just to force the message to get created. 134 | tensor_spec.min.floats.array[:] = [3.0] 135 | tensor_spec.min.ClearField('floats') 136 | 137 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 138 | expected = specs.BoundedArray( 139 | shape=[3], dtype=np.uint32, minimum=0, maximum=2**32 - 1) 140 | self.assertEqual(expected, actual) 141 | self.assertEqual('foo', actual.name) 142 | 143 | def test_bounds_wrong_type_gives_error(self): 144 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 145 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32 146 | tensor_spec.shape[:] = [3] 147 | tensor_spec.name = 'foo' 148 | tensor_spec.min.floats.array[:] = [1.9] 149 | with self.assertRaisesRegex(ValueError, 'uint32'): 150 | dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 151 | 152 | def test_bounds_on_string_gives_error(self): 153 | tensor_spec = dm_env_rpc_pb2.TensorSpec() 154 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING 155 | tensor_spec.shape[:] = [2] 156 | tensor_spec.name = 'named' 157 | tensor_spec.min.floats.array[:] = [1.9] 158 | tensor_spec.max.floats.array[:] = [10.0] 159 | with self.assertRaisesRegex(ValueError, 'string'): 160 | dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec) 161 | 162 | 163 | class DmEnvSpecToTensorSpecTests(parameterized.TestCase): 164 | 165 | @parameterized.parameters( 166 | (specs.Array([1, 2], np.float32, 167 | 'foo'), """name: "foo" shape: 1 shape: 2 dtype: FLOAT"""), 168 | (specs.DiscreteArray(5, int, 'bar'), r"""name: "bar" dtype: INT64 169 | min { int64s { array: 0 } } max { int64s { array: 4 } }"""), 170 | (specs.BoundedArray( 171 | (), np.int32, -1, 5, 'baz'), r"""name: "baz" dtype: INT32 172 | min { int32s { array: -1 } } max { int32s { array: 5 } }"""), 173 | (specs.BoundedArray((1, 2), np.uint8, 0, 127, 174 | 'zog'), r"""name: "zog" shape: 1 shape: 2 dtype: UINT8 175 | min { uint8s { array: "\000" } } max { uint8s { array: "\177" } }"""), 176 | (specs.StringArray(shape=(5, 5), name='fux'), 177 | r"""name: "fux" shape: 5 shape: 5 dtype: STRING"""), 178 | ) 179 | def test_dm_env_spec(self, value, expected): 180 | tensor_spec = dm_env_utils.dm_env_spec_to_tensor_spec(value) 181 | expected = text_format.Parse(expected, dm_env_rpc_pb2.TensorSpec()) 182 | self.assertEqual(expected, tensor_spec) 183 | 184 | 185 | class DmEnvSpecTests(absltest.TestCase): 186 | 187 | def test_spec(self): 188 | dm_env_rpc_specs = { 189 | 54: 190 | dm_env_rpc_pb2.TensorSpec( 191 | name='fuzz', shape=[3], dtype=dm_env_rpc_pb2.DataType.FLOAT), 192 | 55: 193 | dm_env_rpc_pb2.TensorSpec( 194 | name='foo', shape=[2], dtype=dm_env_rpc_pb2.DataType.INT32), 195 | } 196 | manager = spec_manager.SpecManager(dm_env_rpc_specs) 197 | 198 | expected = { 199 | 'foo': specs.Array(shape=[2], dtype=np.int32), 200 | 'fuzz': specs.Array(shape=[3], dtype=np.float32) 201 | } 202 | 203 | self.assertDictEqual(expected, dm_env_utils.dm_env_spec(manager)) 204 | 205 | def test_empty_spec(self): 206 | self.assertDictEqual({}, 207 | dm_env_utils.dm_env_spec(spec_manager.SpecManager({}))) 208 | 209 | def test_spec_generate_and_validate_scalars(self): 210 | dm_env_rpc_specs = [] 211 | for name, dtype in dm_env_rpc_pb2.DataType.items(): 212 | if dtype != dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE: 213 | dm_env_rpc_specs.append( 214 | dm_env_rpc_pb2.TensorSpec(name=name, shape=(), dtype=dtype)) 215 | 216 | for dm_env_rpc_spec in dm_env_rpc_specs: 217 | spec = dm_env_utils.tensor_spec_to_dm_env_spec(dm_env_rpc_spec) 218 | value = spec.generate_value() 219 | spec.validate(value) 220 | 221 | def test_spec_generate_and_validate_tensors(self): 222 | example_shape = (10, 10, 3) 223 | 224 | dm_env_rpc_specs = [] 225 | for name, dtype in dm_env_rpc_pb2.DataType.items(): 226 | if dtype != dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE: 227 | dm_env_rpc_specs.append( 228 | dm_env_rpc_pb2.TensorSpec( 229 | name=name, shape=example_shape, dtype=dtype)) 230 | 231 | for dm_env_rpc_spec in dm_env_rpc_specs: 232 | spec = dm_env_utils.tensor_spec_to_dm_env_spec(dm_env_rpc_spec) 233 | value = spec.generate_value() 234 | spec.validate(value) 235 | 236 | if __name__ == '__main__': 237 | absltest.main() 238 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/error.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Provides custom Pythonic errors for dm_env_rpc error messages.""" 17 | from typing import Iterable 18 | from google.protobuf import any_pb2 19 | from google.rpc import status_pb2 20 | 21 | 22 | class DmEnvRpcError(Exception): 23 | """A dm_env_rpc custom exception. 24 | 25 | Wraps a google.rpc.Status message as a Python Exception class. 26 | """ 27 | 28 | def __init__(self, status_proto: status_pb2.Status): 29 | super().__init__() 30 | self._status_proto = status_proto 31 | 32 | @property 33 | def code(self) -> int: 34 | return self._status_proto.code 35 | 36 | @property 37 | def message(self) -> str: 38 | return self._status_proto.message 39 | 40 | @property 41 | def details(self) -> Iterable[any_pb2.Any]: 42 | return self._status_proto.details 43 | 44 | def __str__(self): 45 | return str(self._status_proto) 46 | 47 | def __repr__(self): 48 | return f'DmEnvRpcError(code={self.code}, message={self.message})' 49 | 50 | def __reduce__(self): 51 | return (DmEnvRpcError, (self._status_proto,)) 52 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/error_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for dm_env_rpc error module.""" 16 | 17 | import pickle 18 | 19 | from absl.testing import absltest 20 | from google.rpc import code_pb2 21 | from google.rpc import status_pb2 22 | from dm_env_rpc.v1 import error 23 | 24 | 25 | class ErrorTest(absltest.TestCase): 26 | 27 | def testSimpleError(self): 28 | message = status_pb2.Status( 29 | code=code_pb2.INVALID_ARGUMENT, message='A test error.') 30 | exception = error.DmEnvRpcError(message) 31 | 32 | self.assertEqual(code_pb2.INVALID_ARGUMENT, exception.code) 33 | self.assertEqual('A test error.', exception.message) 34 | self.assertEqual(str(message), str(exception)) 35 | 36 | def testPickleUnpickle(self): 37 | exception = error.DmEnvRpcError(status_pb2.Status( 38 | code=code_pb2.INVALID_ARGUMENT, message='foo.')) 39 | pickled = pickle.dumps(exception) 40 | unpickled = pickle.loads(pickled) 41 | 42 | self.assertEqual(code_pb2.INVALID_ARGUMENT, unpickled.code) 43 | self.assertEqual('foo.', unpickled.message) 44 | 45 | def testRepr(self): 46 | exception = error.DmEnvRpcError(status_pb2.Status( 47 | code=code_pb2.INVALID_ARGUMENT, message='foo.')) 48 | as_string = repr(exception) 49 | self.assertIn(exception.message, as_string) 50 | self.assertIn(str(exception.code), as_string) 51 | 52 | 53 | if __name__ == '__main__': 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Extensions module for dm_env_rpc.""" 16 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/extensions/properties.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // =========================================================================== 15 | syntax = "proto3"; 16 | 17 | package dm_env_rpc.v1.extensions.properties; 18 | 19 | import "dm_env_rpc/v1/dm_env_rpc.proto"; 20 | 21 | message PropertySpec { 22 | // Required: TensorSpec name field for key value. 23 | dm_env_rpc.v1.TensorSpec spec = 1; 24 | 25 | bool is_readable = 2; 26 | bool is_writable = 3; 27 | bool is_listable = 4; 28 | 29 | string description = 5; 30 | } 31 | 32 | message ListPropertyRequest { 33 | // Key to list property for. Empty string is root level. 34 | string key = 1; 35 | } 36 | 37 | message ListPropertyResponse { 38 | repeated PropertySpec values = 1; 39 | } 40 | 41 | message ReadPropertyRequest { 42 | string key = 1; 43 | } 44 | 45 | message ReadPropertyResponse { 46 | dm_env_rpc.v1.Tensor value = 1; 47 | } 48 | 49 | message WritePropertyRequest { 50 | string key = 1; 51 | dm_env_rpc.v1.Tensor value = 2; 52 | } 53 | 54 | message WritePropertyResponse {} 55 | 56 | message PropertyRequest { 57 | oneof payload { 58 | ReadPropertyRequest read_property = 1; 59 | WritePropertyRequest write_property = 2; 60 | ListPropertyRequest list_property = 3; 61 | } 62 | } 63 | 64 | message PropertyResponse { 65 | oneof payload { 66 | ReadPropertyResponse read_property = 1; 67 | WritePropertyResponse write_property = 2; 68 | ListPropertyResponse list_property = 3; 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/extensions/properties.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A helper class for sending and receiving property requests and responses. 16 | 17 | This helper class provides a Pythonic interface for reading, writing and listing 18 | properties. It simplifies the packing and unpacking of property requests and 19 | responses using the provided dm_env_rpc.v1.connection.Connection instance to 20 | send and receive extension messages. 21 | 22 | Example Usage: 23 | property_extension = PropertiesExtension(connection) 24 | 25 | # To read a property: 26 | value = property_extension['my_property'] 27 | 28 | # To write a property: 29 | property_extension['my_property'] = new_value 30 | 31 | # To find available properties: 32 | property_specs = property_extension.specs() 33 | 34 | spec = property_specs['my_property'] 35 | """ 36 | import contextlib 37 | from typing import Mapping, Sequence, Optional 38 | 39 | from dm_env import specs as dm_env_specs 40 | from google.protobuf import any_pb2 41 | from google.rpc import code_pb2 42 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection 43 | from dm_env_rpc.v1 import dm_env_rpc_pb2 44 | from dm_env_rpc.v1 import dm_env_utils 45 | from dm_env_rpc.v1 import error 46 | from dm_env_rpc.v1 import tensor_utils 47 | from dm_env_rpc.v1.extensions import properties_pb2 48 | 49 | 50 | @contextlib.contextmanager 51 | def _convert_dm_env_rpc_error(): 52 | """Helper to convert DmEnvRpcError to a properties related exception.""" 53 | try: 54 | yield 55 | except error.DmEnvRpcError as e: 56 | if e.code == code_pb2.NOT_FOUND: 57 | raise KeyError('Property key not found!') from e 58 | elif e.code == code_pb2.PERMISSION_DENIED: 59 | raise PermissionError('Property permission denied!') from e 60 | elif e.code == code_pb2.INVALID_ARGUMENT: 61 | raise ValueError('Property value error!') from e 62 | raise 63 | 64 | 65 | class PropertySpec(object): 66 | """Class that represents a property's specification.""" 67 | 68 | def __init__(self, property_spec_proto: properties_pb2.PropertySpec): 69 | """Constructs a property specification from PropertySpec proto message. 70 | 71 | Args: 72 | property_spec_proto: A properties_pb2.PropertySpec message. 73 | """ 74 | self._property_spec_proto = property_spec_proto 75 | 76 | @property 77 | def key(self) -> str: 78 | """Return the property's key.""" 79 | return self._property_spec_proto.spec.name 80 | 81 | @property 82 | def readable(self) -> bool: 83 | """Returns True if the property is readable.""" 84 | return self._property_spec_proto.is_readable 85 | 86 | @property 87 | def writable(self) -> bool: 88 | """Returns True if the property is writable.""" 89 | return self._property_spec_proto.is_writable 90 | 91 | @property 92 | def listable(self) -> bool: 93 | """Returns True if the property is listable.""" 94 | return self._property_spec_proto.is_listable 95 | 96 | @property 97 | def spec(self) -> Optional[dm_env_specs.Array]: 98 | """Returns a dm_env spec if the property has a valid dtype. 99 | 100 | Returns: 101 | Either a dm_env spec or, if the dtype is invalid, None. 102 | """ 103 | if self._property_spec_proto.spec.dtype != ( 104 | dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE): 105 | return dm_env_utils.tensor_spec_to_dm_env_spec( 106 | self._property_spec_proto.spec) 107 | else: 108 | return None 109 | 110 | @property 111 | def description(self) -> str: 112 | """Returns the property's description.""" 113 | return self._property_spec_proto.description 114 | 115 | def __repr__(self): 116 | return (f'PropertySpec(key={self.key}, readable={self.readable}, ' 117 | f'writable={self.writable}, listable={self.listable}, ' 118 | f'spec={self.spec}, description={self.description})') 119 | 120 | 121 | class PropertiesExtension(object): 122 | """Helper class for sending and receiving property requests and responses.""" 123 | 124 | def __init__(self, connection: dm_env_rpc_connection.ConnectionType): 125 | """Construct extension with provided dm_env_rpc connection to the env. 126 | 127 | Args: 128 | connection: An instance of Connection already connected to a dm_env_rpc 129 | server. 130 | """ 131 | self._connection = connection 132 | 133 | def __getitem__(self, key: str): 134 | """Alias for PropertiesExtension read function.""" 135 | return self.read(key) 136 | 137 | def __setitem__(self, key: str, value) -> None: 138 | """Alias for PropertiesExtension write function.""" 139 | self.write(key, value) 140 | 141 | def specs(self, key: str = '') -> Mapping[str, PropertySpec]: 142 | """Helper to return sub-properties as a dict.""" 143 | return { 144 | sub_property.key: sub_property for sub_property in self.list(key) 145 | } 146 | 147 | def read(self, key: str): 148 | """Reads the value of a property. 149 | 150 | Args: 151 | key: A string key that represents the property to read. 152 | 153 | Returns: 154 | The value of the property, either as a scalar (float, int, string, etc.) 155 | or, if the response tensor has a non-empty `shape` attribute, a NumPy 156 | array of the payload with the correct type and shape. See 157 | tensor_utils.unpack for more details. 158 | """ 159 | response = properties_pb2.PropertyResponse() 160 | packed_request = any_pb2.Any() 161 | packed_request.Pack( 162 | properties_pb2.PropertyRequest( 163 | read_property=properties_pb2.ReadPropertyRequest(key=key))) 164 | with _convert_dm_env_rpc_error(): 165 | self._connection.send(packed_request).Unpack(response) 166 | 167 | return tensor_utils.unpack_tensor(response.read_property.value) 168 | 169 | def write(self, key: str, value) -> None: 170 | """Writes the provided value to a property. 171 | 172 | Args: 173 | key: A string key that represents the property to write. 174 | value: A scalar (float, int, string, etc.), NumPy array, or nested lists. 175 | See tensor_utils.pack for more details. 176 | """ 177 | packed_request = any_pb2.Any() 178 | packed_request.Pack( 179 | properties_pb2.PropertyRequest( 180 | write_property=properties_pb2.WritePropertyRequest( 181 | key=key, value=tensor_utils.pack_tensor(value)))) 182 | with _convert_dm_env_rpc_error(): 183 | self._connection.send(packed_request) 184 | 185 | def list(self, key: str = '') -> Sequence[PropertySpec]: 186 | """Lists properties residing under the provided key. 187 | 188 | Args: 189 | key: A string key to list properties at this location. If empty, returns 190 | properties registered at the root level. 191 | 192 | Returns: 193 | A sequence of PropertySpecs. 194 | """ 195 | response = properties_pb2.PropertyResponse() 196 | packed_request = any_pb2.Any() 197 | packed_request.Pack( 198 | properties_pb2.PropertyRequest( 199 | list_property=properties_pb2.ListPropertyRequest(key=key))) 200 | with _convert_dm_env_rpc_error(): 201 | self._connection.send(packed_request).Unpack(response) 202 | 203 | return tuple( 204 | PropertySpec(sub_property) 205 | for sub_property in response.list_property.values) 206 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/extensions/properties_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests Properties extension.""" 16 | 17 | import contextlib 18 | from unittest import mock 19 | 20 | from absl.testing import absltest 21 | from dm_env import specs 22 | import numpy as np 23 | 24 | from google.protobuf import any_pb2 25 | from google.rpc import code_pb2 26 | from google.rpc import status_pb2 27 | from google.protobuf import text_format 28 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection 29 | from dm_env_rpc.v1 import dm_env_rpc_pb2 30 | from dm_env_rpc.v1 import error 31 | from dm_env_rpc.v1.extensions import properties 32 | from dm_env_rpc.v1.extensions import properties_pb2 33 | 34 | 35 | def _create_property_request_key(text_proto): 36 | extension_message = any_pb2.Any() 37 | extension_message.Pack( 38 | text_format.Parse(text_proto, properties_pb2.PropertyRequest())) 39 | return dm_env_rpc_pb2.EnvironmentRequest( 40 | extension=extension_message).SerializeToString() 41 | 42 | 43 | def _pack_property_response(text_proto): 44 | extension_message = any_pb2.Any() 45 | extension_message.Pack( 46 | text_format.Parse(text_proto, properties_pb2.PropertyResponse())) 47 | return dm_env_rpc_pb2.EnvironmentResponse(extension=extension_message) 48 | 49 | # Set of expected requests and associated responses for mock connection. 50 | _EXPECTED_REQUEST_RESPONSE_PAIRS = { 51 | _create_property_request_key('read_property { key: "foo" }'): 52 | _pack_property_response( 53 | 'read_property { value: { int32s: { array: 1 } } }'), 54 | _create_property_request_key("""write_property { 55 | key: "bar" 56 | value: { strings { array: "some_value" } } 57 | }"""): 58 | _pack_property_response('write_property {}'), 59 | _create_property_request_key('read_property { key: "bar" }'): 60 | _pack_property_response( 61 | 'read_property { value: { strings: { array: "some_value" } } }'), 62 | _create_property_request_key('list_property { key: "baz" }'): 63 | _pack_property_response("""list_property { 64 | values: { 65 | is_readable:true 66 | spec { name: "baz.fiz" dtype:UINT32 shape: 2 shape: 2 } 67 | }}"""), 68 | _create_property_request_key('list_property {}'): 69 | _pack_property_response("""list_property { 70 | values: { is_readable:true spec { name: "foo" dtype:INT32 } 71 | description: "This is a documented integer" } 72 | values: { is_readable:true 73 | is_writable:true 74 | spec { name: "bar" dtype:STRING } } 75 | values: { is_listable:true spec { name: "baz" } } 76 | }"""), 77 | _create_property_request_key('read_property { key: "bad_property" }'): 78 | dm_env_rpc_pb2.EnvironmentResponse( 79 | error=status_pb2.Status(message='invalid property request.')), 80 | _create_property_request_key('read_property { key: "invalid_key" }'): 81 | dm_env_rpc_pb2.EnvironmentResponse( 82 | error=status_pb2.Status( 83 | code=code_pb2.NOT_FOUND, message='Invalid key.')), 84 | _create_property_request_key("""write_property { 85 | key: "argument_test" 86 | value: { strings: { array: "invalid" } } }"""): 87 | dm_env_rpc_pb2.EnvironmentResponse( 88 | error=status_pb2.Status( 89 | code=code_pb2.INVALID_ARGUMENT, message='Invalid argument.')), 90 | _create_property_request_key('read_property { key: "permission_key" }'): 91 | dm_env_rpc_pb2.EnvironmentResponse( 92 | error=status_pb2.Status( 93 | code=code_pb2.PERMISSION_DENIED, message='No permission.')) 94 | } 95 | 96 | 97 | @contextlib.contextmanager 98 | def _create_mock_connection(): 99 | """Helper to create mock dm_env_rpc connection.""" 100 | with mock.patch.object(dm_env_rpc_connection, 101 | 'dm_env_rpc_pb2_grpc') as mock_grpc: 102 | 103 | def _process(request_iterator, **kwargs): 104 | del kwargs 105 | for request in request_iterator: 106 | yield _EXPECTED_REQUEST_RESPONSE_PAIRS[request.SerializeToString()] 107 | 108 | mock_stub_class = mock.MagicMock() 109 | mock_stub_class.Process = _process 110 | mock_grpc.EnvironmentStub.return_value = mock_stub_class 111 | yield dm_env_rpc_connection.Connection(mock.MagicMock()) 112 | 113 | 114 | class PropertiesTest(absltest.TestCase): 115 | 116 | def test_read_property(self): 117 | with _create_mock_connection() as connection: 118 | extension = properties.PropertiesExtension(connection) 119 | self.assertEqual(1, extension['foo']) 120 | 121 | def test_write_property(self): 122 | with _create_mock_connection() as connection: 123 | extension = properties.PropertiesExtension(connection) 124 | extension['bar'] = 'some_value' 125 | self.assertEqual('some_value', extension['bar']) 126 | 127 | def test_list_property(self): 128 | with _create_mock_connection() as connection: 129 | extension = properties.PropertiesExtension(connection) 130 | property_specs = extension.specs('baz') 131 | self.assertLen(property_specs, 1) 132 | 133 | property_spec = property_specs['baz.fiz'] 134 | self.assertTrue(property_spec.readable) 135 | self.assertFalse(property_spec.writable) 136 | self.assertFalse(property_spec.listable) 137 | self.assertEqual( 138 | specs.Array(shape=(2, 2), dtype=np.uint32), property_spec.spec) 139 | 140 | def test_root_list_property(self): 141 | with _create_mock_connection() as connection: 142 | extension = properties.PropertiesExtension(connection) 143 | property_specs = extension.specs() 144 | self.assertLen(property_specs, 3) 145 | self.assertTrue(property_specs['foo'].readable) 146 | self.assertTrue(property_specs['bar'].readable) 147 | self.assertTrue(property_specs['bar'].writable) 148 | self.assertTrue(property_specs['baz'].listable) 149 | 150 | def test_invalid_spec_request_on_listable_property(self): 151 | with _create_mock_connection() as connection: 152 | extension = properties.PropertiesExtension(connection) 153 | property_specs = extension.specs() 154 | self.assertTrue(property_specs['baz'].listable) 155 | self.assertIsNone(property_specs['baz'].spec) 156 | 157 | def test_invalid_request(self): 158 | with _create_mock_connection() as connection: 159 | extension = properties.PropertiesExtension(connection) 160 | with self.assertRaisesRegex(error.DmEnvRpcError, 161 | 'invalid property request.'): 162 | _ = extension['bad_property'] 163 | 164 | def test_invalid_key_raises_key_error(self): 165 | with _create_mock_connection() as connection: 166 | extension = properties.PropertiesExtension(connection) 167 | with self.assertRaises(KeyError): 168 | _ = extension['invalid_key'] 169 | 170 | def test_invalid_argument_raises_value_error(self): 171 | with _create_mock_connection() as connection: 172 | extension = properties.PropertiesExtension(connection) 173 | with self.assertRaises(ValueError): 174 | extension['argument_test'] = 'invalid' 175 | 176 | def test_permission_denied_raises_permission_error(self): 177 | with _create_mock_connection() as connection: 178 | extension = properties.PropertiesExtension(connection) 179 | with self.assertRaises(PermissionError): 180 | _ = extension['permission_key'] 181 | 182 | def test_property_description(self): 183 | with _create_mock_connection() as connection: 184 | extension = properties.PropertiesExtension(connection) 185 | property_specs = extension.specs() 186 | self.assertEqual('This is a documented integer', 187 | property_specs['foo'].description) 188 | 189 | def test_property_print(self): 190 | with _create_mock_connection() as connection: 191 | extension = properties.PropertiesExtension(connection) 192 | property_specs = extension.specs() 193 | self.assertRegex( 194 | str(property_specs['foo']), 195 | (r'PropertySpec\(key=foo, readable=True, writable=False, ' 196 | r'listable=False, spec=.*, ' 197 | r'description=This is a documented integer\)')) 198 | 199 | 200 | if __name__ == '__main__': 201 | absltest.main() 202 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/message_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Helper functions used to process dm_env_rpc request / response messages. 16 | """ 17 | 18 | import typing 19 | from typing import Iterable, NamedTuple, Type, Union 20 | 21 | import immutabledict 22 | 23 | from google.protobuf import any_pb2 24 | from google.protobuf import message 25 | from dm_env_rpc.v1 import dm_env_rpc_pb2 26 | from dm_env_rpc.v1 import error 27 | 28 | 29 | _MESSAGE_TYPE_TO_FIELD = immutabledict.immutabledict({ 30 | field.message_type.name: field.name 31 | for field in dm_env_rpc_pb2.EnvironmentRequest.DESCRIPTOR.fields 32 | }) 33 | 34 | # An unpacked extension request (anything). 35 | # As any proto message that is not a native request is accepted, this definition 36 | # is overly broad - use with care. 37 | DmEnvRpcExtensionMessage = message.Message 38 | 39 | # A packed extension request. 40 | # Wraps a DmEnvRpcExtensionMessage. 41 | DmEnvRpcPackedExtensionMessage = any_pb2.Any 42 | 43 | # A native request RPC. 44 | DmEnvRpcNativeRequest = Union[ 45 | dm_env_rpc_pb2.CreateWorldRequest, 46 | dm_env_rpc_pb2.JoinWorldRequest, 47 | dm_env_rpc_pb2.StepRequest, 48 | dm_env_rpc_pb2.ResetRequest, 49 | dm_env_rpc_pb2.ResetWorldRequest, 50 | dm_env_rpc_pb2.LeaveWorldRequest, 51 | dm_env_rpc_pb2.DestroyWorldRequest, 52 | ] 53 | 54 | # A native response RPC. 55 | DmEnvRpcNativeResponse = Union[ 56 | dm_env_rpc_pb2.CreateWorldResponse, 57 | dm_env_rpc_pb2.JoinWorldResponse, 58 | dm_env_rpc_pb2.StepResponse, 59 | dm_env_rpc_pb2.ResetResponse, 60 | dm_env_rpc_pb2.ResetWorldResponse, 61 | dm_env_rpc_pb2.LeaveWorldResponse, 62 | dm_env_rpc_pb2.DestroyWorldResponse, 63 | ] 64 | 65 | # A native request RPC, or an extension message, wrapped in Any. 66 | DmEnvRpcRequest = Union[DmEnvRpcNativeRequest, DmEnvRpcPackedExtensionMessage] 67 | # A native response RPC, or an extension message, wrapped in Any. 68 | DmEnvRpcResponse = Union[DmEnvRpcNativeResponse, DmEnvRpcPackedExtensionMessage] 69 | 70 | 71 | def pack_rpc_request( 72 | request: Union[DmEnvRpcRequest, DmEnvRpcExtensionMessage], 73 | ) -> DmEnvRpcRequest: 74 | """Returns a DmEnvRpcRequest that is suitable to send over the wire. 75 | 76 | Arguments: 77 | request: The request to pack. 78 | 79 | Returns: 80 | Native request - returned as-is. 81 | Packed extension (any_pb2.Any) - returned as-is. 82 | Everything else - assumed to be an extension; wrapped in any_pb2.Any. 83 | """ 84 | if isinstance(request, typing.get_args(DmEnvRpcNativeRequest)): 85 | return typing.cast(DmEnvRpcNativeRequest, request) 86 | else: 87 | return _pack_message(request) 88 | 89 | 90 | def unpack_rpc_request( 91 | request: Union[DmEnvRpcRequest, DmEnvRpcExtensionMessage], 92 | *, 93 | extension_type: Union[ 94 | Type[message.Message], Iterable[Type[message.Message]] 95 | ], 96 | ) -> Union[DmEnvRpcRequest, DmEnvRpcExtensionMessage]: 97 | """Returns a DmEnvRpcRequest without a wrapper around extension messages. 98 | 99 | Arguments: 100 | request: The request to unpack. 101 | extension_type: One or more extension protobuf classes of known extension 102 | types. 103 | 104 | Returns: 105 | Native request - returned as-is. 106 | Unpacked extension in |extension_type| - returned as-is. 107 | Packed extension (any_pb2.Any) in |extension_type| - returned unpacked 108 | 109 | Raises: 110 | ValueError: The message is packed (any_pb2.Any), but not in |extension_type| 111 | or the message type is not a native request or known extension. 112 | """ 113 | if isinstance(request, typing.get_args(DmEnvRpcNativeRequest)): 114 | return request 115 | else: 116 | return _unpack_message( 117 | request, 118 | extension_type=extension_type) 119 | 120 | 121 | def pack_rpc_response( 122 | response: Union[DmEnvRpcResponse, DmEnvRpcExtensionMessage], 123 | ) -> DmEnvRpcResponse: 124 | """Returns a DmEnvRpcResponse that is suitable to send over the wire. 125 | 126 | Arguments: 127 | response: The response to pack. 128 | 129 | Returns: 130 | Native response - returned as-is. 131 | Packed extension (any_pb2.Any) - returned as-is. 132 | Everything else - assumed to be an extension; wrapped in any_pb2.Any. 133 | """ 134 | if isinstance(response, typing.get_args(DmEnvRpcNativeResponse)): 135 | return typing.cast(DmEnvRpcNativeResponse, response) 136 | else: 137 | return _pack_message(response) 138 | 139 | 140 | def unpack_rpc_response( 141 | response: Union[DmEnvRpcResponse, DmEnvRpcExtensionMessage], 142 | *, 143 | extension_type: Union[ 144 | Type[message.Message], Iterable[Type[message.Message]] 145 | ], 146 | ) -> Union[DmEnvRpcRequest, DmEnvRpcExtensionMessage]: 147 | """Returns a DmEnvRpcResponse without a wrapper around extension messages. 148 | 149 | Arguments: 150 | response: The response to unpack. 151 | extension_type: One or more extension protobuf classes of known extension 152 | types. 153 | 154 | Returns: 155 | Native response - returned as-is. 156 | Unpacked extension in |extension_type| - returned as-is. 157 | Packed extension (any_pb2.Any) in |extension_type| - returned unpacked 158 | 159 | Raises: 160 | ValueError: The message is packed (any_pb2.Any), but not in |extension_type| 161 | or the message type is not a native request or known extension. 162 | """ 163 | if isinstance(response, typing.get_args(DmEnvRpcNativeResponse)): 164 | return response 165 | else: 166 | return _unpack_message(response, extension_type=extension_type) 167 | 168 | 169 | class EnvironmentRequestAndFieldName(NamedTuple): 170 | """EnvironmentRequest and field name used when packing.""" 171 | environment_request: dm_env_rpc_pb2.EnvironmentRequest 172 | field_name: str 173 | 174 | 175 | def pack_environment_request( 176 | request: DmEnvRpcRequest) -> EnvironmentRequestAndFieldName: 177 | """Constructs an EnvironmentRequest containing a request message. 178 | 179 | Args: 180 | request: An instance of a dm_env_rpc Request type, such as 181 | CreateWorldRequest. 182 | 183 | Returns: 184 | A tuple of (environment_request, field_name) where: 185 | environment_request: dm_env_rpc.v1.EnvironmentRequest containing the input 186 | request message. 187 | field_name: Name of the environment request field holding the input 188 | request message. 189 | """ 190 | field_name = _MESSAGE_TYPE_TO_FIELD[type(request).__name__] 191 | environment_request = dm_env_rpc_pb2.EnvironmentRequest() 192 | getattr(environment_request, field_name).CopyFrom(request) 193 | return EnvironmentRequestAndFieldName(environment_request, field_name) 194 | 195 | 196 | def unpack_environment_response( 197 | environment_response: dm_env_rpc_pb2.EnvironmentResponse, 198 | expected_field_name: str) -> DmEnvRpcResponse: 199 | """Extracts the response message contained within an EnvironmentResponse. 200 | 201 | Args: 202 | environment_response: An instance of dm_env_rpc.v1.EnvironmentResponse. 203 | expected_field_name: Name of the environment response field expected to be 204 | holding the dm_env_rpc response message. 205 | 206 | Returns: 207 | dm_env_rpc response message wrapped in the input environment response. 208 | 209 | Raises: 210 | DmEnvRpcError: The dm_env_rpc EnvironmentResponse contains an error. 211 | ValueError: The dm_env_rpc response message contained in the 212 | EnvironmentResponse is held in a different field from the one expected. 213 | """ 214 | response_field_name = environment_response.WhichOneof('payload') 215 | if response_field_name == 'error': 216 | raise error.DmEnvRpcError(environment_response.error) 217 | elif response_field_name == expected_field_name: 218 | return getattr(environment_response, expected_field_name) 219 | else: 220 | raise ValueError('Unexpected response message! expected: ' 221 | f'{expected_field_name}, actual: {response_field_name}') 222 | 223 | 224 | def _pack_message(msg) -> any_pb2.Any: 225 | """Helper to pack message into an Any proto.""" 226 | if isinstance(msg, any_pb2.Any): 227 | return msg 228 | 229 | # Assume the message is an extension. 230 | packed = any_pb2.Any() 231 | packed.Pack(msg) 232 | return packed 233 | 234 | 235 | def _unpack_message( 236 | msg: message.Message, 237 | *, 238 | extension_type: Union[ 239 | Type[message.Message], Iterable[Type[message.Message]] 240 | ], 241 | ): 242 | """Helper to unpack a message from set of possible extensions. 243 | 244 | Args: 245 | msg: The message to process. 246 | extension_type: Type or type(s) used to match extension messages. The first 247 | matching type is used. 248 | 249 | Returns: 250 | An upacked extension message with type within |extension_type|. 251 | 252 | Raises: 253 | TypeError: Raised if a return type could not be determined. 254 | """ 255 | if isinstance(extension_type, type): 256 | extension_type = (extension_type,) 257 | else: 258 | extension_type = tuple(extension_type) 259 | 260 | if isinstance(msg, extension_type): 261 | return msg 262 | 263 | if isinstance(msg, any_pb2.Any): 264 | matching_type = next( 265 | (e for e in extension_type if msg.Is(e.DESCRIPTOR)), None) 266 | 267 | if not matching_type: 268 | raise ValueError( 269 | 'Extension type could not be found to unpack message: ' 270 | f'{type(msg).__name__}.\n' 271 | f'Known Types:\n' + '\n'.join(f'- {e}' for e in extension_type)) 272 | 273 | unpacked = matching_type() 274 | msg.Unpack(unpacked) 275 | return unpacked 276 | 277 | raise ValueError( 278 | f'Cannot unpack extension message with type: {type(msg).__name__}.' 279 | ) 280 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/message_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for dm_env_rpc/message_utils.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | from google.rpc import status_pb2 21 | from dm_env_rpc.v1 import dm_env_rpc_pb2 22 | from dm_env_rpc.v1 import error 23 | from dm_env_rpc.v1 import message_utils 24 | from dm_env_rpc.v1 import tensor_utils 25 | from google.protobuf import any_pb2 26 | 27 | 28 | _CREATE_WORLD_REQUEST = dm_env_rpc_pb2.CreateWorldRequest( 29 | settings={'foo': tensor_utils.pack_tensor('bar')}) 30 | _CREATE_WORLD_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse(world_name='qux') 31 | _CREATE_WORLD_ENVIRONMENT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse( 32 | create_world=_CREATE_WORLD_RESPONSE) 33 | 34 | # Anything that's not a "native" rpc message is an extension. 35 | _EXTENSION_TYPE = error.status_pb2.Status 36 | _EXTENSION_MULTI_TYPE = (dm_env_rpc_pb2.EnvironmentResponse, _EXTENSION_TYPE) 37 | _EXTENSION_MESSAGE = _EXTENSION_TYPE(code=42) 38 | _PACKED_EXTENSION_MESSAGE = any_pb2.Any() 39 | _PACKED_EXTENSION_MESSAGE.Pack(_EXTENSION_MESSAGE) 40 | 41 | 42 | class MessageUtilsTests(parameterized.TestCase): 43 | 44 | def test_pack_create_world_request(self): 45 | environment_request, field_name = message_utils.pack_environment_request( 46 | _CREATE_WORLD_REQUEST) 47 | self.assertEqual(field_name, 'create_world') 48 | self.assertEqual(environment_request.WhichOneof('payload'), 49 | 'create_world') 50 | self.assertEqual(environment_request.create_world, 51 | _CREATE_WORLD_REQUEST) 52 | 53 | def test_unpack_create_world_response(self): 54 | response = message_utils.unpack_environment_response( 55 | _CREATE_WORLD_ENVIRONMENT_RESPONSE, 'create_world') 56 | self.assertEqual(response, _CREATE_WORLD_RESPONSE) 57 | 58 | def test_unpack_error_response(self): 59 | with self.assertRaisesRegex(error.DmEnvRpcError, 'A test error.'): 60 | message_utils.unpack_environment_response( 61 | dm_env_rpc_pb2.EnvironmentResponse( 62 | error=status_pb2.Status(message='A test error.')), 63 | 'create_world') 64 | 65 | def test_unpack_incorrect_response(self): 66 | with self.assertRaisesWithLiteralMatch( 67 | ValueError, 68 | 'Unexpected response message! expected: create_world, actual: ' 69 | 'leave_world'): 70 | message_utils.unpack_environment_response( 71 | dm_env_rpc_pb2.EnvironmentResponse( 72 | leave_world=dm_env_rpc_pb2.LeaveWorldResponse()), 73 | 'create_world') 74 | 75 | @parameterized.named_parameters( 76 | dict( 77 | testcase_name='create_world_request_passes_through', 78 | message=_CREATE_WORLD_REQUEST, 79 | expected=_CREATE_WORLD_REQUEST, 80 | ), 81 | dict( 82 | testcase_name='packed_extension_message_passes_through', 83 | message=_PACKED_EXTENSION_MESSAGE, 84 | expected=_PACKED_EXTENSION_MESSAGE, 85 | ), 86 | dict( 87 | testcase_name='extension_message_is_packed', 88 | message=_EXTENSION_MESSAGE, 89 | expected=_PACKED_EXTENSION_MESSAGE, 90 | ), 91 | ) 92 | def test_pack_request(self, message, expected): 93 | self.assertEqual(message_utils.pack_rpc_request(message), expected) 94 | 95 | @parameterized.named_parameters( 96 | dict( 97 | testcase_name='create_world_response', 98 | message=_CREATE_WORLD_RESPONSE, 99 | expected=_CREATE_WORLD_RESPONSE, 100 | ), 101 | dict( 102 | testcase_name='extension_message', 103 | message=_EXTENSION_MESSAGE, 104 | expected=_PACKED_EXTENSION_MESSAGE, 105 | ), 106 | dict( 107 | testcase_name='packed_extension_message', 108 | message=_PACKED_EXTENSION_MESSAGE, 109 | expected=_PACKED_EXTENSION_MESSAGE, 110 | ), 111 | ) 112 | def test_pack_response(self, message, expected): 113 | self.assertEqual(message_utils.pack_rpc_response(message), expected) 114 | 115 | @parameterized.named_parameters( 116 | dict( 117 | testcase_name='create_world_request_passes_through', 118 | message=_CREATE_WORLD_REQUEST, 119 | extensions=_EXTENSION_TYPE, 120 | expected=_CREATE_WORLD_REQUEST, 121 | ), 122 | dict( 123 | testcase_name='extension_message_passes_through', 124 | message=_EXTENSION_MESSAGE, 125 | extensions=_EXTENSION_TYPE, 126 | expected=_EXTENSION_MESSAGE, 127 | ), 128 | dict( 129 | testcase_name='extension_message_passes_through_with_multi', 130 | message=_EXTENSION_MESSAGE, 131 | extensions=_EXTENSION_MULTI_TYPE, 132 | expected=_EXTENSION_MESSAGE, 133 | ), 134 | dict( 135 | testcase_name='packed_extension_message_is_unpacked_with_multi', 136 | message=_PACKED_EXTENSION_MESSAGE, 137 | extensions=_EXTENSION_MULTI_TYPE, 138 | expected=_EXTENSION_MESSAGE, 139 | ), 140 | ) 141 | def test_unpack_request(self, message, extensions, expected): 142 | self.assertEqual( 143 | message_utils.unpack_rpc_request(message, extension_type=extensions), 144 | expected) 145 | 146 | @parameterized.named_parameters( 147 | dict( 148 | testcase_name='create_world_response_passes_through', 149 | message=_CREATE_WORLD_RESPONSE, 150 | extensions=_EXTENSION_TYPE, 151 | expected=_CREATE_WORLD_RESPONSE, 152 | ), 153 | dict( 154 | testcase_name='extension_message_passes_through', 155 | message=_EXTENSION_MESSAGE, 156 | extensions=_EXTENSION_TYPE, 157 | expected=_EXTENSION_MESSAGE, 158 | ), 159 | dict( 160 | testcase_name='packed_extension_message_is_unpacked', 161 | message=_PACKED_EXTENSION_MESSAGE, 162 | extensions=_EXTENSION_TYPE, 163 | expected=_EXTENSION_MESSAGE, 164 | ), 165 | dict( 166 | testcase_name='extension_message_passes_through_with_multi', 167 | message=_EXTENSION_MESSAGE, 168 | extensions=_EXTENSION_MULTI_TYPE, 169 | expected=_EXTENSION_MESSAGE, 170 | ), 171 | dict( 172 | testcase_name='packed_extension_message_is_unpacked_with_multi', 173 | message=_PACKED_EXTENSION_MESSAGE, 174 | extensions=_EXTENSION_MULTI_TYPE, 175 | expected=_EXTENSION_MESSAGE, 176 | ), 177 | ) 178 | def test_unpack_response(self, message, extensions, expected): 179 | self.assertEqual( 180 | message_utils.unpack_rpc_response(message, extension_type=extensions), 181 | expected) 182 | 183 | if __name__ == '__main__': 184 | absltest.main() 185 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/spec_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Manager class to manage the dm_env_rpc UID system.""" 16 | 17 | from typing import Any, Collection, Mapping, MutableMapping 18 | import numpy as np 19 | 20 | from dm_env_rpc.v1 import dm_env_rpc_pb2 21 | from dm_env_rpc.v1 import tensor_utils 22 | 23 | 24 | def _assert_shapes_match(tensor: dm_env_rpc_pb2.Tensor, 25 | dm_env_rpc_spec: dm_env_rpc_pb2.TensorSpec): 26 | """Raises ValueError if shape of tensor and spec don't match.""" 27 | tensor_shape = np.asarray(tensor.shape) 28 | spec_shape = np.asarray(dm_env_rpc_spec.shape) 29 | 30 | # Check all elements are equal, or the spec element is -1 (variable length). 31 | if tensor_shape.size != spec_shape.size or not np.all( 32 | (tensor_shape == spec_shape) | (spec_shape < 0)): 33 | raise ValueError( 34 | 'Received dm_env_rpc tensor {} with shape {} but spec has shape {}.' 35 | .format(dm_env_rpc_spec.name, tensor_shape, spec_shape)) 36 | 37 | 38 | class SpecManager(object): 39 | """Manages transitions between Python dicts and dm_env_rpc UIDs. 40 | 41 | To make sending and receiving actions and observations easier for dm_env_rpc, 42 | this helps manage the transition between UID-keyed dicts mapping to dm_env_rpc 43 | tensors and string-keyed dicts mapping to scalars, lists, or NumPy arrays. 44 | """ 45 | 46 | def __init__(self, specs: Mapping[int, dm_env_rpc_pb2.TensorSpec]): 47 | """Builds the SpecManager from the given dm_env_rpc specs. 48 | 49 | Args: 50 | specs: A dict mapping UIDs to dm_env_rpc TensorSpecs, similar to what is 51 | stored in `actions` and `observations` in ActionObservationSpecs. 52 | """ 53 | for spec in specs.values(): 54 | if np.count_nonzero(np.asarray(spec.shape) < 0) > 1: 55 | raise ValueError( 56 | f'"{spec.name}" shape has > 1 variable length dimension. ' 57 | f'Spec:\n{spec}') 58 | 59 | self._name_to_uid = {spec.name: uid for uid, spec in specs.items()} 60 | self._uid_to_name = {uid: spec.name for uid, spec in specs.items()} 61 | if len(self._name_to_uid) != len(self._uid_to_name): 62 | raise ValueError('There are duplicate names in the tensor specs.') 63 | 64 | self._specs_by_uid = specs 65 | self._specs_by_name = {spec.name: spec for spec in specs.values()} 66 | 67 | @property 68 | def specs_by_uid(self) -> Mapping[int, dm_env_rpc_pb2.TensorSpec]: 69 | return self._specs_by_uid 70 | 71 | @property 72 | def specs_by_name(self) -> Mapping[str, dm_env_rpc_pb2.TensorSpec]: 73 | return self._specs_by_name 74 | 75 | def name_to_uid(self, name: str) -> int: 76 | """Returns the UID for the given name.""" 77 | return self._name_to_uid[name] 78 | 79 | def uid_to_name(self, uid: int) -> str: 80 | """Returns the name for the given UID.""" 81 | return self._uid_to_name[uid] 82 | 83 | def name_to_spec(self, name: str) -> dm_env_rpc_pb2.TensorSpec: 84 | """Returns the dm_env_rpc TensorSpec named `name`.""" 85 | return self._specs_by_name[name] 86 | 87 | def uid_to_spec(self, uid: int) -> dm_env_rpc_pb2.TensorSpec: 88 | """Returns the dm_env_rpc TensorSpec for the given UID.""" 89 | return self._specs_by_uid[uid] 90 | 91 | def names(self) -> Collection[str]: 92 | """Returns the spec names in no particular order.""" 93 | return self._name_to_uid.keys() 94 | 95 | def uids(self) -> Collection[int]: 96 | """Returns the spec UIDs in no particular order.""" 97 | return self._uid_to_name.keys() 98 | 99 | def unpack( 100 | self, dm_env_rpc_tensors: Mapping[int, dm_env_rpc_pb2.Tensor] 101 | ) -> MutableMapping[str, Any]: 102 | """Unpacks a dm_env_rpc uid-to-tensor map to a name-keyed Python dict. 103 | 104 | Args: 105 | dm_env_rpc_tensors: A dict mapping UIDs to dm_env_rpc tensor protos. 106 | 107 | Returns: 108 | A dict mapping names to scalars and arrays. 109 | """ 110 | unpacked = {} 111 | for uid, tensor in dm_env_rpc_tensors.items(): 112 | name = self._uid_to_name[uid] 113 | dm_env_rpc_spec = self.name_to_spec(name) 114 | _assert_shapes_match(tensor, dm_env_rpc_spec) 115 | tensor_dtype = tensor_utils.get_tensor_type(tensor) 116 | spec_dtype = tensor_utils.data_type_to_np_type(dm_env_rpc_spec.dtype) 117 | if tensor_dtype != spec_dtype: 118 | raise ValueError( 119 | 'Received dm_env_rpc tensor {} with dtype {} but spec has dtype {}.' 120 | .format(name, tensor_dtype, spec_dtype)) 121 | tensor_unpacked = tensor_utils.unpack_tensor(tensor) 122 | unpacked[name] = tensor_unpacked 123 | return unpacked 124 | 125 | def pack( 126 | self, 127 | tensors: Mapping[str, Any]) -> MutableMapping[int, dm_env_rpc_pb2.Tensor]: 128 | """Packs a name-keyed Python dict to a dm_env_rpc uid-to-tensor map. 129 | 130 | Args: 131 | tensors: A dict mapping string names to scalars and arrays. 132 | 133 | Returns: 134 | A dict mapping UIDs to dm_env_rpc tensor protos. 135 | """ 136 | packed = {} 137 | for name, value in tensors.items(): 138 | dm_env_rpc_spec = self.name_to_spec(name) 139 | tensor = tensor_utils.pack_tensor(value, dtype=dm_env_rpc_spec.dtype) 140 | _assert_shapes_match(tensor, dm_env_rpc_spec) 141 | packed[self.name_to_uid(name)] = tensor 142 | return packed 143 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/spec_manager_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for SpecManager class.""" 16 | 17 | from absl.testing import absltest 18 | import numpy as np 19 | 20 | from dm_env_rpc.v1 import dm_env_rpc_pb2 21 | from dm_env_rpc.v1 import spec_manager 22 | from dm_env_rpc.v1 import tensor_utils 23 | 24 | _EXAMPLE_SPECS = { 25 | 54: 26 | dm_env_rpc_pb2.TensorSpec( 27 | name='fuzz', shape=[2], dtype=dm_env_rpc_pb2.DataType.FLOAT), 28 | 55: 29 | dm_env_rpc_pb2.TensorSpec( 30 | name='foo', shape=[3], dtype=dm_env_rpc_pb2.DataType.INT32), 31 | } 32 | 33 | 34 | class SpecManagerTests(absltest.TestCase): 35 | 36 | def setUp(self): 37 | super(SpecManagerTests, self).setUp() 38 | self._spec_manager = spec_manager.SpecManager(_EXAMPLE_SPECS) 39 | 40 | def test_specs_by_uid(self): 41 | self.assertDictEqual(_EXAMPLE_SPECS, self._spec_manager.specs_by_uid) 42 | 43 | def test_specs_by_name(self): 44 | expected = {'foo': _EXAMPLE_SPECS[55], 'fuzz': _EXAMPLE_SPECS[54]} 45 | self.assertDictEqual(expected, self._spec_manager.specs_by_name) 46 | 47 | def test_name_to_uid(self): 48 | self.assertEqual(55, self._spec_manager.name_to_uid('foo')) 49 | 50 | def test_name_to_uid_no_such_name(self): 51 | with self.assertRaisesRegex(KeyError, 'bar'): 52 | self._spec_manager.name_to_uid('bar') 53 | 54 | def test_name_to_spec(self): 55 | spec = self._spec_manager.name_to_spec('foo') 56 | self.assertEqual([3], spec.shape) 57 | 58 | def test_name_to_spec_no_such_name(self): 59 | with self.assertRaisesRegex(KeyError, 'bar'): 60 | self._spec_manager.name_to_spec('bar') 61 | 62 | def test_uid_to_name(self): 63 | self.assertEqual('foo', self._spec_manager.uid_to_name(55)) 64 | 65 | def test_uid_to_name_no_such_uid(self): 66 | with self.assertRaisesRegex(KeyError, '56'): 67 | self._spec_manager.uid_to_name(56) 68 | 69 | def test_names(self): 70 | self.assertEqual(set(['foo', 'fuzz']), self._spec_manager.names()) 71 | 72 | def test_uids(self): 73 | self.assertEqual(set([54, 55]), self._spec_manager.uids()) 74 | 75 | def test_uid_to_spec(self): 76 | spec = self._spec_manager.uid_to_spec(54) 77 | self.assertEqual([2], spec.shape) 78 | 79 | def test_pack(self): 80 | packed = self._spec_manager.pack({'fuzz': [1.0, 2.0], 'foo': [3, 4, 5]}) 81 | expected = { 82 | 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32), 83 | 55: tensor_utils.pack_tensor([3, 4, 5], dtype=np.int32), 84 | } 85 | self.assertDictEqual(expected, packed) 86 | 87 | def test_partial_pack(self): 88 | packed = self._spec_manager.pack({ 89 | 'fuzz': [1.0, 2.0], 90 | }) 91 | expected = { 92 | 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32), 93 | } 94 | self.assertDictEqual(expected, packed) 95 | 96 | def test_pack_unknown_key_raises_error(self): 97 | with self.assertRaisesRegex(KeyError, 'buzz'): 98 | self._spec_manager.pack({'buzz': 'hello'}) 99 | 100 | def test_pack_wrong_shape_raises_error(self): 101 | with self.assertRaisesRegex(ValueError, 'shape'): 102 | self._spec_manager.pack({'foo': [1, 2]}) 103 | 104 | def test_pack_wrong_dtype_raises_error(self): 105 | with self.assertRaises(ValueError): 106 | self._spec_manager.pack({'foo': 'hello'}) 107 | 108 | def test_pack_cast_int_to_float_is_ok(self): 109 | packed = self._spec_manager.pack({'fuzz': [1, 2]}) 110 | self.assertEqual([1.0, 2.0], packed[54].floats.array) 111 | 112 | def test_unpack(self): 113 | unpacked = self._spec_manager.unpack({ 114 | 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32), 115 | 55: tensor_utils.pack_tensor([3, 4, 5], dtype=np.int32), 116 | }) 117 | self.assertLen(unpacked, 2) 118 | np.testing.assert_array_equal(np.asarray([1.0, 2.0]), unpacked['fuzz']) 119 | np.testing.assert_array_equal(np.asarray([3, 4, 5]), unpacked['foo']) 120 | 121 | def test_partial_unpack(self): 122 | unpacked = self._spec_manager.unpack({ 123 | 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32), 124 | }) 125 | self.assertLen(unpacked, 1) 126 | np.testing.assert_array_equal(np.asarray([1.0, 2.0]), unpacked['fuzz']) 127 | 128 | def test_unpack_unknown_uid_raises_error(self): 129 | with self.assertRaisesRegex(KeyError, '53'): 130 | self._spec_manager.unpack({53: tensor_utils.pack_tensor('foo')}) 131 | 132 | def test_unpack_wrong_shape_raises_error(self): 133 | with self.assertRaisesRegex(ValueError, 'shape'): 134 | self._spec_manager.unpack({55: tensor_utils.pack_tensor([1, 2])}) 135 | 136 | def test_unpack_wrong_type_raises_error(self): 137 | with self.assertRaisesRegex(ValueError, 'dtype'): 138 | self._spec_manager.unpack( 139 | {55: tensor_utils.pack_tensor([1, 2, 3], dtype=np.float32)}) 140 | 141 | 142 | class SpecManagerVariableSpecShapeTests(absltest.TestCase): 143 | 144 | def setUp(self): 145 | super(SpecManagerVariableSpecShapeTests, self).setUp() 146 | specs = { 147 | 101: 148 | dm_env_rpc_pb2.TensorSpec( 149 | name='foo', shape=[1, -1], dtype=dm_env_rpc_pb2.DataType.INT32), 150 | } 151 | self._spec_manager = spec_manager.SpecManager(specs) 152 | 153 | def test_variable_spec_shape(self): 154 | packed = self._spec_manager.pack({'foo': [[1, 2, 3, 4]]}) 155 | expected = { 156 | 101: tensor_utils.pack_tensor([[1, 2, 3, 4]], dtype=np.int32), 157 | } 158 | self.assertDictEqual(expected, packed) 159 | 160 | def test_invalid_variable_shape(self): 161 | with self.assertRaisesRegex(ValueError, 'shape'): 162 | self._spec_manager.pack({'foo': np.ones((1, 2, 3), dtype=np.int32)}) 163 | 164 | def test_empty_variable_shape(self): 165 | manager = spec_manager.SpecManager({ 166 | 1: 167 | dm_env_rpc_pb2.TensorSpec( 168 | name='bar', shape=[], dtype=dm_env_rpc_pb2.DataType.INT32) 169 | }) 170 | with self.assertRaisesRegex(ValueError, 'shape'): 171 | manager.pack({'bar': np.ones((1), dtype=np.int32)}) 172 | 173 | def test_invalid_variable_spec_shape(self): 174 | with self.assertRaisesRegex(ValueError, 'shape has > 1 variable length'): 175 | spec_manager.SpecManager({ 176 | 1: 177 | dm_env_rpc_pb2.TensorSpec( 178 | name='bar', 179 | shape=[1, -1, -1], 180 | dtype=dm_env_rpc_pb2.DataType.INT32) 181 | }) 182 | 183 | 184 | class SpecManagerConstructorTests(absltest.TestCase): 185 | 186 | def test_duplicate_names_raise_error(self): 187 | specs = { 188 | 54: 189 | dm_env_rpc_pb2.TensorSpec( 190 | name='fuzz', shape=[3], dtype=dm_env_rpc_pb2.DataType.FLOAT), 191 | 55: 192 | dm_env_rpc_pb2.TensorSpec( 193 | name='fuzz', shape=[2], dtype=dm_env_rpc_pb2.DataType.FLOAT), 194 | } 195 | with self.assertRaisesRegex(ValueError, 'duplicate name'): 196 | spec_manager.SpecManager(specs) 197 | 198 | 199 | if __name__ == '__main__': 200 | absltest.main() 201 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/tensor_spec_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Helper Python utilities for managing dm_env_rpc TensorSpecs.""" 16 | 17 | import dataclasses 18 | from typing import Generic, TypeVar, Union 19 | import numpy as np 20 | 21 | from dm_env_rpc.v1 import dm_env_rpc_pb2 22 | from dm_env_rpc.v1 import tensor_utils 23 | 24 | 25 | _BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE = ( 26 | 'TensorSpec "{name}"\'s bounds [{minimum}, {maximum}] contain value(s) ' 27 | 'that cannot be safely cast to dtype {dtype}.') 28 | 29 | T = TypeVar('T') 30 | 31 | 32 | @dataclasses.dataclass 33 | class Bounds(Generic[T]): 34 | min: T 35 | max: T 36 | 37 | 38 | def _is_valid_bound(array_or_scalar, np_dtype: np.dtype) -> bool: 39 | """Returns whether array_or_scalar is a valid bound for the given type.""" 40 | array_or_scalar = np.asarray(array_or_scalar) 41 | if np.issubdtype(array_or_scalar.dtype, np.integer): 42 | iinfo = np.iinfo(np_dtype) 43 | for value in np.asarray(array_or_scalar).flat: 44 | if int(value) < iinfo.min or int(value) > iinfo.max: 45 | return False 46 | elif np.issubdtype(array_or_scalar.dtype, np.floating): 47 | finfo = np.finfo(np_dtype) 48 | for value in np.asarray(array_or_scalar).flat: 49 | if (float(value) < finfo.min and float(value) != -np.inf) or ( 50 | float(value) > finfo.max and float(value) != np.inf 51 | ): 52 | return False 53 | return True 54 | 55 | 56 | def np_range_info(np_type: ...) -> Union[np.finfo, np.iinfo]: 57 | """Returns type info for `np_type`, which includes min and max attributes.""" 58 | np_type = np.dtype(np_type) 59 | if issubclass(np_type.type, np.floating): 60 | return np.finfo(np_type) 61 | elif issubclass(np_type.type, np.integer): 62 | return np.iinfo(np_type) 63 | else: 64 | raise ValueError('{} does not have range info.'.format(np_type)) 65 | 66 | 67 | def _get_value(min_max_value, shape, default): 68 | """Helper function that returns the min/max bounds for a Value message. 69 | 70 | Args: 71 | min_max_value: Value protobuf message to get value from. 72 | shape: Optional dimensions to unpack payload data to. 73 | default: Value to use if min_max_value is not set. 74 | 75 | Returns: 76 | A scalar if `shape` is empty or None, or an unpacked NumPy array of either 77 | the unpacked value or provided default. 78 | 79 | """ 80 | which = min_max_value.WhichOneof('payload') 81 | value = which and getattr(min_max_value, which) 82 | 83 | if value is None: 84 | min_max = default 85 | else: 86 | unpacked = tensor_utils.unpack_proto(min_max_value) 87 | min_max = tensor_utils.reshape_array( 88 | unpacked, shape) if len(unpacked) > 1 else unpacked[0] 89 | 90 | if (shape is not None 91 | and np.any(np.array(shape) < 0) 92 | and np.asarray(min_max).size > 1): 93 | raise ValueError( 94 | "TensorSpec's with variable length shapes can only have scalar ranges. " 95 | 'Shape: {}, value: {}'.format(shape, min_max)) 96 | return min_max 97 | 98 | 99 | def bounds(tensor_spec: dm_env_rpc_pb2.TensorSpec) -> Bounds: 100 | """Gets the inclusive bounds of `tensor_spec`. 101 | 102 | Args: 103 | tensor_spec: An instance of a dm_env_rpc TensorSpec proto. 104 | 105 | Returns: 106 | A named tuple (`min`, `max`) of inclusive bounds. 107 | 108 | Raises: 109 | ValueError: `tensor_spec` does not have a numeric dtype, or the type of its 110 | `min` or `max` does not match its dtype, or the the bounds are invalid in 111 | some way. 112 | """ 113 | np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype) 114 | tensor_spec_type = dm_env_rpc_pb2.DataType.Name(tensor_spec.dtype).lower() 115 | if not issubclass(np_type.type, np.number): 116 | raise ValueError('TensorSpec "{}" has non-numeric type {}.' 117 | .format(tensor_spec.name, tensor_spec_type)) 118 | 119 | # Check min payload type matches the tensor type. 120 | min_which = tensor_spec.min.WhichOneof('payload') 121 | if min_which and not min_which.startswith(tensor_spec_type): 122 | raise ValueError('TensorSpec "{}" has dtype {} but min type {}.'.format( 123 | tensor_spec.name, tensor_spec_type, min_which)) 124 | 125 | # Check max payload type matches the tensor type. 126 | max_which = tensor_spec.max.WhichOneof('payload') 127 | if max_which and not max_which.startswith(tensor_spec_type): 128 | raise ValueError('TensorSpec "{}" has dtype {} but max type {}.'.format( 129 | tensor_spec.name, tensor_spec_type, max_which)) 130 | 131 | dtype_bounds = np_range_info(np_type) 132 | min_bound = _get_value(tensor_spec.min, tensor_spec.shape, dtype_bounds.min) 133 | max_bound = _get_value(tensor_spec.max, tensor_spec.shape, dtype_bounds.max) 134 | 135 | if not _is_valid_bound(min_bound, np_type) or not _is_valid_bound( 136 | max_bound, np_type 137 | ): 138 | raise ValueError( 139 | _BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE.format( 140 | name=tensor_spec.name, 141 | minimum=min_bound, 142 | maximum=max_bound, 143 | dtype=tensor_spec_type)) 144 | 145 | if np.any(max_bound < min_bound): 146 | raise ValueError('TensorSpec "{}" has min {} larger than max {}.'.format( 147 | tensor_spec.name, min_bound, max_bound)) 148 | 149 | return Bounds(np_type.type(min_bound), np_type.type(max_bound)) 150 | 151 | 152 | def set_bounds(tensor_spec: dm_env_rpc_pb2.TensorSpec, minimum, maximum): 153 | """Modifies `tensor_spec` to have its inclusive bounds set. 154 | 155 | Packs `minimum` in to `tensor_spec.min` and `maximum` in to `tensor_spec.max`. 156 | 157 | Args: 158 | tensor_spec: An instance of a dm_env_rpc TensorSpec proto. It should 159 | already have its `name`, `dtype` and `shape` attributes set. 160 | minimum: The minimum value that elements in the described tensor can obtain. 161 | A scalar, iterable of scalars, or None. If None, `min` will be cleared on 162 | `tensor_spec`. 163 | maximum: The maximum value that elements in the described tensor can obtain. 164 | A scalar, iterable of scalars, or None. If None, `max` will be cleared on 165 | `tensor_spec`. 166 | """ 167 | np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype) 168 | if not issubclass(np_type.type, np.number): 169 | raise ValueError(f'TensorSpec has non-numeric type "{np_type}".') 170 | 171 | has_min = minimum is not None 172 | has_max = maximum is not None 173 | 174 | if ((has_min and not _is_valid_bound(minimum, np_type)) or 175 | (has_max and not _is_valid_bound(maximum, np_type))): 176 | raise ValueError( 177 | _BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE.format( 178 | name=tensor_spec.name, 179 | minimum=minimum, 180 | maximum=maximum, 181 | dtype=dm_env_rpc_pb2.DataType.Name(tensor_spec.dtype))) 182 | 183 | if has_min: 184 | minimum = np.asarray(minimum, dtype=np_type) 185 | if minimum.size != 1 and minimum.shape != tuple(tensor_spec.shape): 186 | raise ValueError( 187 | f'minimum has shape {minimum.shape}, which is incompatible with ' 188 | f"tensor_spec {tensor_spec.name}'s shape {tensor_spec.shape}.") 189 | 190 | if has_max: 191 | maximum = np.asarray(maximum, dtype=np_type) 192 | if maximum.size != 1 and maximum.shape != tuple(tensor_spec.shape): 193 | raise ValueError( 194 | f'maximum has shape {maximum.shape}, which is incompatible with ' 195 | f"tensor_spec {tensor_spec.name}'s shape {tensor_spec.shape}.") 196 | 197 | if (has_min and has_max and np.any(maximum < minimum)): 198 | raise ValueError('TensorSpec "{}" has min {} larger than max {}.'.format( 199 | tensor_spec.name, minimum, maximum)) 200 | 201 | packer = tensor_utils.get_packer(np_type) 202 | if has_min: 203 | packer.pack(tensor_spec.min, minimum) 204 | else: 205 | tensor_spec.ClearField('min') 206 | 207 | if has_max: 208 | packer.pack(tensor_spec.max, maximum) 209 | else: 210 | tensor_spec.ClearField('max') 211 | -------------------------------------------------------------------------------- /dm_env_rpc/v1/tensor_utils_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Micro-benchmark for tensor_utils.pack_tensor.""" 16 | 17 | import abc 18 | import timeit 19 | 20 | from absl import app 21 | from absl import flags 22 | import numpy as np 23 | 24 | from dm_env_rpc.v1 import tensor_utils 25 | 26 | flags.DEFINE_integer('repeats', 10000, 27 | 'Number of times each benchmark will run.') 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | class _AbstractBenchmark(metaclass=abc.ABCMeta): 32 | """Base class for benchmarks using timeit.""" 33 | 34 | def run(self): 35 | time = timeit.timeit(self.statement, setup=self.setup, number=FLAGS.repeats) 36 | print(f'{self.name} -- overall: {time:0.2f} s, ' 37 | f'per call: {time/FLAGS.repeats:0.1e} s') 38 | 39 | def setup(self): 40 | pass 41 | 42 | @abc.abstractmethod 43 | def statement(self): 44 | pass 45 | 46 | @abc.abstractproperty 47 | def name(self): 48 | pass 49 | 50 | 51 | class _PackBenchmark(_AbstractBenchmark): 52 | """Benchmark for packing a numpy array to a Tensor proto.""" 53 | 54 | def __init__(self, dtype, shape): 55 | self._name = f'pack {np.dtype(dtype).name}' 56 | self._dtype = dtype 57 | self._shape = shape 58 | 59 | @property 60 | def name(self): 61 | return self._name 62 | 63 | def setup(self): 64 | # Use non-zero values in case there's something special about zero arrays. 65 | self._unpacked = np.arange( 66 | np.prod(self._shape), dtype=self._dtype).reshape(self._shape) 67 | 68 | def statement(self): 69 | self._unpacked.flat[0] += 1 # prevent caching of the result 70 | tensor_utils.pack_tensor(self._unpacked, self._dtype) 71 | 72 | 73 | class _UnpackBenchmark(_AbstractBenchmark): 74 | """Benchmark for unpacking a Tensor proto to a numpy array.""" 75 | 76 | def __init__(self, dtype, shape): 77 | self._name = f'unpack {np.dtype(dtype).name}' 78 | self._shape = shape 79 | self._dtype = dtype 80 | 81 | @property 82 | def name(self): 83 | return self._name 84 | 85 | def setup(self): 86 | # Use non-zero values in case there's something special about zero arrays. 87 | tensor = np.arange( 88 | np.prod(self._shape), dtype=self._dtype).reshape(self._shape) 89 | self._packed = tensor_utils.pack_tensor(tensor, self._dtype) 90 | 91 | def statement(self): 92 | tensor_utils.unpack_tensor(self._packed) 93 | 94 | 95 | def main(argv): 96 | if len(argv) > 1: 97 | raise app.UsageError('Too many command-line arguments.') 98 | # Pick `shape` such that number of bytes is consistent between benchmarks. 99 | benchmarks = ( 100 | _PackBenchmark(dtype=np.uint8, shape=(128, 128, 3)), 101 | _PackBenchmark(dtype=np.int32, shape=(64, 64, 3)), 102 | _PackBenchmark(dtype=np.int64, shape=(32, 64, 3)), 103 | _PackBenchmark(dtype=np.uint32, shape=(64, 64, 3)), 104 | _PackBenchmark(dtype=np.uint64, shape=(32, 64, 3)), 105 | _PackBenchmark(dtype=np.float32, shape=(64, 64, 3)), 106 | _PackBenchmark(dtype=np.float64, shape=(32, 64, 3)), 107 | _UnpackBenchmark(dtype=np.uint8, shape=(128, 128, 3)), 108 | _UnpackBenchmark(dtype=np.int32, shape=(64, 64, 3)), 109 | _UnpackBenchmark(dtype=np.int64, shape=(32, 64, 3)), 110 | _UnpackBenchmark(dtype=np.uint32, shape=(64, 64, 3)), 111 | _UnpackBenchmark(dtype=np.uint64, shape=(32, 64, 3)), 112 | _UnpackBenchmark(dtype=np.float32, shape=(64, 64, 3)), 113 | _UnpackBenchmark(dtype=np.float64, shape=(32, 64, 3)), 114 | ) 115 | for benchmark in benchmarks: 116 | benchmark.run() 117 | 118 | 119 | if __name__ == '__main__': 120 | app.run(main) 121 | -------------------------------------------------------------------------------- /docs/v1/2x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/2x2.png -------------------------------------------------------------------------------- /docs/v1/2x3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/2x3.png -------------------------------------------------------------------------------- /docs/v1/appendix.md: -------------------------------------------------------------------------------- 1 | ## Advice for implementers 2 | 3 | A server implementation does not need to implement the full protocol to be 4 | usable. When first starting, just supporting `JoinWorld`, `LeaveWorld` and 5 | `Step` is sufficient to provide a minimum `dm_env_rpc` environment (though 6 | clients using the provided `DmEnvAdaptor` will need access to `Reset` as well). 7 | The world can be considered already created just by virtue of the server being 8 | available and listening on a port. Unsupported requests should just return an 9 | error. After this base, the following can be added: 10 | 11 | * `CreateWorld` and `DestroyWorld` can provide a way for clients to give 12 | settings and manage the lifetime of worlds. 13 | * Supporting sequences, where one sequence reaches its natural conclusion and 14 | the next begins on the next step. 15 | * `Reset` and `ResetWorld` can provide mechanisms to manage transitions 16 | between sequences and to update settings. 17 | * Read/Write/List properties for client code to query or set state outside of 18 | the normal observe/act loop. 19 | 20 | A client implementation likely has to support the full range of features and 21 | data types that the server it wants to interact with supports. However, it's 22 | unnecessary to support data types or features that the server does not 23 | implement. If the target server does not provide tensors with more than one 24 | dimension it probably isn't worth the effort to support higher dimensional 25 | tensors, for instance. 26 | 27 | For Python clients, some python utility code is provided. Specifically: 28 | 29 | * connection.py - A utility class for managing the client-side gRPC connection 30 | and the error handling from server responses. 31 | * spec_manager.py - A utility class for managing the UID-to-name mapping, so 32 | the rest of the client code can use the more human readable tensor names. 33 | * tensor_utils.py - A few utility functions for packing and unpacking NumPy 34 | arrays to `dm_env_rpc` tensors. 35 | 36 | For other languages similar utility functions will likely be needed. It is 37 | especially recommended that client code have something similar to `SpecManager` 38 | that turns UIDs into human readable text as soon as possible. Strings are both 39 | less likely to be used wrong and more easily debugged. 40 | 41 | ### Common state transition errors 42 | 43 | The [Environment connection state machine](overview.md#states) cannot transition 44 | directly between `INTERRUPTED` and `TERMINATED` states. Servers may 45 | inadvertently cause an `INTERRUPTED` --> `TERMINATED` transition if: 46 | 47 | 1. The server environment starts in a completed state (and therefore finishes 48 | in 0 steps), or 49 | 1. The server environment setup fails but does not throw an exception until the 50 | first step. 51 | 52 | To ensure server state transitions are legal, please check the validity of the 53 | server environment when it is initialized. 54 | 55 | ## Reward functions 56 | 57 | [Reward function design is difficult](https://www.alexirpan.com/2018/02/14/rl-hard.html#reward-function-design-is-difficult), 58 | and may be the first thing a client will tweak. 59 | 60 | For instance, a simple arcade game could expose its score function as `reward`. 61 | However, the score function might not actually be a good measure of playing 62 | ability if there's a series of actions which increases score without advancing 63 | towards the actual desired goal. Alternatively, games might have an obvious 64 | reward signal only at the end (whether the agent won or not), which might be too 65 | sparse for a reinforcement learning agent. Constructing a robust reward function 66 | is an area of active research and it's perfectly reasonable for an environment 67 | to abdicate the responsibility of forming one. 68 | 69 | For instance, constructing a reward function for the game of chess is actually 70 | rather tricky. A simple reward at the end if an agent wins is going to make 71 | learning difficult, as agents won't get feedback during a game if they make 72 | mistakes or play well. Beginner chess players often use 73 | [material values](https://en.wikipedia.org/wiki/Chess_piece_relative_value) to 74 | evaluate a given position to decide who is winning, with queens being worth more 75 | than rooks, etc. This might seem like a prime candidate for a reward function, 76 | however there are 77 | [well known shortcomings](https://en.wikipedia.org/wiki/Chess_piece_relative_value#Shortcomings_of_piece_valuation_systems) 78 | to this simple system. For instance, using this as a reward function can blind 79 | an agent to a move which loses material but produces a checkmate. 80 | 81 | If a client does construct a custom reward function it may want access to data 82 | which normally would be considered hidden and unavailable. Exposing this 83 | information to clients may feel like cheating, however getting an AI agent to 84 | start learning at all is often half the battle. To this end servers should 85 | expose as much relevant information through the environment as they can to give 86 | clients room to experiment. Weaning an agent off this information down the line 87 | may be possible; just be sure to document which observables are considered 88 | hidden information so a client can strip them in the future. 89 | 90 | For client implementations, if a given server does not provide a reward or 91 | discount observable or they aren't suitable you can build your own from other 92 | observables. For instance, the provided `DmEnvAdaptor` has `reward` and 93 | `discount` functions which can be overridden in derived classes. 94 | 95 | ## Nested observations example 96 | 97 | Retrofitting an existing object-oriented codebase to be a `dm_env_rpc` server 98 | can be difficult, as the data is likely arranged in a tree-like structure of 99 | heterogeneous nodes. For instance, a game might have pickup and character 100 | classes which inherit from the same base: 101 | 102 | ``` 103 | Entity { 104 | Vector3 position; 105 | } 106 | 107 | HealthPickup : Entity { 108 | float HealthRecoveryAmount; 109 | Vector4 IconColor; 110 | } 111 | 112 | Character : Entity { 113 | string Name; 114 | } 115 | ``` 116 | 117 | It's not immediately clear how to turn this data into tensors, which can be a 118 | difficult issue for a server implementation. 119 | 120 | First, though, not all data needs to be sent to a joined connection. For a 121 | health pickup, the health recovery amount might be hidden information that 122 | agents don't generally have access to. Likewise, the IconColor might only matter 123 | for a user interface, which an agent might not have. 124 | 125 | After filtering unnecessary data the remainder may fit in a 126 | "[structure of arrays](https://en.wikipedia.org/wiki/AoS_and_SoA)" format. 127 | 128 | In our example case, we can structure the remaining data as a few different 129 | tensors. The specs for them might look like: 130 | 131 | ``` 132 | TensorSpec { 133 | name = "HealthPickups.Position", 134 | shape = [3, -1] 135 | dtype = float 136 | }, 137 | 138 | TensorSpec { 139 | name = "Characters.Position" 140 | shape = [3, -1] 141 | dtype = float 142 | }, 143 | 144 | TensorSpec { 145 | name = "Characters.Name", 146 | shape = [-1] 147 | dtype = string 148 | } 149 | ``` 150 | 151 | The use of a period "." character allows clients to reconstruct a hierarchy, 152 | which can be useful in grouping tensors into logical categories. 153 | 154 | If a structure of arrays format is infeasible, custom protocol messages can be 155 | implemented as an alternative. This allows a more natural data representation, 156 | but requires the client to also compile the protocol buffers and makes discovery 157 | of metadata, such as range for numeric types, more difficult. For our toy 158 | example, we could form the data in to this intermediate protocol buffer format: 159 | 160 | ```protobuf 161 | message Vector3 { 162 | float x = 1; 163 | float y = 2; 164 | float z = 3; 165 | } 166 | 167 | message HealthPickup { 168 | Vector3 position = 1; 169 | } 170 | 171 | message Character { 172 | Vector3 position = 1; 173 | string name = 2; 174 | } 175 | ``` 176 | 177 | The `TensorSpec` would then look like: 178 | 179 | ``` 180 | TensorSpec { 181 | name = "Characters", 182 | shape = [-1], 183 | dtype = proto 184 | }, 185 | 186 | TensorSpec { 187 | name = "HealthPickups", 188 | shape = [-1], 189 | dtype = proto 190 | }, 191 | ``` 192 | 193 | ### Rendering 194 | 195 | Often renders are desirable observations for agents, whether human or machine. 196 | In past frameworks, such as 197 | [Atari](https://deepmind.com/research/publications/playing-atari-deep-reinforcement-learning) 198 | and 199 | [DMLab](https://deepmind.com/blog/article/impala-scalable-distributed-deeprl-dmlab-30), 200 | environments have been responsible for rendering. Servers implementing this 201 | protocol will likely (but not necessarily) continue this tradition. 202 | 203 | For performance reasons rendering should not be done unless requested by an 204 | agent, and then only the individual renders requested. Servers can batch similar 205 | renders if desirable. 206 | 207 | The exact image format for a render is at the discretion of the server 208 | implementation, but should be documented. Reasonable choices are to return an 209 | interleaved RGB buffer, similar to a GPU render texture, or a standard image 210 | format such as PNG or JPEG. 211 | 212 | Human agents generally prefer high resolution images, such as the high 213 | definition 1920x1080 format. Reinforcement agents, however, often use much 214 | smaller resolutions, such as 96x72, for performance of both environment and 215 | agent. It is up to the server implementation how or if it wants to expose 216 | resolution as a setting and what resolutions it wants to support. It could be 217 | hard coded or specified in world settings or join settings. Server implementers 218 | are advised to consider performance impacts of whatever choice they make, as 219 | rendering time often dominates the runtime of many environments. 220 | 221 | ### Per-element Min/Max Ranges {#per-element-ranges} 222 | 223 | When defining actions for an environment, it's desirable to provide users with 224 | the range of valid values that can be sent. `TensorSpec`s support this by 225 | providing min and max range fields, which can either be defined as a single 226 | scalar (see [broadcastable](overview.md#broadcastable)) for all elements in a 227 | `Tensor`, or each element can have their own range defined. 228 | 229 | For the vast majority of actions, we strongly encourage a single range only 230 | should be defined. When users start defining per-element ranges, this is 231 | typically indicative of the action being a combination of several, distinct 232 | actions. By slicing the action into separate actions, there's also the benefit 233 | of providing a more descriptive name, as well as making it easier for clients to 234 | easily compose their own action spaces. 235 | 236 | There are some examples where it may be desirable to not split the action. For 237 | example, imagine an `ABSOLUTE_MOUSE_XY` action. This would have two elements 238 | (one for the `X` and `Y` positions respectively), but would likely have 239 | different ranges based on the width and height of the screen. Splitting the 240 | action would mean agents could send only the `X` or `Y` action without the 241 | other, which might not be valid. 242 | 243 | ### Documentation 244 | 245 | Actions and observations can be self-documenting, in the sense that their 246 | `TensorSpec`s provide information about their type, shape and bounds. However, 247 | the exact meaning of actions and observations are not discoverable through the 248 | protocol. In addition, `CreateWorldRequest` and `JoinWorldRequest` settings lack 249 | even a spec. Therefore server implementers should be sure to document allowed 250 | settings for create and join world, along with their types and shapes, the 251 | consequences of any actions and the meaning of any observations, as well as any 252 | properties. 253 | -------------------------------------------------------------------------------- /docs/v1/extensions/index.md: -------------------------------------------------------------------------------- 1 | # Extensions 2 | 3 | `dm_env_rpc` Extensions provide a way for users to send custom messages over the 4 | same bi-directional gRPC stream as the standard messages. This can be useful for 5 | debugging or manipulating the simulation in a way that isn't appropriate through 6 | the typical `dm_env_rpc` protocol. 7 | 8 | Extensions must complement the existing `dm_env_rpc` protocol, so that agents 9 | are able to send conventional `dm_env_rpc` messages, interspersed with extension 10 | requests. 11 | 12 | To send an extension message, you must pack your custom message into an `Any` 13 | proto, assigning to the `extension` field in `EnvironmentRequest`. For example: 14 | 15 | ```python 16 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection 17 | from google.protobuf import any_pb2 18 | from google.protobuf import struct_pb2 19 | 20 | def send_custom_message(connection: dm_env_rpc_connection.ConnectionType): 21 | """Send/receive custom Struct message through a dm_env_rpc connection.""" 22 | my_message = struct_pb2.Struct( 23 | fields={'foo': struct_pb2.Value(string_value='bar')}) 24 | 25 | packed_message = any_pb2.Any() 26 | packed_message.Pack(my_message) 27 | 28 | response = struct_pb2.Struct() 29 | connection.send(packed_message).Unpack(response) 30 | ``` 31 | 32 | If the simulation can respond to such a message, it must send a response using 33 | the corresponding `extension` field in `EnvironmentResponse`. The server must 34 | send an [error](../overview.md#errors) if it doesn't recognise the request. 35 | 36 | ## Common Extensions 37 | 38 | The following commonly used extensions are provided with `dm_env_rpc`: 39 | 40 | * [Properties](properties.md) 41 | 42 | ## Recommendations 43 | 44 | ### Should you use an extension? 45 | 46 | Although extensions can be powerful, if you expect an extension message to be 47 | sent by the client every step, consider making it a proper action or observable, 48 | even if it's intended as metadata. This better ensures that all mutable actions 49 | can be executed at well-ordered times. 50 | 51 | ### Creating one parent request/response for your extension 52 | 53 | If your extension has more than a couple of requests, consider creating a single 54 | parent request/response message that you can add/remove messages from. This 55 | simplifies the server code by only having to unpack a single request, and makes 56 | it easier to compartmentalize each extension. For example: 57 | 58 | ```proto 59 | message AwesomeRequest { 60 | string foo = 1; 61 | } 62 | 63 | message AwesomeResponse {} 64 | 65 | message AnotherAwesomeRequest { 66 | string bar = 1; 67 | } 68 | 69 | message AnotherAwesomeResponse {} 70 | 71 | message MyExtensionRequest { 72 | oneof payload { 73 | AwesomeRequest awesome = 1; 74 | AnotherAwesomeRequest another_awesome = 2; 75 | } 76 | } 77 | 78 | message MyExtensionResponse { 79 | oneof payload { 80 | AwesomeResponse awesome = 1; 81 | AnotherAwesomeResponse another_awesome = 2; 82 | } 83 | } 84 | ``` 85 | 86 | ## Alternatives 87 | 88 | An alternative to `dm_env_rpc` extension messages is to register a separate gRPC 89 | service. This has the benefit of being able to use gRPC's other 90 | [service methods](https://grpc.io/docs/guides/concepts/#service-definition). 91 | However, if you need messages to be sent at a particular point (e.g. after a 92 | specific `StepRequest`), synchronizing these disparate services will add 93 | additional complexity to the server. 94 | -------------------------------------------------------------------------------- /docs/v1/extensions/properties.md: -------------------------------------------------------------------------------- 1 | ## Properties 2 | 3 | Often side-channel data is useful for debugging or manipulating the simulation 4 | in some way which isn’t appropriate for an agent’s interface. The property 5 | system provides this capability, allowing both the reading, writing and 6 | discovery of properties. Properties are implemented as an extension to the core 7 | protocol (see [extensions](index.md) for more details). 8 | 9 | Properties queried before a `JoinWorldRequest` will correspond to universal 10 | properties that apply for all worlds. Properties queried after a 11 | JoinWorldRequest can add a layer of world-specific properties. 12 | 13 | For writing properties, it’s up to the server to determine if/when any 14 | modification should take place (e.g. changing the world seed might not take 15 | place until the next sequence). 16 | 17 | For reading properties, the exact timing of the observation is up to the world. 18 | It may occur at the previous step's time, or it may occur at some intermediate 19 | time. 20 | 21 | Properties can be laid out in a tree-like structure, as long as each node in the 22 | tree has a unique key, by having parent nodes be listable (the is_listable field 23 | in the Property message). 24 | 25 | Although properties can be powerful, if you expect a property to be read or 26 | written to every step by a normally functioning agent, it may be preferable to 27 | make it a proper action or observation, even if it’s intended to be metadata. 28 | For instance, a score which provides hidden information could be done as a 29 | property, but it might be preferable to do it as an observation. This ensures 30 | observations and actions all occur at well-ordered times. 31 | 32 | ### Read Property 33 | 34 | ```proto 35 | package dm_env_rpc.v1.extensions.properties; 36 | 37 | message ReadPropertyRequest { 38 | string key = 1; 39 | } 40 | 41 | message ReadPropertyResponse { 42 | dm_env_rpc.v1.Tensor value = 1; 43 | } 44 | ``` 45 | 46 | Returns the current value for the property with the provided `key`. 47 | 48 | ### Write Property 49 | 50 | ```proto 51 | package dm_env_rpc.v1.extensions.properties; 52 | 53 | message WritePropertyRequest { 54 | string key = 1; 55 | dm_env_rpc.v1.Tensor value = 2; 56 | } 57 | 58 | message WritePropertyResponse {} 59 | ``` 60 | 61 | Set the property referened by the `key` field to the value of the provided 62 | `Tensor`. 63 | 64 | ### List Property 65 | 66 | ```proto 67 | package dm_env_rpc.v1.extensions.properties; 68 | 69 | message ListPropertyRequest { 70 | // Key to list property for. Empty string is root level. 71 | string key = 1; 72 | } 73 | 74 | message ListPropertyResponse { 75 | message PropertySpec { 76 | // Required: TensorSpec name field for key value. 77 | dm_env_rpc.v1.TensorSpec spec = 1; 78 | 79 | bool is_readable = 2; 80 | bool is_writable = 3; 81 | bool is_listable = 4; 82 | } 83 | 84 | repeated PropertySpec values = 1; 85 | } 86 | ``` 87 | 88 | Returns an array of `PropertySpec` values for each property residing under the 89 | provided `key`. If the `key` is empty, properties registered at the root level 90 | are returned. If the `key` is not empty and the property is not listable an 91 | error is returned. 92 | 93 | Each `PropertySpec` must return a `TensorSpec` with its name field set to the 94 | property's key. For readable and writable properties, the type and shape of the 95 | property must also be set. Properties which are only listable must have the 96 | default value for type (`dm_env_rpc.v1.DataType.INVALID_DATA_TYPE`) and shape 97 | (an empty array). 98 | -------------------------------------------------------------------------------- /docs/v1/glossary.md: -------------------------------------------------------------------------------- 1 | ## Glossary 2 | 3 | ### Server 4 | 5 | A server is a process which implements the dm_env_rpc `Environment` gRPC 6 | service. A server can host one or more worlds and allow one or more client 7 | connections. 8 | 9 | ### Client 10 | 11 | A client connects to a server. It can request worlds to be created and join them 12 | to become agents. It can also query properties and reset worlds without being an 13 | agent. 14 | 15 | ### Agent 16 | 17 | A client which has joined a world. In a game sense this is a "player". It has 18 | the ability to send actions and receive observations. This could be a human or a 19 | reinforcement learning agent. 20 | 21 | ### World 22 | 23 | A system or simulation in which one or more agents interact. 24 | 25 | ### Environment 26 | 27 | An agent's view of a world. In limited information situations, the environment 28 | may expose only part of the world state to an agent that corresponds to 29 | information that agent is allowed by the simulation to have. Agents communicate 30 | directly with an environment, and the environment communicates with the world to 31 | synchronize with it. 32 | 33 | ### Sequence 34 | 35 | A series of discrete states, where one state is correlated to previous and 36 | subsequent states, possibly ending in a terminal state, and usually modified by 37 | agent actions. In the simplest case, playing an entire game until one player is 38 | declared the winner is one sequence. Also sometimes called an "episode" in 39 | reinforcement learning contexts. 40 | -------------------------------------------------------------------------------- /docs/v1/index.md: -------------------------------------------------------------------------------- 1 | # `dm_env_rpc` documentation 2 | 3 | * [Protocol overview](overview.md) 4 | * [Protocol reference](reference.md) 5 | * [Extensions](extensions/index.md) 6 | * [Appendix](appendix.md) 7 | * [Glossary](glossary.md) 8 | -------------------------------------------------------------------------------- /docs/v1/overview.md: -------------------------------------------------------------------------------- 1 | ## Protocol overview 2 | 3 | `dm_env_rpc` is a protocol for [agents](glossary.md#agent) 4 | ([clients](glossary.md#client)) to communicate with 5 | [environments](glossary.md#environment) ([servers](glossary.md#server)). A 6 | server has a single remote procedural call (RPC) named `Process` for handling 7 | messages from clients. 8 | 9 | ```protobuf 10 | service Environment { 11 | // Process incoming environment requests. 12 | rpc Process(stream EnvironmentRequest) returns (stream EnvironmentResponse) {} 13 | } 14 | ``` 15 | 16 | Each call to `Process` creates a bidirectional streaming connection between a 17 | client and the server. It is up to each server to decide if it can support 18 | multiple simultaneous clients, if each client can instantiate its own 19 | [world](glossary.md#world), or if each client is expected to connect to the same 20 | underlying world. 21 | 22 | Each connection accepts a stream of `EnvironmentRequest` messages from the 23 | client. The server sends exactly one `EnvironmentResponse` for each request. The 24 | payload of the response always either corresponds to that of the request (e.g. a 25 | `StepResponse` in response to a `StepRequest`) or is an error `Status` message. 26 | 27 | ### Streaming 28 | 29 | Clients may send multiple requests without waiting for responses. The server 30 | processes these requests in the order that they are sent and returns the 31 | corresponding responses in the same order. It is the client's responsibility to 32 | ensure each request is valid when processed. 33 | 34 | Note: gRPC guarantees messages will be delivered in the order they are sent. 35 | 36 | ### States 37 | 38 | An Environment connection can be in one of two states: joined to a world or not. 39 | When not joined to a world, `StepRequest` and `ResetRequest` calls are 40 | unavailable (the server will send an error upon receiving them). 41 | 42 | A joined connection may be in a variety of sub-states (ie: RUNNING, TERMINATED, 43 | and INTERRUPTED). Agents transition between these states using `StepRequest`, 44 | `ResetRequest` and `ResetWorldRequest` calls, though the environment controls 45 | which state is transitioned to. 46 | 47 | ![State transitions](state_transitions.png) 48 | 49 | ### Tensors 50 | 51 | For the purposes of this protocol tensors are loosely based on NumPy arrays: 52 | n-dimensional arrays of data with the same data type. A tensor with "n" 53 | dimensions can be referred to as an n-tensor. A 0-tensor is just a scalar value, 54 | such as a single float. A 1-tensor can be thought of either as a single 55 | dimensional array or vector. A 2-tensor is a two dimensional array or a matrix. 56 | In principle there's no limit to the number of dimensions a tensor can have in 57 | the protocol, but in practice we rarely have more than 3 or 4 dimensions. 58 | Tensors are not allowed to be 59 | [ragged](https://en.wikipedia.org/wiki/Jagged_array) (have rows with different 60 | numbers of elements), though they may have a 61 | [variable length](#variable-lengths) along one dimension. 62 | 63 | A tensor's shape represents the number of elements along each dimension. A 64 | 2-tensor with a shape of `[3, 4]` would be a 2 dimensional array with 3 rows and 65 | 4 columns. 66 | 67 | In order to pack these tensors in a way that can be sent over the network they 68 | have to be flattened to a one dimensional array. For multidimensional tensors 69 | it's expected that they will be packed in a row-major format. That is, the 70 | values at indices `[2, 3, 4]` and `[2, 3, 5]` are located next to each other in 71 | the flattened array. This is the default memory layout in C based languages, 72 | such as C/C++, Java, and C#, and NumPy in Python and TensorFlow, but is opposite 73 | to how column-major languages work, such as Fortran. 74 | 75 | Consult the 76 | [Row- and column-major order](https://en.wikipedia.org/wiki/Row-_and_column-major_order) 77 | article for more information. 78 | 79 | #### Variable lengths: 80 | 81 | Normally a tensor has a well defined shape. However, if one of the elements in a 82 | Tensor's shape is negative it represents a variable dimension. Either the client 83 | or the server, upon receiving a Tensor message with a Shape with a negative 84 | element, will attempt to infer the correct value for the shape based on the 85 | number of elements in the Tensor's array part and the rest of the shape. 86 | 87 | Note: even though this dimension has variable length, the tensor itself is still 88 | not ragged. The variable dimension has a definite length that can be inferred. 89 | 90 | For instance, a Tensor with shape `[2, -1]` represents a variable length 91 | 2-tensor with two rows and a variable number of columns. If this Tensor's array 92 | part contains 6 elements `[1, 2, 3, 4, 5, 6]` then the final produced 2-tensor 93 | will look like: 94 | 95 | ![2x3 matrix with first row 1, 2, 3 and second row 4, 5, 6](2x3.png) 96 | 97 | Variable length tensors are useful for situations where a given observation or 98 | action's length is unknowable from frame to frame. For instance, the number of 99 | words in a sentence or the number of stochastic events in a given time frame. 100 | 101 | Note: At most one dimension is allowed to be variable on a given tensor. 102 | 103 | Note: servers should provide actions and observations with non-variable length 104 | if possible, as it can reduce the complexity of agent implementations. 105 | 106 | #### Broadcastable 107 | 108 | If a tensor contains all elements of the same value, it is "broadcastable" and 109 | can be represented with a single value in the array part of the Tensor, even if 110 | the shape requires more elements. Either the client or server, upon receiving a 111 | broadcastable tensor, will unpack it to an appropriately sized multidimensional 112 | array with each element being set to the value from the lone element in the 113 | array part of the Tensor. 114 | 115 | For instance, a Tensor with Shape `[2, 2]` and a single element `1` in its array 116 | part will produce a 2-tensor that looks like: 117 | 118 | ![2x2 matrix of all 1s](2x2.png) 119 | 120 | #### TensorSpec 121 | 122 | A `TensorSpec` provides metadata about a tensor, such as its name, type, and 123 | expected shape. Tensor names must be unique within a given domain (action or 124 | observation) so clients can use them as keys. A period "." character in a Tensor 125 | name indicates a level of nesting. See 126 | [Nested actions/observations](#nested-actions-or-observations) for more details. 127 | 128 | #### Ranges 129 | 130 | Numerical tensors can have min and max ranges on the TensorSpec. These ranges 131 | are inclusive, and can be either: 132 | 133 | * Scalar: Indicates all `Tensor` elements must adhere to this `min/max` value. 134 | * N-dimensional: Must match the `TensorSpec` shape, where each `Tensor` 135 | element has a distinct `min/max` value. 136 | 137 | For a `TensorSpec` with a shape of [variable-length](#variable-lengths), only 138 | scalar ranges are supported. 139 | 140 | Whilst distinct, per-element `min/max` ranges are supported, we encourage 141 | implementers to instead provide separate actions. For more discussion, see the 142 | [per-element min/max range](appendix.md#per-element-ranges) appendix. 143 | 144 | Note: Range is not enforced by the protocol. Servers and clients should be 145 | careful to validate any incoming or outgoing tensors to make sure they are in 146 | range. Servers should return an error for any out of range tensors from the 147 | client. 148 | 149 | ### UIDs 150 | 151 | Unique Identifications (UIDs) are 64 bit numbers used as keys for data that is 152 | sent over the wire frequently, specifically observations and actions. This 153 | reduces the amount of data compared to string keys, since a key is needed for 154 | each action and observation every step. For data that is not intended to be 155 | referenced frequently, such as create and join settings, string keys are used 156 | for clarity. 157 | 158 | For more information on UIDs see [JoinWorld specs](reference.md#specs) 159 | 160 | ### Errors 161 | 162 | If an `EnvironmentRequest` fails for any reason, the payload will contain an 163 | error `Status` instead of the normal response message. It’s up to the server 164 | implementation to decide what error codes and messages to use. For fatal errors, 165 | the server can close the stream after sending the error. For recoverable errors 166 | the server can treat the failed request as a no-op and clients can retry. 167 | 168 | The client cannot send errors to the server. If the client has an error that it 169 | can’t recover from, it should just close the connection (gracefully, if 170 | possible). 171 | 172 | Since a server can *only* send messages in response to a given 173 | `EnvironmentRequest`, the errors should ideally be focused on problems from a 174 | specific client request. More general issues or warnings from a given server 175 | implementation should be logged through a separate mechanism. 176 | 177 | If a server implementation needs to report an error, it should send as much 178 | detail about the nature of the problem as possible and any likely remedies. A 179 | client may have difficulties debugging a server, perhaps because the server is 180 | running on a different machine, so the server should send enough information to 181 | properly diagnose the problem. Any additional relevant information about the 182 | error that would normally be logged by the server should also be included in the 183 | error sent to the client. 184 | 185 | ### Nested actions or observations 186 | 187 | Nested actions or observations are not directly supported by the protocol, 188 | however there are two ways they can be handled: 189 | 190 | 1. Flattening the hierarchy, using a period "." character to indicate a level 191 | of nesting. Servers can flatten the nested structure to push through the 192 | wire and the client can reconstruct the nested structure on its side. 193 | Servers need to be careful not to use "." as part of the tensor's name, 194 | except to indicate a level of nesting. 195 | 196 | 2. Defining a custom proto message type, or using the proto common type Struct, 197 | and setting it as the payload in a Tensor message’s array field. 198 | 199 | Flattening the hierarchy is easier for clients to consume, but can involve a 200 | great deal of work on the server. A custom proto message is more flexible but 201 | means every client needs to compile the custom protobuf for their desired 202 | language. 203 | 204 | Nested data structures occur commonly with object-oriented codebases, such as 205 | from an existing game, and flattening them can be difficult. For an in-depth 206 | discussion see the 207 | [nested observations example](appendix.md#nested-observations-example). 208 | 209 | ### Reward and discount 210 | 211 | `dm_env_rpc` does not provide explicit channels for reward or discount (common 212 | reinforcement learning signals). For servers where there's a sensible reward or 213 | discount already available they can be provided through a `reward` or `discount` 214 | observation respectively. For `dm_env`, the provided `DmEnvAdaptor` will 215 | properly route the reward and discount for client code if available. 216 | 217 | A server may choose not to provide reward and discount observations, however. 218 | See [reward functions](appendix.md#reward-functions) for a discussion on the 219 | pitfalls of reward design. 220 | 221 | ### Multiagent support 222 | 223 | Some servers may support multiple joined connections on the same world. These 224 | multiagent servers are responsible for coordinating how agents interact through 225 | the world, and ensuring each connection has a separate environment for each 226 | agent. 227 | -------------------------------------------------------------------------------- /docs/v1/single_agent_connect_and_step.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/single_agent_connect_and_step.png -------------------------------------------------------------------------------- /docs/v1/single_agent_sequence_transitions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/single_agent_sequence_transitions.png -------------------------------------------------------------------------------- /docs/v1/single_agent_world_destruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/single_agent_world_destruction.png -------------------------------------------------------------------------------- /docs/v1/state_transitions.graphviz: -------------------------------------------------------------------------------- 1 | # Source to generate the state_transitions.png diagram using Graphviz. 2 | 3 | digraph { 4 | node [shape=record]; 5 | compound=true; 6 | "" [shape=diamond] 7 | "" -> TERMINATED [label = "JoinWorld "] 8 | subgraph cluster_world_states { 9 | rank="same"; 10 | TERMINATED -> RUNNING [label = "Step", color=grey, dir=both]; 11 | RUNNING -> INTERRUPTED [label = "Step ", color=grey, dir=both]; 12 | RUNNING -> INTERRUPTED [label = "Reset", color=grey]; 13 | RUNNING -> RUNNING [label = " Step", color=grey]; 14 | TERMINATED -> TERMINATED [label = "Reset", color=grey]; 15 | INTERRUPTED -> INTERRUPTED [label = "Reset", color=grey]; 16 | } 17 | 18 | TERMINATED -> "" [label = "LeaveWorld", ltail=cluster_world_states]; 19 | } 20 | -------------------------------------------------------------------------------- /docs/v1/state_transitions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/state_transitions.png -------------------------------------------------------------------------------- /examples/catch_human_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Example Catch human agent.""" 16 | 17 | from concurrent import futures 18 | 19 | from absl import app 20 | import grpc 21 | import pygame 22 | 23 | import catch_environment 24 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection 25 | from dm_env_rpc.v1 import dm_env_adaptor 26 | from dm_env_rpc.v1 import dm_env_rpc_pb2 27 | from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc 28 | 29 | _FRAMES_PER_SEC = 3 30 | _FRAME_DELAY_MS = int(1000.0 // _FRAMES_PER_SEC) 31 | 32 | _BLACK = (0, 0, 0) 33 | _WHITE = (255, 255, 255) 34 | 35 | _ACTION_LEFT = -1 36 | _ACTION_NOTHING = 0 37 | _ACTION_RIGHT = 1 38 | 39 | _ACTION_PADDLE = 'paddle' 40 | _OBSERVATION_REWARD = 'reward' 41 | _OBSERVATION_BOARD = 'board' 42 | 43 | 44 | def _draw_row(row_str, row_index, standard_font, window_surface): 45 | text = standard_font.render(row_str, True, _WHITE) 46 | text_rect = text.get_rect() 47 | text_rect.left = 50 48 | text_rect.top = 30 + (row_index * 30) 49 | window_surface.blit(text, text_rect) 50 | 51 | 52 | def _render_window(board, window_surface, reward): 53 | """Render the game onto the window surface.""" 54 | 55 | standard_font = pygame.font.SysFont('Courier', 24) 56 | instructions_font = pygame.font.SysFont('Courier', 16) 57 | 58 | num_rows = board.shape[0] 59 | num_cols = board.shape[1] 60 | 61 | window_surface.fill(_BLACK) 62 | 63 | # Draw board. 64 | header = '* ' * (num_cols + 2) 65 | _draw_row(header, 0, standard_font, window_surface) 66 | for board_index in range(num_rows): 67 | row = board[board_index] 68 | row_str = '* ' 69 | for c in row: 70 | row_str += 'x ' if c == 1. else ' ' 71 | row_str += '* ' 72 | _draw_row(row_str, board_index + 1, standard_font, window_surface) 73 | _draw_row(header, num_rows + 1, standard_font, window_surface) 74 | 75 | # Draw footer. 76 | reward_str = 'Reward: {}'.format(reward) 77 | _draw_row(reward_str, num_rows + 3, standard_font, window_surface) 78 | instructions = ('Instructions: Left/Right arrow keys to move paddle, Escape ' 79 | 'to exit.') 80 | _draw_row(instructions, num_rows + 5, instructions_font, window_surface) 81 | 82 | 83 | def _start_server(): 84 | """Starts the Catch gRPC server.""" 85 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) 86 | servicer = catch_environment.CatchEnvironmentService() 87 | dm_env_rpc_pb2_grpc.add_EnvironmentServicer_to_server(servicer, server) 88 | 89 | port = server.add_secure_port('localhost:0', grpc.local_server_credentials()) 90 | server.start() 91 | return server, port 92 | 93 | 94 | def main(_): 95 | pygame.init() 96 | 97 | server, port = _start_server() 98 | 99 | with dm_env_rpc_connection.create_secure_channel_and_connect( 100 | f'localhost:{port}') as connection: 101 | env, world_name = dm_env_adaptor.create_and_join_world( 102 | connection, create_world_settings={}, join_world_settings={}) 103 | with env: 104 | window_surface = pygame.display.set_mode((800, 600), 0, 32) 105 | pygame.display.set_caption('Catch Human Agent') 106 | 107 | keep_running = True 108 | while keep_running: 109 | requested_action = _ACTION_NOTHING 110 | 111 | for event in pygame.event.get(): 112 | if event.type == pygame.QUIT: 113 | keep_running = False 114 | break 115 | elif event.type == pygame.KEYDOWN: 116 | if event.key == pygame.K_LEFT: 117 | requested_action = _ACTION_LEFT 118 | elif event.key == pygame.K_RIGHT: 119 | requested_action = _ACTION_RIGHT 120 | elif event.key == pygame.K_ESCAPE: 121 | keep_running = False 122 | break 123 | 124 | actions = {_ACTION_PADDLE: requested_action} 125 | timestep = env.step(actions) 126 | 127 | board = timestep.observation[_OBSERVATION_BOARD] 128 | reward = timestep.reward 129 | 130 | _render_window(board, window_surface, reward) 131 | 132 | pygame.display.update() 133 | 134 | pygame.time.wait(_FRAME_DELAY_MS) 135 | 136 | connection.send(dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name)) 137 | 138 | server.stop(None) 139 | 140 | 141 | if __name__ == '__main__': 142 | app.run(main) 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements necessary to run dm_env_rpc tests. 2 | # Please install other requirements via pip install. 3 | 4 | absl-py==2.1.0 5 | pytest==8.2.2 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Install script for setuptools.""" 16 | 17 | import importlib.util 18 | import os 19 | 20 | from setuptools import Command 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | from setuptools.command.build_ext import build_ext 24 | from setuptools.command.build_py import build_py 25 | 26 | 27 | _ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | _GOOGLE_COMMON_PROTOS_ROOT_DIR = os.path.join( 29 | _ROOT_DIR, 'third_party/api-common-protos' 30 | ) 31 | 32 | # Tuple of proto message definitions to build Python bindings for. Paths must 33 | # be relative to root directory. 34 | _DM_ENV_RPC_PROTOS = ( 35 | 'dm_env_rpc/v1/dm_env_rpc.proto', 36 | 'dm_env_rpc/v1/extensions/properties.proto', 37 | ) 38 | 39 | 40 | class _GenerateProtoFiles(Command): 41 | """Command to generate protobuf bindings for dm_env_rpc.proto.""" 42 | 43 | descriptions = 'Generates Python protobuf bindings for dm_env_rpc.proto.' 44 | user_options = [] 45 | 46 | def initialize_options(self): 47 | pass 48 | 49 | def finalize_options(self): 50 | pass 51 | 52 | def run(self): 53 | # Import grpc_tools and importlib_resources here, after setuptools has 54 | # installed setup_requires dependencies. 55 | from grpc_tools import protoc # pylint: disable=g-import-not-at-top 56 | import importlib_resources # pylint: disable=g-import-not-at-top 57 | 58 | if not os.path.exists( 59 | os.path.join(_GOOGLE_COMMON_PROTOS_ROOT_DIR, 'google/rpc/status.proto') 60 | ): 61 | raise RuntimeError( 62 | 'Cannot find third_party/api-common-protos. ' 63 | 'Please run `git submodule init && git submodule update` to install ' 64 | 'the api-common-protos submodule.' 65 | ) 66 | 67 | with importlib_resources.as_file( 68 | importlib_resources.files('grpc_tools') / '_proto' 69 | ) as grpc_protos_include: 70 | for proto_path in _DM_ENV_RPC_PROTOS: 71 | proto_args = [ 72 | 'grpc_tools.protoc', 73 | '--proto_path={}'.format(_GOOGLE_COMMON_PROTOS_ROOT_DIR), 74 | '--proto_path={}'.format(grpc_protos_include), 75 | '--proto_path={}'.format(_ROOT_DIR), 76 | '--python_out={}'.format(_ROOT_DIR), 77 | '--grpc_python_out={}'.format(_ROOT_DIR), 78 | os.path.join(_ROOT_DIR, proto_path), 79 | ] 80 | if protoc.main(proto_args) != 0: 81 | raise RuntimeError('ERROR: {}'.format(proto_args)) 82 | 83 | 84 | class _BuildExt(build_ext): 85 | """Generate protobuf bindings in build_ext stage.""" 86 | 87 | def run(self): 88 | self.run_command('generate_protos') 89 | build_ext.run(self) 90 | 91 | 92 | class _BuildPy(build_py): 93 | """Generate protobuf bindings in build_py stage.""" 94 | 95 | def run(self): 96 | self.run_command('generate_protos') 97 | build_py.run(self) 98 | 99 | 100 | def _load_version(): 101 | """Load dm_env_rpc version.""" 102 | spec = importlib.util.spec_from_file_location( 103 | '_version', 'dm_env_rpc/_version.py' 104 | ) 105 | version_module = importlib.util.module_from_spec(spec) 106 | spec.loader.exec_module(version_module) 107 | return version_module.__version__ 108 | 109 | 110 | setup( 111 | name='dm-env-rpc', 112 | version=_load_version(), 113 | description='A networking protocol for agent-environment communication.', 114 | author='DeepMind', 115 | license='Apache License, Version 2.0', 116 | keywords='reinforcement-learning python machine learning', 117 | packages=find_packages(exclude=['examples']), 118 | install_requires=[ 119 | 'dm-env>=1.2', 120 | 'immutabledict', 121 | 'googleapis-common-protos', 122 | 'grpcio', 123 | 'numpy<2.0', 124 | 'protobuf>=3.8', 125 | ], 126 | python_requires='>=3.8', 127 | setup_requires=['grpcio-tools', 'importlib_resources'], 128 | extras_require={ 129 | 'examples': ['pygame'], 130 | }, 131 | cmdclass={ 132 | 'build_ext': _BuildExt, 133 | 'build_py': _BuildPy, 134 | 'generate_protos': _GenerateProtoFiles, 135 | }, 136 | classifiers=[ 137 | 'Development Status :: 5 - Production/Stable', 138 | 'Environment :: Console', 139 | 'Intended Audience :: Science/Research', 140 | 'License :: OSI Approved :: Apache Software License', 141 | 'Operating System :: POSIX :: Linux', 142 | 'Operating System :: Microsoft :: Windows', 143 | 'Operating System :: MacOS :: MacOS X', 144 | 'Programming Language :: Python :: 3.8', 145 | 'Programming Language :: Python :: 3.9', 146 | 'Programming Language :: Python :: 3.10', 147 | 'Programming Language :: Python :: 3.11', 148 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 149 | ], 150 | ) 151 | --------------------------------------------------------------------------------