├── .github
└── workflows
│ └── publish-to-pypi.yml
├── .gitignore
├── .gitmodules
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── RELEASE_NOTES.md
├── dm_env_rpc
├── __init__.py
├── _version.py
└── v1
│ ├── __init__.py
│ ├── async_connection.py
│ ├── async_connection_test.py
│ ├── compliance
│ ├── __init__.py
│ ├── create_destroy_world.py
│ ├── join_leave_world.py
│ ├── reset.py
│ ├── reset_world.py
│ └── step.py
│ ├── connection.py
│ ├── connection_test.py
│ ├── dm_env_adaptor.py
│ ├── dm_env_adaptor_test.py
│ ├── dm_env_flatten_utils.py
│ ├── dm_env_flatten_utils_test.py
│ ├── dm_env_rpc.proto
│ ├── dm_env_rpc_test.py
│ ├── dm_env_utils.py
│ ├── dm_env_utils_test.py
│ ├── error.py
│ ├── error_test.py
│ ├── extensions
│ ├── __init__.py
│ ├── properties.proto
│ ├── properties.py
│ └── properties_test.py
│ ├── message_utils.py
│ ├── message_utils_test.py
│ ├── spec_manager.py
│ ├── spec_manager_test.py
│ ├── tensor_spec_utils.py
│ ├── tensor_spec_utils_test.py
│ ├── tensor_utils.py
│ ├── tensor_utils_benchmark.py
│ └── tensor_utils_test.py
├── docs
└── v1
│ ├── 2x2.png
│ ├── 2x3.png
│ ├── appendix.md
│ ├── extensions
│ ├── index.md
│ └── properties.md
│ ├── glossary.md
│ ├── index.md
│ ├── overview.md
│ ├── reference.md
│ ├── single_agent_connect_and_step.png
│ ├── single_agent_sequence_transitions.png
│ ├── single_agent_world_destruction.png
│ ├── state_transitions.graphviz
│ └── state_transitions.png
├── examples
├── catch_environment.py
├── catch_human_agent.py
└── catch_test.py
├── requirements.txt
└── setup.py
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created.
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 | workflow_dispatch:
10 |
11 | jobs:
12 | deploy:
13 |
14 | runs-on: ubuntu-latest
15 |
16 | steps:
17 | - uses: actions/checkout@v2
18 | - name: Set up Python
19 | uses: actions/setup-python@v2
20 | with:
21 | python-version: '3.7'
22 | - name: Add repository sub-modules
23 | run: |
24 | git submodule init
25 | git submodule update
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install setuptools wheel twine
30 | - name: Build and publish
31 | env:
32 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
33 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
34 | run: |
35 | python setup.py sdist bdist_wheel
36 | twine upload dist/*
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore byte-compiled Python code
2 | *.py[cod]
3 |
4 | # Ignore protobuf bindings.
5 | **/*pb2*.py
6 |
7 | # Ignore directories created during the build/installation process
8 | *.egg-info/
9 | .eggs/
10 | build/
11 | dist/
12 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/api-common-protos"]
2 | path = third_party/api-common-protos
3 | url = https://github.com/googleapis/api-common-protos
4 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project, however this
4 | library is widely used in our research code, so we are unlikely to be able to
5 | accept breaking changes to the interface.
6 |
7 | There are just a few small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # `dm_env_rpc`: A networking protocol for agent-environment communication.
2 |
3 | 
4 | 
5 |
6 | `dm_env_rpc` is a remote procedure call (RPC) protocol for communicating between
7 | machine learning agents and environments. It uses [gRPC](http://www.grpc.io) as
8 | the underlying communication framework, specifically its
9 | [bidirectional streaming](http://grpc.io/docs/guides/concepts/#bidirectional-streaming-rpc)
10 | RPC variant.
11 |
12 | This package also contains an implementation of
13 | [`dm_env`](http://www.github.com/deepmind/dm_env), a Python interface for
14 | interacting with such environments.
15 |
16 | Please see the documentation for more detailed information on the semantics of
17 | the protocol and how to use it. The examples sub-directory also provides
18 | examples of RL environments implemented using the `dm_env_rpc` protocol.
19 |
20 | ## Intended audience
21 |
22 | Games can make for interesting AI research platforms, for example as
23 | reinforcement learning (RL) environments. However, exposing a game as an RL
24 | environment can be a subtle, fraught process. We aim to provide a protocol that
25 | allows agents and environments to communicate in a standardized way, without
26 | specialized knowledge about how the other side works. Game developers can expose
27 | their games as environments with minimal domain knowledge and researchers can
28 | test their agents on a large library of different games.
29 |
30 | This protocol also removes the need for agents and environments to run in the
31 | same process or even on the same machine, allowing agents and environments to
32 | have very different technology stacks and requirements.
33 |
34 | ## Documentation
35 |
36 | * [Protocol overview](docs/v1/overview.md)
37 | * [Protocol reference](docs/v1/reference.md)
38 | * [Extensions](docs/v1/extensions/index.md)
39 | * [Appendix](docs/v1/appendix.md)
40 | * [Glossary](docs/v1/glossary.md)
41 |
42 | ## Installation
43 |
44 | Note: You may optionally wish to create a
45 | [Python Virtual Environment](https://docs.python.org/3/tutorial/venv.html) to
46 | prevent conflicts with your system's Python environment.
47 |
48 | `dm_env_rpc` can be installed from [PyPi](https://pypi.org/project/dm-env-rpc/)
49 | using `pip`:
50 |
51 | ```bash
52 | $ pip install dm-env-rpc
53 | ```
54 |
55 | To also install the dependencies for the `examples/`, install with:
56 |
57 | ```bash
58 | $ pip install dm-env-rpc[examples]
59 | ```
60 |
61 | Alternatively, you can install `dm_env_rpc` by cloning a local copy of our
62 | GitHub repository:
63 |
64 | ```bash
65 | $ git clone --recursive https://github.com/deepmind/dm_env_rpc.git
66 | $ pip install ./dm_env_rpc
67 | ```
68 |
69 | ## Citing `dm_env_rpc`
70 |
71 | To cite this repository:
72 |
73 | ```bibtex
74 | @misc{dm_env_rpc2019,
75 | author = {Tom Ward and Jay Lemmon},
76 | title = {dm\_env\_rpc: A networking protocol for agent-environment communication},
77 | url = {http://github.com/deepmind/dm_env_rpc},
78 | year = {2019},
79 | }
80 | ```
81 |
82 | ## Developer Notes
83 |
84 | In order to run the included unit tests, developers must also install additional
85 | dependencies defined in `requirements.txt`, at which point the tests should be
86 | runnable via `pytest`. For example:
87 |
88 | ```bash
89 | $ git clone --recursive https://github.com/deepmind/dm_env_rpc.git
90 | $ pip install -U pip wheel
91 | $ pip install -e ./dm_env_rpc
92 | $ pip install -r ./dm_env_rpc/requirements.txt
93 | $ pytest ./dm_env_rpc
94 | ```
95 |
96 | ## Notice
97 |
98 | This is not an officially supported Google product
99 |
--------------------------------------------------------------------------------
/RELEASE_NOTES.md:
--------------------------------------------------------------------------------
1 | # Release Notes
2 |
3 | ## [1.1.6]
4 |
5 | * New `AsyncConnection` class that allows users to make asynchronous requests
6 | via `asyncio`.
7 | * Add optional `metadata` argument to `create_secure_channel_and_connect`.
8 | * Expose error status details in `DmEnvRpcError` exception.
9 | * Improvements and fixes related to type hints, in particular hints relating
10 | to NumPy.
11 | * Add utility functions for packing/unpacking custom request/responses.
12 | * Remove support for Python 3.7, now that it has reached EOL.
13 |
14 | ## [1.1.5]
15 |
16 | * Add optional `strict` argument to dictionary flattening.
17 | * Relax flattening of create/join world settings, so that users can pass
18 | already flattened settings.
19 |
20 | ## [1.1.4]
21 |
22 | * Support for nested create/join world settings in `dm_env_adaptor` helper
23 | functions.
24 | * Support scalar min/max bounds for array actions in compliance tests.
25 | * Allow `DmEnvAdaptor` to be closable multiple times.
26 | * Cleaned up various type hints, specifically removing deprecated NumPy type
27 | aliases.
28 | * Bug fixes.
29 |
30 | ## [1.1.3]
31 |
32 | * Pass all additional keyword arguments through to `DmEnvAdaptor` when using
33 | create/join helper functions.
34 | * Bug fixes and cleanup.
35 |
36 | ## [1.1.2]
37 |
38 | * Fixed Catch human agent example, raised in this
39 | [GitHub issue](https://github.com/deepmind/dm_env_rpc/issues/2).
40 | * Fixed bug when attempting to pack an empty array as a particular dtype.
41 | * Minor cleanup.
42 |
43 | ## [1.1.1]
44 |
45 | * Removed support for Python 3.6
46 | * Updated compliance tests to support wider range of environments.
47 | * Fixed bug with packing large `np.uint64` scalars.
48 | * Various PyType fixes.
49 |
50 | ## [1.1.0]
51 |
52 | WARNING: This release removes support for some previously deprecated fields.
53 | This may mean that scalar bounds for older environments are no longer readable.
54 | Users are advised to either revert to an older version, or re-build their
55 | environments to use the newer, multi-dimensional `TensorSpec.Value` fields.
56 |
57 | * Removed scalar `TensorSpec.Value` fields, which were marked as deprecated in
58 | [v1.0.1](#101). These have been superseded by array variants, which can be
59 | used for scalar bounds by creating a single element array.
60 | * Removed deprecated Property request/responses. These are now provided
61 | through the optional Property extension.
62 | * Refactored `Connection` to expose message packing utilities.
63 |
64 | ## [1.0.6]
65 |
66 | * `tensor_spec.bounds()` no longer broadcasts scalar bounds.
67 | * Fixed bug where `reward` and `discount` were inadvertently included in the
68 | observations when using `dm_env_adaptor`, without explicitly requesting
69 | these as observations.
70 |
71 | ## [1.0.5]
72 |
73 | * Better support for string specs in `dm_env_adaptor`.
74 | * Improved Python type annotations.
75 | * Check that the server has returned the correct response in
76 | `dm_env_rpc.Connection` class.
77 | * Added `create_world` helper function to `dm_env_adaptor`.
78 |
79 | ## [1.0.4]
80 |
81 | * Better support for variable sized tensors.
82 | * Support for packing/unpacking tensors that use `Any` protobuf messages.
83 | * Bug fixes.
84 |
85 | ## [1.0.3]
86 |
87 | ### Added
88 |
89 | * Support for property descriptions.
90 | * New utility functions for creating a Connection instance from a server
91 | address.
92 | * DmEnvAdaptor helper functions for creating and joining worlds.
93 | * Additional compliance tests for resetting.
94 | * Support for optional DmEnvAdaptor extensions.
95 |
96 | ### Changed
97 |
98 | * Removed portpicker dependency, instead relying on gRPC port picking
99 | functionality.
100 | * Changed property extension API to be more amenable to being used as an
101 | extension object for DmEnvAdaptor.
102 |
103 | ## [1.0.2]
104 |
105 | * Explicitly support nested tensors by the use of a period character in the
106 | `TensorSpec` name to indicate a level of nesting. Updated `dm_env` adaptor
107 | to flatten/unflattten actions and observations.
108 | * Increased minimum Python version to 3.6.
109 | * Moved property request/responses to its own extension. This supersedes the
110 | previous property requests, which have been marked as deprecated. **These
111 | requests will be removed in a future version of dm_env_rpc**.
112 | * Speed improvements for packing and unpacking byte arrays in Python.
113 |
114 | ## [1.0.1]
115 |
116 | ### Added
117 |
118 | * Support for per-element min/max values. This supersedes the existing scalar
119 | fields, which have been marked as deprecated. **These fields will be be
120 | removed in a future version of dm_env_rpc.**
121 | * Initial set of compliance tests that environment authors can use to better
122 | ensure their implementations adhere to the protocol specification.
123 | * Support for `dm_env` DiscreteArray specs.
124 |
125 | ### Changed
126 |
127 | * `dm_env_rpc` `EnvironmentResponse` errors in Python are now raised as a
128 | custom, `DmEnvRpcError` exception.
129 |
130 | ## [1.0.0]
131 |
132 | * Initial release.
133 |
134 | ## [1.0.0b2]
135 |
136 | * Updated minimum requirements for Python and protobuf.
137 |
138 | ## [1.0.0b1]
139 |
140 | * Initial beta release
141 |
--------------------------------------------------------------------------------
/dm_env_rpc/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A networking protocol for agent-environment communication."""
16 | from dm_env_rpc._version import __version__
17 |
--------------------------------------------------------------------------------
/dm_env_rpc/_version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Package version for dm_env_rpc.
16 |
17 | Kept in separate file so it can be used during installation.
18 | """
19 |
20 | __version__ = '1.1.6' # https://www.python.org/dev/peps/pep-0440/
21 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Initial version of dm_env_rpc networking protocol."""
16 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/async_connection.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A helper class to manage a connection to a dm_env_rpc server asynchronously.
16 |
17 | This helper class allows sending the Request message types and receiving the
18 | Response message as a future. The class automatically wraps and unwraps from
19 | EnvironmentRequest and EnvironmentResponse, respectively. It also turns error
20 | messages in to exceptions.
21 |
22 | For most calls (such as create, join, etc.):
23 |
24 | with async_connection.AsyncConnection(grpc_channel) as async_channel:
25 | create_response = await async_channel.send(
26 | dm_env_rpc_pb2.CreateRequest(settings={
27 | 'players': 5
28 | })
29 |
30 | For the `extension` message type, you must send an Any proto and you'll get back
31 | an Any proto. It is up to you to wrap and unwrap these to concrete proto types
32 | that you know how to handle.
33 |
34 | with async_connection.AsyncConnection(grpc_channel) as async_channel:
35 | request = struct_pb2.Struct()
36 | ...
37 | request_any = any_pb2.Any()
38 | request_any.Pack(request)
39 | response_any = await async_channel.send(request_any)
40 | response = my_type_pb2.MyType()
41 | response_any.Unpack(response)
42 |
43 |
44 | Any errors encountered in the EnvironmentResponse are turned into Python
45 | exceptions, so explicit error handling code isn't needed per call.
46 | """
47 |
48 | import asyncio
49 | from typing import Optional, Sequence, Tuple
50 |
51 | import grpc
52 |
53 | from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc
54 | from dm_env_rpc.v1 import message_utils
55 |
56 |
57 | Metadata = Sequence[Tuple[str, str]]
58 |
59 |
60 | class AsyncConnection:
61 | """A helper class for interacting with dm_env_rpc servers asynchronously."""
62 |
63 | def __init__(
64 | self, channel: grpc.aio.Channel, metadata: Optional[Metadata] = None
65 | ):
66 | """Manages an async connection to a dm_env_rpc server.
67 |
68 | Args:
69 | channel: An async grpc channel to connect to the dm_env_rpc server over.
70 | metadata: Optional sequence of 2-tuples, sent to the gRPC server as
71 | metadata.
72 | """
73 | self._stream = dm_env_rpc_pb2_grpc.EnvironmentStub(channel).Process(
74 | metadata=metadata
75 | )
76 |
77 | async def send(
78 | self,
79 | request: message_utils.DmEnvRpcRequest) -> message_utils.DmEnvRpcResponse:
80 | """Sends `request` to the dm_env_rpc server and returns a response future.
81 |
82 | The request should be an instance of one of the dm_env_rpc Request messages,
83 | such as CreateWorldRequest. Based on the type the correct payload for the
84 | EnvironmentRequest will be constructed and sent to the dm_env_rpc server.
85 |
86 | Returns an awaitable future to retrieve the response.
87 |
88 | Args:
89 | request: An instance of a dm_env_rpc Request type, such as
90 | CreateWorldRequest.
91 |
92 | Returns:
93 | An asyncio Future which can be awaited to retrieve the response from the
94 | dm_env_rpc server returned for the given RPC call, unwrapped from the
95 | EnvironmentStream message. For instance if `request` had type
96 | `CreateWorldRequest` this returns a message of type `CreateWorldResponse`.
97 |
98 | Raises:
99 | DmEnvRpcError: The dm_env_rpc server responded to the request with an
100 | error.
101 | ValueError: The dm_env_rpc server responded to the request with an
102 | unexpected response message.
103 | """
104 | environment_request, field_name = (
105 | message_utils.pack_environment_request(request))
106 | if self._stream is None:
107 | raise ValueError('Cannot send request after stream is closed.')
108 | await self._stream.write(environment_request)
109 | return message_utils.unpack_environment_response(await self._stream.read(),
110 | field_name)
111 |
112 | def close(self):
113 | """Closes the connection. Call when the connection is no longer needed."""
114 | if self._stream:
115 | self._stream = None
116 |
117 | def __exit__(self, *args, **kwargs):
118 | self.close()
119 |
120 | def __enter__(self):
121 | return self
122 |
123 |
124 | async def create_secure_async_channel_and_connect(
125 | server_address: str,
126 | credentials: grpc.ChannelCredentials = grpc.local_channel_credentials(),
127 | metadata: Optional[Metadata] = None,
128 | ) -> AsyncConnection:
129 | """Creates a secure async channel from address and credentials and connects.
130 |
131 | We allow the created channel to have un-bounded message lengths, to support
132 | large observations.
133 |
134 | Args:
135 | server_address: URI server address to connect to.
136 | credentials: gRPC credentials necessary to connect to the server.
137 | metadata: Optional sequence of 2-tuples, sent to the gRPC server as
138 | metadata.
139 |
140 | Returns:
141 | An instance of dm_env_rpc.AsyncConnection, where the async channel is closed
142 | upon the connection being closed.
143 | """
144 | options = [('grpc.max_send_message_length', -1),
145 | ('grpc.max_receive_message_length', -1)]
146 | channel = grpc.aio.secure_channel(server_address, credentials,
147 | options=options)
148 | await channel.channel_ready()
149 |
150 | class _ConnectionWrapper(AsyncConnection):
151 | """Utility to ensure channel is closed when the connection is closed."""
152 |
153 | def __init__(self, channel, metadata):
154 | super().__init__(channel=channel, metadata=metadata)
155 | self._channel = channel
156 |
157 | def __del__(self):
158 | self.close()
159 |
160 | def close(self):
161 | super().close()
162 | try:
163 | loop = asyncio.get_running_loop()
164 | except RuntimeError:
165 | loop = None
166 | if loop and loop.is_running():
167 | return asyncio.ensure_future(self._channel.close())
168 | else:
169 | return asyncio.run(self._channel.close())
170 |
171 | return _ConnectionWrapper(channel=channel, metadata=metadata)
172 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/async_connection_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests for AsyncConnection."""
16 |
17 | import asyncio
18 | import contextlib
19 | import queue
20 | import unittest
21 | from unittest import mock
22 |
23 | from absl.testing import absltest
24 | import grpc
25 |
26 | from google.protobuf import any_pb2
27 | from google.protobuf import struct_pb2
28 | from google.rpc import status_pb2
29 | from dm_env_rpc.v1 import async_connection
30 | from dm_env_rpc.v1 import dm_env_rpc_pb2
31 | from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc
32 | from dm_env_rpc.v1 import error
33 | from dm_env_rpc.v1 import tensor_utils
34 |
35 | _CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest(
36 | settings={'foo': tensor_utils.pack_tensor('bar')})
37 | _CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse()
38 |
39 | _BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest()
40 | _TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse(
41 | error=status_pb2.Status(message='A test error.'))
42 |
43 | _INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest(
44 | world_name='foo')
45 | _INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse(
46 | leave_world=dm_env_rpc_pb2.LeaveWorldResponse())
47 |
48 | _EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request')
49 | _EXTENSION_RESPONSE = struct_pb2.Value(number_value=555)
50 |
51 |
52 | def _wrap_in_any(proto):
53 | any_proto = any_pb2.Any()
54 | any_proto.Pack(proto)
55 | return any_proto
56 |
57 |
58 | _REQUEST_RESPONSE_PAIRS = {
59 | dm_env_rpc_pb2.EnvironmentRequest(
60 | create_world=_CREATE_REQUEST).SerializeToString():
61 | dm_env_rpc_pb2.EnvironmentResponse(create_world=_CREATE_RESPONSE),
62 | dm_env_rpc_pb2.EnvironmentRequest(
63 | create_world=_BAD_CREATE_REQUEST).SerializeToString():
64 | _TEST_ERROR,
65 | dm_env_rpc_pb2.EnvironmentRequest(
66 | extension=_wrap_in_any(_EXTENSION_REQUEST)).SerializeToString():
67 | dm_env_rpc_pb2.EnvironmentResponse(
68 | extension=_wrap_in_any(_EXTENSION_RESPONSE)),
69 | dm_env_rpc_pb2.EnvironmentRequest(
70 | destroy_world=_INCORRECT_RESPONSE_TEST_MSG).SerializeToString():
71 | _INCORRECT_RESPONSE,
72 | }
73 |
74 |
75 | def _process(metadata: async_connection.Metadata) -> grpc.aio.StreamStreamCall:
76 | requests = queue.Queue()
77 |
78 | async def _write(request):
79 | requests.put(request)
80 |
81 | async def _read():
82 | request = requests.get()
83 | return _REQUEST_RESPONSE_PAIRS.get(request.SerializeToString(), _TEST_ERROR)
84 |
85 | mock_stream = mock.create_autospec(grpc.aio.StreamStreamCall)
86 | mock_stream.write = _write
87 | mock_stream.read = _read
88 | mock_stream.metadata = metadata
89 | return mock_stream
90 |
91 |
92 | @contextlib.contextmanager
93 | def _create_mock_async_channel():
94 | """Mocks out gRPC and returns a channel to be passed to Connection."""
95 | with mock.patch.object(async_connection, 'dm_env_rpc_pb2_grpc') as mock_grpc:
96 | mock_stub_class = mock.create_autospec(dm_env_rpc_pb2_grpc.EnvironmentStub)
97 | mock_stub_class.Process = _process
98 | mock_grpc.EnvironmentStub.return_value = mock_stub_class
99 | yield mock.MagicMock()
100 |
101 |
102 | class AsyncConnectionAsyncTests(unittest.IsolatedAsyncioTestCase):
103 |
104 | async def test_create(self):
105 | with _create_mock_async_channel() as mock_channel:
106 | with async_connection.AsyncConnection(mock_channel) as connection:
107 | response = await connection.send(_CREATE_REQUEST)
108 | self.assertEqual(_CREATE_RESPONSE, response)
109 |
110 | async def test_send_error_after_close(self):
111 | with _create_mock_async_channel() as mock_channel:
112 | with async_connection.AsyncConnection(mock_channel) as connection:
113 | connection.close()
114 | with self.assertRaisesRegex(ValueError, 'stream is closed'):
115 | await connection.send(_CREATE_REQUEST)
116 |
117 | async def test_error(self):
118 | with _create_mock_async_channel() as mock_channel:
119 | with async_connection.AsyncConnection(mock_channel) as connection:
120 | with self.assertRaisesRegex(error.DmEnvRpcError, 'test error'):
121 | await connection.send(_BAD_CREATE_REQUEST)
122 |
123 | async def test_extension(self):
124 | with _create_mock_async_channel() as mock_channel:
125 | with async_connection.AsyncConnection(mock_channel) as connection:
126 | request = any_pb2.Any()
127 | request.Pack(_EXTENSION_REQUEST)
128 | response = await connection.send(request)
129 | expected_response = any_pb2.Any()
130 | expected_response.Pack(_EXTENSION_RESPONSE)
131 | self.assertEqual(expected_response, response)
132 |
133 | async def test_incorrect_response(self):
134 | with _create_mock_async_channel() as mock_channel:
135 | with async_connection.AsyncConnection(mock_channel) as connection:
136 | with self.assertRaisesRegex(ValueError, 'Unexpected response message'):
137 | await connection.send(_INCORRECT_RESPONSE_TEST_MSG)
138 |
139 | def test_with_metadata(self):
140 | expected_metadata = (('key', 'value'),)
141 | with mock.patch.object(async_connection,
142 | 'dm_env_rpc_pb2_grpc') as mock_grpc:
143 | mock_stub_class = mock.MagicMock()
144 | mock_grpc.EnvironmentStub.return_value = mock_stub_class
145 | _ = async_connection.AsyncConnection(
146 | mock.MagicMock(), metadata=expected_metadata)
147 | mock_stub_class.Process.assert_called_with(
148 | metadata=expected_metadata)
149 |
150 | @mock.patch.object(grpc.aio, 'secure_channel')
151 | async def test_create_secure_channel_and_connect_context(
152 | self, mock_secure_channel):
153 |
154 | mock_async_channel = mock.MagicMock()
155 | mock_async_channel.channel_ready = absltest.mock.AsyncMock()
156 | mock_async_channel.close = absltest.mock.AsyncMock()
157 | mock_secure_channel.return_value = mock_async_channel
158 |
159 | with await async_connection.create_secure_async_channel_and_connect(
160 | 'valid_address') as connection:
161 | self.assertIsNotNone(connection)
162 |
163 | await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
164 |
165 | mock_async_channel.close.assert_called_once()
166 | mock_async_channel.channel_ready.assert_called_once()
167 | mock_secure_channel.assert_called_once()
168 |
169 |
170 | class AsyncConnectionSyncTests(absltest.TestCase):
171 |
172 | @absltest.mock.patch.object(grpc.aio, 'secure_channel')
173 | def test_create_secure_channel_and_connect_context(self, mock_secure_channel):
174 |
175 | mock_async_channel = absltest.mock.MagicMock()
176 | mock_async_channel.channel_ready = absltest.mock.AsyncMock()
177 | mock_async_channel.close = absltest.mock.AsyncMock()
178 | mock_secure_channel.return_value = mock_async_channel
179 |
180 | asyncio.set_event_loop(asyncio.new_event_loop())
181 | loop = asyncio.get_event_loop()
182 |
183 | connection_task = asyncio.ensure_future(
184 | async_connection.create_secure_async_channel_and_connect(
185 | 'valid_address'))
186 | connection = loop.run_until_complete(connection_task)
187 |
188 | loop.stop()
189 | asyncio.set_event_loop(None)
190 |
191 | connection.close()
192 |
193 | mock_async_channel.close.assert_called_once()
194 | mock_async_channel.channel_ready.assert_called_once()
195 | mock_secure_channel.assert_called_once()
196 |
197 |
198 | if __name__ == '__main__':
199 | absltest.main()
200 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/compliance/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Compliance test base classes for dm_env_rpc."""
16 | from dm_env_rpc.v1.compliance import create_destroy_world
17 | from dm_env_rpc.v1.compliance import join_leave_world
18 | from dm_env_rpc.v1.compliance import reset
19 | from dm_env_rpc.v1.compliance import reset_world
20 | from dm_env_rpc.v1.compliance import step
21 |
22 | CreateDestroyWorld = create_destroy_world.CreateDestroyWorld
23 | JoinLeaveWorld = join_leave_world.JoinLeaveWorld
24 | Reset = reset.Reset
25 | ResetWorld = reset_world.ResetWorld
26 | Step = step.Step
27 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/compliance/create_destroy_world.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A base class for CreateWorld and DestroyWorld tests for a server."""
16 | import abc
17 |
18 | from absl.testing import absltest
19 |
20 | from dm_env_rpc.v1 import dm_env_rpc_pb2
21 | from dm_env_rpc.v1 import error
22 |
23 |
24 | class CreateDestroyWorld(absltest.TestCase, metaclass=abc.ABCMeta):
25 | """A base class for `CreateWorld` and `DestroyWorld` compliance tests."""
26 |
27 | @abc.abstractproperty
28 | def required_world_settings(self):
29 | """A string to Tensor mapping of the minimum set of required settings."""
30 | pass
31 |
32 | @abc.abstractproperty
33 | def invalid_world_settings(self):
34 | """World creation settings which are invalid in some way."""
35 | pass
36 |
37 | @abc.abstractproperty
38 | def has_multiple_world_support(self):
39 | """Does the server support creating more than one world?"""
40 | pass
41 |
42 | @abc.abstractproperty
43 | def connection(self):
44 | """An instance of dm_env_rpc's Connection."""
45 | pass
46 |
47 | def create_world(self, settings):
48 | """Returns the world name of the world created with the given settings."""
49 | response = self.connection.send(
50 | dm_env_rpc_pb2.CreateWorldRequest(settings=settings))
51 | return response.world_name
52 |
53 | def destroy_world(self, world_name):
54 | """Destroys the world named `world_name`."""
55 | if world_name is not None:
56 | self.connection.send(
57 | dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))
58 |
59 | # pylint: disable=missing-docstring
60 | def test_can_create_and_destroy_world(self):
61 | # If this doesn't raise an exception the test passes.
62 | world_name = self.create_world(self.required_world_settings)
63 | self.destroy_world(world_name)
64 |
65 | def test_cannot_create_world_with_less_than_required_settings(self):
66 | settings = self.required_world_settings
67 |
68 | for name, _ in settings.items():
69 | sans_setting = dict(settings)
70 | del sans_setting[name]
71 | message = f'world was created without required setting "{name}"'
72 | with self.assertRaises(error.DmEnvRpcError, msg=message):
73 | self.create_world(sans_setting)
74 |
75 | def test_cannot_create_world_with_invalid_settings(self):
76 | settings = self.required_world_settings
77 | invalid_settings = self.invalid_world_settings
78 | for name, tensor in invalid_settings.items():
79 | message = f'world was created with invalid setting "{name}"'
80 | with self.assertRaises(error.DmEnvRpcError, msg=message):
81 | self.create_world({name: tensor, **settings})
82 |
83 | def test_world_name_is_unique(self):
84 | if not self.has_multiple_world_support:
85 | return
86 | world1_name = None
87 | world2_name = None
88 | try:
89 | world1_name = self.create_world(self.required_world_settings)
90 | world2_name = self.create_world(self.required_world_settings)
91 | self.assertIsNotNone(world1_name)
92 | self.assertIsNotNone(world2_name)
93 | self.assertNotEqual(world1_name, world2_name)
94 | finally:
95 | self.destroy_world(world1_name)
96 | self.destroy_world(world2_name)
97 |
98 | def test_cannot_destroy_uncreated_world(self):
99 | with self.assertRaises(error.DmEnvRpcError):
100 | self.destroy_world('foo')
101 | # pylint: enable=missing-docstring
102 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/compliance/join_leave_world.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A base class for JoinWorld and LeaveWord tests for a server."""
16 | import abc
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | from dm_env_rpc.v1 import dm_env_rpc_pb2
22 | from dm_env_rpc.v1 import error
23 | from dm_env_rpc.v1 import tensor_spec_utils
24 |
25 |
26 | def _find_duplicates(iterable):
27 | """Returns a list of duplicate entries found in `iterable`."""
28 | duplicates = []
29 | seen = set()
30 | for item in iterable:
31 | if item in seen:
32 | duplicates.append(item)
33 | else:
34 | seen.add(item)
35 | return duplicates
36 |
37 |
38 | def _check_tensor_spec(tensor_spec):
39 | """Raises an error if the given `tensor_spec` is internally inconsistent."""
40 | if np.sum(np.asarray(tensor_spec.shape) < 0) > 1:
41 | raise ValueError(
42 | f'"{tensor_spec.name}" has shape {tensor_spec.shape} which has more '
43 | 'than one negative element.')
44 | min_type = tensor_spec.min and tensor_spec.min.WhichOneof('payload')
45 | max_type = tensor_spec.max and tensor_spec.max.WhichOneof('payload')
46 | if min_type or max_type:
47 | _ = tensor_spec_utils.bounds(tensor_spec)
48 |
49 |
50 | class JoinLeaveWorld(absltest.TestCase, metaclass=abc.ABCMeta):
51 | """A base class for `JoinWorld` and `LeaveWorld` compliance tests."""
52 |
53 | @property
54 | def required_join_settings(self):
55 | """A dict of required settings for a Join World call."""
56 | return {}
57 |
58 | @property
59 | def invalid_join_settings(self):
60 | """A list of dicts of Join World settings which are invalid in some way."""
61 | return {}
62 |
63 | @abc.abstractproperty
64 | def world_name(self):
65 | """A string of the world name of an already created world."""
66 | pass
67 |
68 | @property
69 | def invalid_world_name(self):
70 | """A string which doesn't correspond to any valid world_name."""
71 | return 'invalid_world_name'
72 |
73 | @property
74 | @abc.abstractmethod
75 | def connection(self):
76 | """An instance of dm_env_rpc's Connection."""
77 | pass
78 |
79 | def tearDown(self):
80 | super().tearDown()
81 | try:
82 | self.leave_world()
83 | finally:
84 | pass
85 |
86 | def join_world(self, **kwargs):
87 | """Joins the world and returns the spec."""
88 | response = self.connection.send(dm_env_rpc_pb2.JoinWorldRequest(**kwargs))
89 | return response.specs
90 |
91 | def leave_world(self):
92 | """Leaves currently joined world, if any."""
93 | self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
94 |
95 | # pylint: disable=missing-docstring
96 | def test_can_join(self):
97 | self.join_world(
98 | world_name=self.world_name, settings=self.required_join_settings)
99 | # Success if there's no error raised.
100 |
101 | def test_cannot_join_with_wrong_world_name(self):
102 | with self.assertRaises(error.DmEnvRpcError):
103 | self.join_world(world_name=self.invalid_world_name)
104 |
105 | def test_cannot_join_world_with_invalid_settings(self):
106 | settings = self.required_join_settings
107 | for name, tensor in self.invalid_join_settings.items():
108 | with self.assertRaises(error.DmEnvRpcError):
109 | self.join_world(
110 | world_name=self.world_name, settings={
111 | name: tensor,
112 | **settings
113 | })
114 |
115 | def test_cannot_join_world_twice(self):
116 | self.join_world(
117 | world_name=self.world_name, settings=self.required_join_settings)
118 | with self.assertRaises(error.DmEnvRpcError):
119 | self.join_world(
120 | world_name=self.world_name, settings=self.required_join_settings)
121 |
122 | def test_action_specs_have_unique_names(self):
123 | specs = self.join_world(
124 | world_name=self.world_name, settings=self.required_join_settings)
125 | self.assertEmpty(_find_duplicates(
126 | spec.name for spec in specs.actions.values()))
127 |
128 | def test_action_specs_for_consistency(self):
129 | specs = self.join_world(
130 | world_name=self.world_name, settings=self.required_join_settings)
131 | for action_spec in specs.actions.values():
132 | _check_tensor_spec(action_spec)
133 |
134 | def test_observation_specs_have_unique_names(self):
135 | specs = self.join_world(
136 | world_name=self.world_name, settings=self.required_join_settings)
137 | self.assertEmpty(_find_duplicates(
138 | spec.name for spec in specs.observations.values()))
139 |
140 | def test_observation_specs_for_consistency(self):
141 | specs = self.join_world(
142 | world_name=self.world_name, settings=self.required_join_settings)
143 | for observation_spec in specs.observations.values():
144 | _check_tensor_spec(observation_spec)
145 |
146 | def test_can_leave_world_if_not_joined(self):
147 | self.leave_world()
148 | # Success if there's no error raised.
149 |
150 | def test_can_leave_world_after_joining(self):
151 | self.join_world(
152 | world_name=self.world_name, settings=self.required_join_settings)
153 | self.leave_world()
154 | # Success if there's no error raised.
155 |
156 | def test_can_rejoin_world_after_leaving(self):
157 | self.join_world(
158 | world_name=self.world_name, settings=self.required_join_settings)
159 | self.leave_world()
160 | self.join_world(
161 | world_name=self.world_name, settings=self.required_join_settings)
162 | # Success if there's no error raised.
163 |
164 | # pylint: enable=missing-docstring
165 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/compliance/reset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A base class for Reset tests for a server."""
16 |
17 | import abc
18 |
19 | from absl.testing import absltest
20 |
21 | from dm_env_rpc.v1 import dm_env_rpc_pb2
22 | from dm_env_rpc.v1 import error
23 |
24 |
25 | class Reset(absltest.TestCase, metaclass=abc.ABCMeta):
26 | """A base class for dm_env_rpc `Reset` compliance tests."""
27 |
28 | @property
29 | @abc.abstractmethod
30 | def connection(self):
31 | """An instance of dm_env_rpc's Connection already joined to a world."""
32 | pass
33 |
34 | @property
35 | def required_reset_settings(self):
36 | return {}
37 |
38 | @abc.abstractmethod
39 | def join_world(self):
40 | """Joins a world, returning the specs."""
41 | pass
42 |
43 | def reset(self):
44 | """Resets the environment, returning the specs."""
45 | return self.connection.send(dm_env_rpc_pb2.ResetRequest(
46 | settings=self.required_reset_settings)).specs
47 |
48 | # pylint: disable=missing-docstring
49 | def test_reset_resends_the_specs(self):
50 | join_specs = self.join_world()
51 | specs = self.reset()
52 | self.assertEqual(join_specs, specs)
53 |
54 | def test_cannot_reset_if_not_joined_to_world(self):
55 | with self.assertRaises(error.DmEnvRpcError):
56 | self.reset()
57 |
58 | def test_can_reset_multiple_times(self):
59 | join_specs = self.join_world()
60 | self.reset()
61 | specs = self.reset()
62 | self.assertEqual(join_specs, specs)
63 | # pylint: enable=missing-docstring
64 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/compliance/reset_world.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A base class for ResetWorld tests for a server."""
16 |
17 | import abc
18 |
19 | from absl.testing import absltest
20 |
21 | from dm_env_rpc.v1 import dm_env_rpc_pb2
22 | from dm_env_rpc.v1 import error
23 |
24 |
25 | class ResetWorld(absltest.TestCase, metaclass=abc.ABCMeta):
26 | """A base class for dm_env_rpc `ResetWorld` compliance tests."""
27 |
28 | @property
29 | @abc.abstractmethod
30 | def connection(self):
31 | """An instance of dm_env_rpc's Connection already joined to a world."""
32 | pass
33 |
34 | @property
35 | def required_reset_world_settings(self):
36 | """Settings necessary to pass to ResetWorld."""
37 | return {}
38 |
39 | @property
40 | def required_join_world_settings(self):
41 | """Settings necessary to pass to JoinWorld."""
42 | return {}
43 |
44 | @property
45 | def invalid_world_name(self):
46 | """The name of a world which doesn't exist."""
47 | return 'invalid_world_name'
48 |
49 | @property
50 | @abc.abstractmethod
51 | def world_name(self):
52 | """The name of the world to attempt to call ResetWorld on."""
53 | return ''
54 |
55 | def join_world(self):
56 | """Joins the world to call ResetWorld on."""
57 | self.connection.send(dm_env_rpc_pb2.JoinWorldRequest(
58 | world_name=self.world_name, settings=self.required_join_world_settings))
59 |
60 | def reset_world(self, world_name):
61 | """Resets the world."""
62 | self.connection.send(dm_env_rpc_pb2.ResetWorldRequest(
63 | world_name=world_name, settings=self.required_reset_world_settings))
64 |
65 | def leave_world(self):
66 | """Leaves the world."""
67 | self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
68 |
69 | # pylint: disable=missing-docstring
70 | def test_cannot_reset_invalid_world(self):
71 | with self.assertRaises(error.DmEnvRpcError):
72 | self.reset_world(self.invalid_world_name)
73 |
74 | def test_can_reset_world_not_joined_to(self):
75 | self.reset_world(self.world_name)
76 | # If there are no errors the test passes.
77 |
78 | def test_can_reset_world_when_joined_to_it(self):
79 | try:
80 | self.join_world()
81 | self.reset_world(self.world_name)
82 | # If there are no errors the test passes.
83 | finally:
84 | self.leave_world()
85 | # pylint: enable=missing-docstring
86 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/connection.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A helper class to manage a connection to a dm_env_rpc server.
16 |
17 | This helper class allows sending the Request message types and receiving the
18 | Response message types without wrapping in an EnvironmentRequest or unwrapping
19 | from an EnvironmentResponse. It also turns error messages in to exceptions.
20 |
21 | For most calls (such as create, join, etc.):
22 |
23 | with connection.Connection(grpc_channel) as channel:
24 | create_response = channel.send(dm_env_rpc_pb2.CreateRequest(settings={
25 | 'players': 5
26 | })
27 |
28 | For the `extension` message type, you must send an Any proto and you'll get back
29 | an Any proto. It is up to you to wrap and unwrap these to concrete proto types
30 | that you know how to handle.
31 |
32 | with connection.Connection(grpc_channel) as channel:
33 | request = struct_pb2.Struct()
34 | ...
35 | request_any = any_pb2.Any()
36 | request_any.Pack(request)
37 | response_any = channel.send(request_any)
38 | response = my_type_pb2.MyType()
39 | response_any.Unpack(response)
40 |
41 |
42 | Any errors encountered in the EnvironmentResponse are turned into Python
43 | exceptions, so explicit error handling code isn't needed per call.
44 | """
45 |
46 | import queue
47 | from typing import Optional, Protocol, Sequence, Tuple
48 | import grpc
49 |
50 | from dm_env_rpc.v1 import dm_env_rpc_pb2
51 | from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc
52 | from dm_env_rpc.v1 import message_utils
53 |
54 | Metadata = Sequence[Tuple[str, str]]
55 |
56 |
57 | class ConnectionType(Protocol):
58 | """Connection protocol definition for interacting with dm_env_rpc servers."""
59 |
60 | def send(
61 | self,
62 | request: message_utils.DmEnvRpcRequest) -> message_utils.DmEnvRpcResponse:
63 | """Blocking call to send and receive a message from a dm_env_rpc server."""
64 |
65 | def close(self):
66 | """Closes the connection. Call when the connection is no longer needed."""
67 |
68 |
69 | class StreamReaderWriter(object):
70 | """Helper class for reading/writing gRPC streams."""
71 |
72 | def __init__(self,
73 | stub: dm_env_rpc_pb2_grpc.EnvironmentStub,
74 | metadata: Optional[Metadata] = None):
75 | self._requests = queue.Queue()
76 | self._stream = stub.Process(
77 | iter(self._requests.get, None), metadata=metadata)
78 |
79 | def write(self, request: dm_env_rpc_pb2.EnvironmentRequest):
80 | """Asynchronously sends `request` to the stream."""
81 | self._requests.put(request)
82 |
83 | def read(self) -> dm_env_rpc_pb2.EnvironmentResponse:
84 | """Returns the response from stream. Blocking."""
85 | return next(self._stream)
86 |
87 |
88 | class Connection(object):
89 | """A helper class for interacting with dm_env_rpc servers."""
90 |
91 | def __init__(self,
92 | channel: grpc.Channel,
93 | metadata: Optional[Metadata] = None):
94 | """Manages a connection to a dm_env_rpc server.
95 |
96 | Args:
97 | channel: A grpc channel to connect to the dm_env_rpc server over.
98 | metadata: Optional sequence of 2-tuples, sent to the gRPC server as
99 | metadata.
100 | """
101 | self._stream = StreamReaderWriter(
102 | dm_env_rpc_pb2_grpc.EnvironmentStub(channel), metadata)
103 |
104 | def send(
105 | self,
106 | request: message_utils.DmEnvRpcRequest) -> message_utils.DmEnvRpcResponse:
107 | """Sends the given request to the dm_env_rpc server and returns the response.
108 |
109 | The request should be an instance of one of the dm_env_rpc Request messages,
110 | such as CreateWorldRequest. Based on the type the correct payload for the
111 | EnvironmentRequest will be constructed and sent to the dm_env_rpc server.
112 |
113 | Blocks until the server sends back its response.
114 |
115 | Args:
116 | request: An instance of a dm_env_rpc Request type, such as
117 | CreateWorldRequest.
118 |
119 | Returns:
120 | The response the dm_env_rpc server returned for the given RPC call,
121 | unwrapped from the EnvironmentStream message. For instance if `request`
122 | had type `CreateWorldRequest` this returns a message of type
123 | `CreateWorldResponse`.
124 |
125 | Raises:
126 | DmEnvRpcError: The dm_env_rpc server responded to the request with an
127 | error.
128 | ValueError: The dm_env_rpc server responded to the request with an
129 | unexpected response message.
130 | """
131 | environment_request, field_name = (
132 | message_utils.pack_environment_request(request))
133 | if self._stream is None:
134 | raise ValueError('Cannot send request after stream is closed.')
135 | self._stream.write(environment_request)
136 | return message_utils.unpack_environment_response(self._stream.read(),
137 | field_name)
138 |
139 | def close(self):
140 | """Closes the connection. Call when the connection is no longer needed."""
141 | if self._stream:
142 | self._stream = None
143 |
144 | def __exit__(self, *args, **kwargs):
145 | self.close()
146 |
147 | def __enter__(self):
148 | return self
149 |
150 |
151 | def create_secure_channel_and_connect(
152 | server_address: str,
153 | credentials: grpc.ChannelCredentials = grpc.local_channel_credentials(),
154 | timeout: Optional[float] = None,
155 | metadata: Optional[Metadata] = None,
156 | ) -> Connection:
157 | """Creates a secure channel from server address and credentials and connects.
158 |
159 | We allow the created channel to have un-bounded message lengths, to support
160 | large observations.
161 |
162 | Args:
163 | server_address: URI server address to connect to.
164 | credentials: gRPC credentials necessary to connect to the server.
165 | timeout: Optional timeout in seconds to wait for channel to be ready.
166 | Default to waiting indefinitely.
167 | metadata: Optional sequence of 2-tuples, sent to the gRPC server as
168 | metadata.
169 |
170 | Returns:
171 | An instance of dm_env_rpc.Connection, where the channel is close upon the
172 | connection being closed.
173 | """
174 | options = [('grpc.max_send_message_length', -1),
175 | ('grpc.max_receive_message_length', -1)]
176 | channel = grpc.secure_channel(server_address, credentials, options=options)
177 | grpc.channel_ready_future(channel).result(timeout)
178 |
179 | class _ConnectionWrapper(Connection):
180 | """Utility to ensure channel is closed when the connection is closed."""
181 |
182 | def __init__(self, channel, metadata):
183 | super().__init__(channel=channel, metadata=metadata)
184 | self._channel = channel
185 |
186 | def __del__(self):
187 | self.close()
188 |
189 | def close(self):
190 | super().close()
191 | self._channel.close()
192 |
193 | return _ConnectionWrapper(channel=channel, metadata=metadata)
194 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/connection_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests for Connection."""
16 |
17 | import contextlib
18 | from unittest import mock
19 |
20 | from absl.testing import absltest
21 | import grpc
22 |
23 | from google.protobuf import any_pb2
24 | from google.protobuf import struct_pb2
25 | from google.rpc import status_pb2
26 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection
27 | from dm_env_rpc.v1 import dm_env_rpc_pb2
28 | from dm_env_rpc.v1 import error
29 | from dm_env_rpc.v1 import tensor_utils
30 |
31 | _CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest(
32 | settings={'foo': tensor_utils.pack_tensor('bar')})
33 | _CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse()
34 |
35 | _BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest()
36 | _TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse(
37 | error=status_pb2.Status(message='A test error.'))
38 |
39 | _INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest(
40 | world_name='foo')
41 | _INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse(
42 | leave_world=dm_env_rpc_pb2.LeaveWorldResponse())
43 |
44 | _EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request')
45 | _EXTENSION_RESPONSE = struct_pb2.Value(number_value=555)
46 |
47 |
48 | def _wrap_in_any(proto):
49 | any_proto = any_pb2.Any()
50 | any_proto.Pack(proto)
51 | return any_proto
52 |
53 |
54 | _REQUEST_RESPONSE_PAIRS = {
55 | dm_env_rpc_pb2.EnvironmentRequest(
56 | create_world=_CREATE_REQUEST).SerializeToString():
57 | dm_env_rpc_pb2.EnvironmentResponse(create_world=_CREATE_RESPONSE),
58 | dm_env_rpc_pb2.EnvironmentRequest(
59 | create_world=_BAD_CREATE_REQUEST).SerializeToString():
60 | _TEST_ERROR,
61 | dm_env_rpc_pb2.EnvironmentRequest(
62 | extension=_wrap_in_any(_EXTENSION_REQUEST)).SerializeToString():
63 | dm_env_rpc_pb2.EnvironmentResponse(
64 | extension=_wrap_in_any(_EXTENSION_RESPONSE)),
65 | dm_env_rpc_pb2.EnvironmentRequest(
66 | destroy_world=_INCORRECT_RESPONSE_TEST_MSG).SerializeToString():
67 | _INCORRECT_RESPONSE,
68 | }
69 |
70 |
71 | def _process(request_iterator, **kwargs):
72 | del kwargs
73 | for request in request_iterator:
74 | yield _REQUEST_RESPONSE_PAIRS.get(request.SerializeToString(), _TEST_ERROR)
75 |
76 |
77 | @contextlib.contextmanager
78 | def _create_mock_channel():
79 | """Mocks out gRPC and returns a channel to be passed to Connection."""
80 | with mock.patch.object(dm_env_rpc_connection,
81 | 'dm_env_rpc_pb2_grpc') as mock_grpc:
82 | mock_stub_class = mock.MagicMock()
83 | mock_stub_class.Process = _process
84 | mock_grpc.EnvironmentStub.return_value = mock_stub_class
85 | yield mock.MagicMock()
86 |
87 |
88 | class ConnectionTests(absltest.TestCase):
89 |
90 | def test_create(self):
91 | with _create_mock_channel() as mock_channel:
92 | with dm_env_rpc_connection.Connection(mock_channel) as connection:
93 | response = connection.send(_CREATE_REQUEST)
94 | self.assertEqual(_CREATE_RESPONSE, response)
95 |
96 | def test_send_error_after_close(self):
97 | with _create_mock_channel() as mock_channel:
98 | with dm_env_rpc_connection.Connection(mock_channel) as connection:
99 | connection.close()
100 | with self.assertRaisesRegex(ValueError, 'stream is closed'):
101 | connection.send(_CREATE_REQUEST)
102 |
103 | def test_error(self):
104 | with _create_mock_channel() as mock_channel:
105 | with dm_env_rpc_connection.Connection(mock_channel) as connection:
106 | with self.assertRaisesRegex(error.DmEnvRpcError, 'test error'):
107 | connection.send(_BAD_CREATE_REQUEST)
108 |
109 | def test_extension(self):
110 | with _create_mock_channel() as mock_channel:
111 | with dm_env_rpc_connection.Connection(mock_channel) as connection:
112 | request = any_pb2.Any()
113 | request.Pack(_EXTENSION_REQUEST)
114 | response = connection.send(request)
115 | expected_response = any_pb2.Any()
116 | expected_response.Pack(_EXTENSION_RESPONSE)
117 | self.assertEqual(expected_response, response)
118 |
119 | @mock.patch.object(grpc, 'secure_channel')
120 | @mock.patch.object(grpc, 'channel_ready_future')
121 | def test_create_secure_channel_and_connect(self, mock_channel_ready,
122 | mock_secure_channel):
123 | mock_channel = mock.MagicMock()
124 | mock_secure_channel.return_value = mock_channel
125 |
126 | self.assertIsNotNone(
127 | dm_env_rpc_connection.create_secure_channel_and_connect(
128 | 'valid_address', grpc.local_channel_credentials()))
129 |
130 | mock_channel_ready.assert_called_once_with(mock_channel)
131 | mock_secure_channel.assert_called_once()
132 | mock_channel.close.assert_called_once()
133 |
134 | @mock.patch.object(grpc, 'secure_channel')
135 | @mock.patch.object(grpc, 'channel_ready_future')
136 | def test_create_secure_channel_and_connect_context(self, mock_channel_ready,
137 | mock_secure_channel):
138 | mock_channel = mock.MagicMock()
139 | mock_secure_channel.return_value = mock_channel
140 |
141 | with dm_env_rpc_connection.create_secure_channel_and_connect(
142 | 'valid_address') as connection:
143 | self.assertIsNotNone(connection)
144 |
145 | mock_channel_ready.assert_called_once_with(mock_channel)
146 | mock_secure_channel.assert_called_once()
147 | mock_channel.close.assert_called_once()
148 |
149 | @mock.patch.object(grpc, 'secure_channel')
150 | @mock.patch.object(grpc, 'channel_ready_future')
151 | @mock.patch.object(dm_env_rpc_connection, 'StreamReaderWriter')
152 | def test_create_secure_channel_and_connect_metadata(
153 | self, mock_stream_writer, mock_channel_ready, mock_secure_channel
154 | ):
155 | mock_channel = mock.MagicMock()
156 | mock_secure_channel.return_value = mock_channel
157 | metadata = [('upstream', 'fake_address')]
158 | with dm_env_rpc_connection.create_secure_channel_and_connect(
159 | 'valid_address', metadata=metadata
160 | ) as connection:
161 | self.assertIsNotNone(connection)
162 |
163 | mock_channel_ready.assert_called_once_with(mock_channel)
164 | mock_secure_channel.assert_called_once()
165 | mock_channel.close.assert_called_once()
166 | mock_stream_writer.assert_called_once_with(mock.ANY, metadata)
167 |
168 | def test_create_secure_channel_and_connect_timeout(self):
169 | with self.assertRaises(grpc.FutureTimeoutError):
170 | dm_env_rpc_connection.create_secure_channel_and_connect(
171 | 'invalid_address', grpc.local_channel_credentials(), timeout=1.)
172 |
173 | def test_incorrect_response(self):
174 | with _create_mock_channel() as mock_channel:
175 | with dm_env_rpc_connection.Connection(mock_channel) as connection:
176 | with self.assertRaisesRegex(ValueError, 'Unexpected response message'):
177 | connection.send(_INCORRECT_RESPONSE_TEST_MSG)
178 |
179 | def test_with_metadata(self):
180 | expected_metadata = (('key', 'value'),)
181 | with mock.patch.object(dm_env_rpc_connection,
182 | 'dm_env_rpc_pb2_grpc') as mock_grpc:
183 | mock_stub_class = mock.MagicMock()
184 | mock_grpc.EnvironmentStub.return_value = mock_stub_class
185 | _ = dm_env_rpc_connection.Connection(
186 | mock.MagicMock(), metadata=expected_metadata)
187 | mock_stub_class.Process.assert_called_with(
188 | mock.ANY, metadata=expected_metadata)
189 |
190 |
191 | if __name__ == '__main__':
192 | absltest.main()
193 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/dm_env_flatten_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Python utilities for flattening and unflattening key-value mappings."""
16 |
17 | from typing import Any, Dict, Mapping
18 |
19 |
20 | def flatten_dict(
21 | input_dict: Mapping[str, Any],
22 | separator: str,
23 | *,
24 | strict: bool = True,
25 | ) -> Dict[str, Any]:
26 | """Flattens mappings by joining sub-keys using the provided separator.
27 |
28 | Only non-empty, mapping types will be flattened. All other types are deemed
29 | leaf values.
30 |
31 | Args:
32 | input_dict: Mapping of key-value pairs to flatten.
33 | separator: Delimiter used to concatenate keys.
34 | strict: Whether to permit keys that already contain the separator. Setting
35 | this to False will cause unflattening to be ambiguous.
36 |
37 | Returns:
38 | Flattened dictionary of key-value pairs.
39 |
40 | Raises:
41 | ValueError: If the `input_dict` has a key that contains the separator
42 | string, and strict is set to False.
43 | """
44 | result: Dict[str, Any] = {}
45 | for key, value in input_dict.items():
46 | if strict and separator in key:
47 | raise ValueError(
48 | f"Can not safely flatten dictionary: key '{key}' already contains "
49 | f"the separator '{separator}'!"
50 | )
51 | if isinstance(value, Mapping) and len(value):
52 | result.update({
53 | f'{key}{separator}{sub_key}': sub_value
54 | for sub_key, sub_value in flatten_dict(
55 | value, separator, strict=strict).items()
56 | })
57 | else:
58 | result[key] = value
59 | return result
60 |
61 |
62 | def unflatten_dict(input_dict: Mapping[str, Any],
63 | separator: str) -> Dict[str, Any]:
64 | """Unflatten dictionary using split keys to determine the structure.
65 |
66 | For each key, split based on the provided separator and create nested
67 | dictionary entry for each sub-key.
68 |
69 | Args:
70 | input_dict: Mapping of key-value pairs to un-flatten.
71 | separator: Delimiter used to split keys.
72 |
73 | Returns:
74 | Unflattened dictionary.
75 |
76 | Raises:
77 | ValueError: If a key, or it's constituent sub-keys already has a value. For
78 | instance, unflattening `{"foo": True, "foo.bar": "baz"}` will result in
79 | "foo" being set to both a dict and a Bool.
80 | """
81 | result: Dict[str, Any] = {}
82 | for key, value in input_dict.items():
83 | sub_keys = key.split(separator)
84 | sub_tree = result
85 | for sub_key in sub_keys[:-1]:
86 | sub_tree = sub_tree.setdefault(sub_key, {})
87 | if not isinstance(sub_tree, Mapping):
88 | raise ValueError(f"Sub-tree '{sub_key}' has already been assigned a "
89 | f"leaf value {sub_tree}")
90 |
91 | if sub_keys[-1] in sub_tree:
92 | raise ValueError(f'Duplicate key {key}')
93 | sub_tree[sub_keys[-1]] = value
94 | return result
95 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/dm_env_flatten_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests for dm_env_flatten_utils."""
16 |
17 | from absl.testing import absltest
18 | from dm_env_rpc.v1 import dm_env_flatten_utils
19 |
20 |
21 | class FlattenUtilsTest(absltest.TestCase):
22 |
23 | def test_flatten(self):
24 | input_dict = {
25 | 'foo': {
26 | 'bar': 1,
27 | 'baz': False
28 | },
29 | 'fiz': object(),
30 | }
31 | expected = {
32 | 'foo.bar': 1,
33 | 'foo.baz': False,
34 | 'fiz': object(),
35 | }
36 | self.assertSameElements(expected,
37 | dm_env_flatten_utils.flatten_dict(input_dict, '.'))
38 |
39 | def test_unflatten(self):
40 | input_dict = {
41 | 'foo.bar.baz': True,
42 | 'fiz.buz': 1,
43 | 'foo.baz': 'val',
44 | 'buz': {}
45 | }
46 | expected = {
47 | 'foo': {
48 | 'bar': {
49 | 'baz': True
50 | },
51 | 'baz': 'val'
52 | },
53 | 'fiz': {
54 | 'buz': 1
55 | },
56 | 'buz': {},
57 | }
58 | self.assertSameElements(
59 | expected, dm_env_flatten_utils.unflatten_dict(input_dict, '.'))
60 |
61 | def test_unflatten_different_separator(self):
62 | input_dict = {'foo::bar.baz': True, 'foo.bar::baz': 1}
63 | expected = {'foo': {'bar.baz': True}, 'foo.bar': {'baz': 1}}
64 | self.assertSameElements(
65 | expected, dm_env_flatten_utils.unflatten_dict(input_dict, '::'))
66 |
67 | def test_flatten_unflatten(self):
68 | input_output = {
69 | 'foo': {
70 | 'bar': 1,
71 | 'baz': False
72 | },
73 | 'fiz': object(),
74 | }
75 | self.assertSameElements(
76 | input_output,
77 | dm_env_flatten_utils.unflatten_dict(
78 | dm_env_flatten_utils.flatten_dict(input_output, '.'), '.'))
79 |
80 | def test_flatten_with_key_containing_separator(self):
81 | input_dict = {'foo.bar': {'baz': 123}, 'bar': {'foo.baz': 456}}
82 | expected = {'foo.bar.baz': 123, 'bar.foo.baz': 456}
83 |
84 | self.assertSameElements(
85 | expected,
86 | dm_env_flatten_utils.flatten_dict(input_dict, '.', strict=False),
87 | )
88 |
89 | def test_flatten_with_key_containing_separator_strict_raises_error(self):
90 | with self.assertRaisesRegex(ValueError, 'foo.bar'):
91 | dm_env_flatten_utils.flatten_dict({'foo.bar': True}, '.')
92 |
93 | def test_invalid_flattened_dict_raises_error(self):
94 | input_dict = dict((
95 | ('foo.bar', True),
96 | ('foo', 'invalid_value_for_sub_key'),
97 | ))
98 | with self.assertRaisesRegex(ValueError, 'Duplicate key'):
99 | dm_env_flatten_utils.unflatten_dict(input_dict, '.')
100 |
101 | def test_sub_tree_has_value_raises_error(self):
102 | input_dict = dict((
103 | ('branch', 'should_not_have_value'),
104 | ('branch.leaf', True),
105 | ))
106 | with self.assertRaisesRegex(ValueError,
107 | "Sub-tree 'branch' has already been assigned"):
108 | dm_env_flatten_utils.unflatten_dict(input_dict, '.')
109 |
110 | def test_empty_dict_values_flatten(self):
111 | input_dict = {
112 | 'foo': {},
113 | 'bar': {
114 | 'baz': {}
115 | },
116 | }
117 | expected = {
118 | 'foo': {},
119 | 'bar.baz': {},
120 | }
121 | self.assertSameElements(expected,
122 | dm_env_flatten_utils.flatten_dict(input_dict, '.'))
123 |
124 |
125 | if __name__ == '__main__':
126 | absltest.main()
127 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/dm_env_rpc.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | // ===========================================================================
15 | syntax = "proto3";
16 |
17 | package dm_env_rpc.v1;
18 |
19 | import "google/protobuf/any.proto";
20 | import "google/rpc/status.proto";
21 |
22 | // A potentially multi-dimensional array of data, laid out in row-major format.
23 | // Note, only one data channel should be used at a time.
24 | message Tensor {
25 | message Int8Array {
26 | bytes array = 1;
27 | }
28 | message Int32Array {
29 | repeated int32 array = 1;
30 | }
31 | message Int64Array {
32 | repeated int64 array = 1;
33 | }
34 | message Uint8Array {
35 | bytes array = 1;
36 | }
37 | message Uint32Array {
38 | repeated uint32 array = 1;
39 | }
40 | message Uint64Array {
41 | repeated uint64 array = 1;
42 | }
43 | message FloatArray {
44 | repeated float array = 1;
45 | }
46 | message DoubleArray {
47 | repeated double array = 1;
48 | }
49 | message BoolArray {
50 | repeated bool array = 1;
51 | }
52 | message StringArray {
53 | repeated string array = 1;
54 | }
55 | message ProtoArray {
56 | repeated google.protobuf.Any array = 1;
57 | }
58 |
59 | // The flattened tensor data. Data is laid out in row-major order.
60 | oneof payload {
61 | // LINT.IfChange(Tensor)
62 | FloatArray floats = 1;
63 | DoubleArray doubles = 2;
64 | Int8Array int8s = 3;
65 | Int32Array int32s = 4;
66 | Int64Array int64s = 5;
67 | Uint8Array uint8s = 6;
68 | Uint32Array uint32s = 7;
69 | Uint64Array uint64s = 8;
70 | BoolArray bools = 9;
71 | StringArray strings = 10;
72 | ProtoArray protos = 11;
73 | // LINT.ThenChange(:DataType)
74 | }
75 |
76 | // The dimensions of the repeated data fields. If empty, the data channel
77 | // will be treated as a scalar and expected to have exactly one element.
78 | //
79 | // If the payload has exactly one element, it will be repeated to fill the
80 | // shape.
81 | //
82 | // A negative element in a dimension indicates its size should be determined
83 | // based on the number of elements in the payload and the rest of the shape.
84 | // For instance, a shape of [-1, 5] means the shape is a matrix with 5 columns
85 | // and a variable number of rows. Only one element in the shape may be set to
86 | // a negative value.
87 | repeated int32 shape = 15;
88 | }
89 |
90 | // The data type of elements of a tensor. This must match the types in the
91 | // Tensor payload oneof.
92 | enum DataType {
93 | // This is the default value indicating no value was set.
94 | INVALID_DATA_TYPE = 0;
95 |
96 | // LINT.IfChange(DataType)
97 | FLOAT = 1;
98 | DOUBLE = 2;
99 | INT8 = 3;
100 | INT32 = 4;
101 | INT64 = 5;
102 | UINT8 = 6;
103 | UINT32 = 7;
104 | UINT64 = 8;
105 | BOOL = 9;
106 | STRING = 10;
107 | PROTO = 11;
108 | // LINT.ThenChange(:Tensor)
109 | }
110 |
111 | message TensorSpec {
112 | // A human-readable name describing this tensor.
113 | string name = 1;
114 |
115 | // The dimensionality of the tensor. See Tensor.shape for more information.
116 | repeated int32 shape = 2;
117 |
118 | // The data type of the elements in the tensor.
119 | DataType dtype = 3;
120 |
121 | // Value sub-message that defines the inclusive min/max bounds for numerical
122 | // types.
123 | message Value {
124 | oneof payload {
125 | Tensor.FloatArray floats = 9;
126 | Tensor.DoubleArray doubles = 10;
127 | Tensor.Int8Array int8s = 11;
128 | Tensor.Int32Array int32s = 12;
129 | Tensor.Int64Array int64s = 13;
130 | Tensor.Uint8Array uint8s = 14;
131 | Tensor.Uint32Array uint32s = 15;
132 | Tensor.Uint64Array uint64s = 16;
133 | }
134 |
135 | // Deprecated scalar Value fields. Please use array fields and create a
136 | // single element array of the relevant type.
137 | reserved 1 to 8;
138 | }
139 |
140 | // The minimum value that elements in the tensor can obtain. Inclusive.
141 | Value min = 4; // Optional
142 |
143 | // The maximum value that elements in the tensor can obtain. Inclusive.
144 | Value max = 5; // Optional
145 | }
146 |
147 | message CreateWorldRequest {
148 | // Settings to create the world with. This can define the level layout, the
149 | // number of agents, the goal or game mode, or other universal settings.
150 | // Agent-specific settings, such as anything which would change the action or
151 | // observation spec, should go in the JoinWorldRequest.
152 | map settings = 1;
153 | }
154 | message CreateWorldResponse {
155 | // The unique name for the world just created.
156 | string world_name = 1;
157 | }
158 |
159 | message ActionObservationSpecs {
160 | map actions = 1;
161 |
162 | map observations = 2;
163 | }
164 |
165 | message JoinWorldRequest {
166 | // The name of the world to join.
167 | string world_name = 1;
168 |
169 | // Agent-specific settings which define how to join the world, such as agent
170 | // name and class in an RPG.
171 | map settings = 2;
172 | }
173 | message JoinWorldResponse {
174 | ActionObservationSpecs specs = 1;
175 | }
176 |
177 | enum EnvironmentStateType {
178 | // This is the default value indicating no value was set. It should never be
179 | // sent or received.
180 | INVALID_ENVIRONMENT_STATE = 0;
181 |
182 | // The environment is currently in the middle of a sequence.
183 | RUNNING = 1;
184 |
185 | // The previously running sequence reached its natural conclusion.
186 | TERMINATED = 2;
187 |
188 | // The sequence was interrupted by a reset.
189 | INTERRUPTED = 3;
190 | }
191 |
192 | message StepRequest {
193 | // The actions to perform on the environment. If the environment is currently
194 | // in a non-RUNNING state, whether because the agent has just called
195 | // JoinWorld, the state from the last is StepResponse was TERMINATED or
196 | // INTERRUPTED, or a ResetRequest had previously been sent, the actions will
197 | // be ignored.
198 | map actions = 1;
199 |
200 | // Array of observations UIDs to return. If not set, no observations are
201 | // returned.
202 | repeated uint64 requested_observations = 2;
203 | }
204 |
205 | message StepResponse {
206 | // If state is not RUNNING, the action on the next StepRequest will be
207 | // ignored and the environment will transition to a RUNNING state.
208 | EnvironmentStateType state = 1;
209 |
210 | // The observations requested in `StepRequest`. Observations returned should
211 | // match the dimensionality and type specified in `specs.observations`.
212 | map observations = 2;
213 | }
214 |
215 | // The current sequence will be interrupted. The actions on the next call to
216 | // StepRequest will be ignored and a new sequence will begin.
217 | message ResetRequest {
218 | // Agents-specific settings to apply for the next sequence, such as changing
219 | // class in an RPG.
220 | map settings = 1;
221 | }
222 | message ResetResponse {
223 | ActionObservationSpecs specs = 1;
224 | }
225 |
226 | // All connected agents will have their next StepResponse return INTERRUPTED.
227 | message ResetWorldRequest {
228 | string world_name = 1;
229 |
230 | // World settings to apply for the next sequence, such as changing the map or
231 | // seed.
232 | map settings = 2;
233 | }
234 | message ResetWorldResponse {}
235 |
236 | message LeaveWorldRequest {}
237 | message LeaveWorldResponse {}
238 |
239 | message DestroyWorldRequest {
240 | string world_name = 1;
241 | }
242 | message DestroyWorldResponse {}
243 |
244 | message EnvironmentRequest {
245 | oneof payload {
246 | CreateWorldRequest create_world = 1;
247 | JoinWorldRequest join_world = 2;
248 | StepRequest step = 3;
249 | ResetRequest reset = 4;
250 | ResetWorldRequest reset_world = 5;
251 | LeaveWorldRequest leave_world = 6;
252 | DestroyWorldRequest destroy_world = 7;
253 |
254 | // If the environment supports a specialized request not covered above it
255 | // can be sent this way.
256 | //
257 | // Slot 15 is the last slot which can be encoded with one byte. See
258 | // https://developers.google.com/protocol-buffers/docs/proto3#assigning-field-numbers
259 | google.protobuf.Any extension = 15;
260 | }
261 |
262 | // Deprecated property requests. Please use properties extension for future
263 | // requests/responses.
264 | reserved 8 to 10;
265 |
266 | // Slot corresponds to `error` in the EnvironmentResponse.
267 | reserved 16;
268 | }
269 |
270 | message EnvironmentResponse {
271 | oneof payload {
272 | CreateWorldResponse create_world = 1;
273 | JoinWorldResponse join_world = 2;
274 | StepResponse step = 3;
275 | ResetResponse reset = 4;
276 | ResetWorldResponse reset_world = 5;
277 | LeaveWorldResponse leave_world = 6;
278 | DestroyWorldResponse destroy_world = 7;
279 |
280 | // If the environment supports a specialized response not covered above it
281 | // can be sent this way.
282 | //
283 | // Slot 15 is the last slot which can be encoded with one byte. See
284 | // https://developers.google.com/protocol-buffers/docs/proto3#assigning-field-numbers
285 | google.protobuf.Any extension = 15;
286 |
287 | google.rpc.Status error = 16;
288 | }
289 |
290 | // Deprecated property responses. Please use properties extension for future
291 | // requests/responses.
292 | reserved 8 to 10;
293 | }
294 |
295 | service Environment {
296 | // Process incoming environment requests.
297 | rpc Process(stream EnvironmentRequest) returns (stream EnvironmentResponse) {}
298 | }
299 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/dm_env_rpc_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests for environment_stream.proto.
16 |
17 | These aren't for testing functionality (it's assumed protobufs work) but for
18 | testing/demonstrating how the protobufs would have to be used in code.
19 | """
20 |
21 | from absl.testing import absltest
22 | from dm_env_rpc.v1 import dm_env_rpc_pb2
23 |
24 |
25 | class TensorTests(absltest.TestCase):
26 |
27 | def test_setting_tensor_data(self):
28 | tensor = dm_env_rpc_pb2.Tensor()
29 | tensor.floats.array[:] = [1, 2]
30 |
31 | def test_setting_tensor_data_with_wrong_type(self):
32 | tensor = dm_env_rpc_pb2.Tensor()
33 | with self.assertRaises(TypeError):
34 | tensor.floats.array[:] = ['hello!'] # pytype: disable=unsupported-operands
35 |
36 | def test_which_is_set(self):
37 | tensor = dm_env_rpc_pb2.Tensor()
38 | tensor.floats.array[:] = [1, 2]
39 | self.assertEqual('floats', tensor.WhichOneof('payload'))
40 |
41 |
42 | class TensorSpec(absltest.TestCase):
43 |
44 | def test_setting_spec(self):
45 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
46 | tensor_spec.name = 'Foo'
47 | tensor_spec.min.floats.array[:] = [0.0]
48 | tensor_spec.max.floats.array[:] = [0.0]
49 | tensor_spec.shape[:] = [2, 2]
50 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.FLOAT
51 |
52 |
53 | class JoinWorldResponse(absltest.TestCase):
54 |
55 | def test_setting_spec(self):
56 | response = dm_env_rpc_pb2.JoinWorldResponse()
57 | tensor_spec = response.specs.actions[1]
58 | tensor_spec.shape[:] = [1]
59 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.FLOAT
60 |
61 |
62 | if __name__ == '__main__':
63 | absltest.main()
64 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/dm_env_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Utilities for interfacing dm_env and dm_env_rpc."""
16 |
17 | from typing import Dict
18 | from dm_env import specs
19 | import numpy as np
20 |
21 | from dm_env_rpc.v1 import dm_env_rpc_pb2
22 | from dm_env_rpc.v1 import spec_manager as dm_env_rpc_spec_manager
23 | from dm_env_rpc.v1 import tensor_spec_utils
24 | from dm_env_rpc.v1 import tensor_utils
25 |
26 |
27 | def tensor_spec_to_dm_env_spec(
28 | tensor_spec: dm_env_rpc_pb2.TensorSpec) -> specs.Array:
29 | """Returns a dm_env spec given a dm_env_rpc TensorSpec.
30 |
31 | Args:
32 | tensor_spec: A dm_env_rpc TensorSpec protobuf.
33 |
34 | Returns:
35 | Either a DiscreteArray, BoundedArray, StringArray or Array, depending on the
36 | content of the TensorSpec.
37 | """
38 | np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype)
39 | if tensor_spec.HasField('min') or tensor_spec.HasField('max'):
40 | bounds = tensor_spec_utils.bounds(tensor_spec)
41 |
42 | if (not tensor_spec.shape
43 | and np.issubdtype(np_type, np.integer)
44 | and bounds.min == 0
45 | and tensor_spec.HasField('max')):
46 | return specs.DiscreteArray(
47 | num_values=bounds.max + 1, dtype=np_type, name=tensor_spec.name)
48 | else:
49 | return specs.BoundedArray(
50 | shape=tensor_spec.shape,
51 | dtype=np_type,
52 | name=tensor_spec.name,
53 | minimum=bounds.min,
54 | maximum=bounds.max)
55 | else:
56 | if tensor_spec.dtype == dm_env_rpc_pb2.DataType.STRING:
57 | return specs.StringArray(shape=tensor_spec.shape, name=tensor_spec.name)
58 | else:
59 | return specs.Array(
60 | shape=tensor_spec.shape, dtype=np_type, name=tensor_spec.name)
61 |
62 |
63 | def dm_env_spec_to_tensor_spec(spec: specs.Array) -> dm_env_rpc_pb2.TensorSpec:
64 | """Returns a dm_env_rpc TensorSpec from the provided dm_env spec type."""
65 | dtype = np.str_ if isinstance(spec, specs.StringArray) else spec.dtype
66 | tensor_spec = dm_env_rpc_pb2.TensorSpec(
67 | name=spec.name,
68 | shape=spec.shape,
69 | dtype=tensor_utils.np_type_to_data_type(dtype))
70 | if isinstance(spec, specs.DiscreteArray):
71 | tensor_spec_utils.set_bounds(
72 | tensor_spec, minimum=0, maximum=spec.num_values - 1)
73 | elif isinstance(spec, specs.BoundedArray):
74 | tensor_spec_utils.set_bounds(tensor_spec, spec.minimum, spec.maximum)
75 |
76 | return tensor_spec
77 |
78 |
79 | def dm_env_spec(
80 | spec_manager: dm_env_rpc_spec_manager.SpecManager
81 | ) -> Dict[str, specs.Array]:
82 | """Returns a dm_env spec for the given `spec_manager`.
83 |
84 | Args:
85 | spec_manager: An instance of SpecManager.
86 |
87 | Returns:
88 | A dict mapping names to either a dm_env Array, BoundedArray, DiscreteArray
89 | or StringArray spec for each named TensorSpec in `spec_manager`.
90 | """
91 | return {
92 | name: tensor_spec_to_dm_env_spec(spec_manager.name_to_spec(name))
93 | for name in spec_manager.names()
94 | }
95 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/dm_env_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests for dm_env_rpc/dm_env utilities."""
16 |
17 | import typing
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from dm_env import specs
22 | import numpy as np
23 |
24 | from google.protobuf import text_format
25 | from dm_env_rpc.v1 import dm_env_rpc_pb2
26 | from dm_env_rpc.v1 import dm_env_utils
27 | from dm_env_rpc.v1 import spec_manager
28 |
29 |
30 | class TensorSpecToDmEnvSpecTests(absltest.TestCase):
31 |
32 | def test_no_bounds_gives_arrayspec(self):
33 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
34 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
35 | tensor_spec.shape[:] = [3]
36 | tensor_spec.name = 'foo'
37 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
38 | self.assertEqual(specs.Array(shape=[3], dtype=np.uint32), actual)
39 | self.assertEqual('foo', actual.name)
40 |
41 | def test_string_give_string_array(self):
42 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
43 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING
44 | tensor_spec.shape[:] = [1, 2, 3]
45 | tensor_spec.name = 'string_spec'
46 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
47 | self.assertEqual(specs.StringArray(shape=[1, 2, 3]), actual)
48 | self.assertEqual('string_spec', actual.name)
49 |
50 | def test_scalar_with_0_n_bounds_gives_discrete_array(self):
51 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
52 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
53 | tensor_spec.name = 'foo'
54 |
55 | max_value = 9
56 | tensor_spec.min.uint32s.array[:] = [0]
57 | tensor_spec.max.uint32s.array[:] = [max_value]
58 | actual = typing.cast(specs.DiscreteArray,
59 | dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec))
60 | expected = specs.DiscreteArray(
61 | num_values=max_value + 1, dtype=np.uint32, name='foo')
62 | self.assertEqual(expected, actual)
63 | self.assertEqual(0, actual.minimum)
64 | self.assertEqual(max_value, actual.maximum)
65 | self.assertEqual('foo', actual.name)
66 |
67 | def test_scalar_with_1_n_bounds_gives_bounded_array(self):
68 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
69 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
70 | tensor_spec.name = 'foo'
71 | tensor_spec.min.uint32s.array[:] = [1]
72 | tensor_spec.max.uint32s.array[:] = [10]
73 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
74 | expected = specs.BoundedArray(
75 | shape=(), dtype=np.uint32, minimum=1, maximum=10, name='foo')
76 | self.assertEqual(expected, actual)
77 | self.assertEqual('foo', actual.name)
78 |
79 | def test_scalar_with_0_min_and_no_max_bounds_gives_bounded_array(self):
80 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
81 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
82 | tensor_spec.name = 'foo'
83 | tensor_spec.min.uint32s.array[:] = [0]
84 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
85 | expected = specs.BoundedArray(
86 | shape=(), dtype=np.uint32, minimum=0, maximum=2**32 - 1, name='foo')
87 | self.assertEqual(expected, actual)
88 | self.assertEqual('foo', actual.name)
89 |
90 | def test_only_min_bounds(self):
91 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
92 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
93 | tensor_spec.shape[:] = [3]
94 | tensor_spec.name = 'foo'
95 | tensor_spec.min.uint32s.array[:] = [1]
96 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
97 | expected = specs.BoundedArray(
98 | shape=[3], dtype=np.uint32, minimum=1, maximum=2**32 - 1)
99 | self.assertEqual(expected, actual)
100 | self.assertEqual('foo', actual.name)
101 |
102 | def test_only_max_bounds(self):
103 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
104 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
105 | tensor_spec.shape[:] = [3]
106 | tensor_spec.name = 'foo'
107 | tensor_spec.max.uint32s.array[:] = [10]
108 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
109 | expected = specs.BoundedArray(
110 | shape=[3], dtype=np.uint32, minimum=0, maximum=10)
111 | self.assertEqual(expected, actual)
112 | self.assertEqual('foo', actual.name)
113 |
114 | def test_both_bounds(self):
115 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
116 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
117 | tensor_spec.shape[:] = [3]
118 | tensor_spec.name = 'foo'
119 | tensor_spec.min.uint32s.array[:] = [1]
120 | tensor_spec.max.uint32s.array[:] = [10]
121 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
122 | expected = specs.BoundedArray(
123 | shape=[3], dtype=np.uint32, minimum=1, maximum=10)
124 | self.assertEqual(expected, actual)
125 | self.assertEqual('foo', actual.name)
126 |
127 | def test_bounds_oneof_not_set_gives_dtype_bounds(self):
128 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
129 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
130 | tensor_spec.shape[:] = [3]
131 | tensor_spec.name = 'foo'
132 |
133 | # Just to force the message to get created.
134 | tensor_spec.min.floats.array[:] = [3.0]
135 | tensor_spec.min.ClearField('floats')
136 |
137 | actual = dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
138 | expected = specs.BoundedArray(
139 | shape=[3], dtype=np.uint32, minimum=0, maximum=2**32 - 1)
140 | self.assertEqual(expected, actual)
141 | self.assertEqual('foo', actual.name)
142 |
143 | def test_bounds_wrong_type_gives_error(self):
144 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
145 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
146 | tensor_spec.shape[:] = [3]
147 | tensor_spec.name = 'foo'
148 | tensor_spec.min.floats.array[:] = [1.9]
149 | with self.assertRaisesRegex(ValueError, 'uint32'):
150 | dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
151 |
152 | def test_bounds_on_string_gives_error(self):
153 | tensor_spec = dm_env_rpc_pb2.TensorSpec()
154 | tensor_spec.dtype = dm_env_rpc_pb2.DataType.STRING
155 | tensor_spec.shape[:] = [2]
156 | tensor_spec.name = 'named'
157 | tensor_spec.min.floats.array[:] = [1.9]
158 | tensor_spec.max.floats.array[:] = [10.0]
159 | with self.assertRaisesRegex(ValueError, 'string'):
160 | dm_env_utils.tensor_spec_to_dm_env_spec(tensor_spec)
161 |
162 |
163 | class DmEnvSpecToTensorSpecTests(parameterized.TestCase):
164 |
165 | @parameterized.parameters(
166 | (specs.Array([1, 2], np.float32,
167 | 'foo'), """name: "foo" shape: 1 shape: 2 dtype: FLOAT"""),
168 | (specs.DiscreteArray(5, int, 'bar'), r"""name: "bar" dtype: INT64
169 | min { int64s { array: 0 } } max { int64s { array: 4 } }"""),
170 | (specs.BoundedArray(
171 | (), np.int32, -1, 5, 'baz'), r"""name: "baz" dtype: INT32
172 | min { int32s { array: -1 } } max { int32s { array: 5 } }"""),
173 | (specs.BoundedArray((1, 2), np.uint8, 0, 127,
174 | 'zog'), r"""name: "zog" shape: 1 shape: 2 dtype: UINT8
175 | min { uint8s { array: "\000" } } max { uint8s { array: "\177" } }"""),
176 | (specs.StringArray(shape=(5, 5), name='fux'),
177 | r"""name: "fux" shape: 5 shape: 5 dtype: STRING"""),
178 | )
179 | def test_dm_env_spec(self, value, expected):
180 | tensor_spec = dm_env_utils.dm_env_spec_to_tensor_spec(value)
181 | expected = text_format.Parse(expected, dm_env_rpc_pb2.TensorSpec())
182 | self.assertEqual(expected, tensor_spec)
183 |
184 |
185 | class DmEnvSpecTests(absltest.TestCase):
186 |
187 | def test_spec(self):
188 | dm_env_rpc_specs = {
189 | 54:
190 | dm_env_rpc_pb2.TensorSpec(
191 | name='fuzz', shape=[3], dtype=dm_env_rpc_pb2.DataType.FLOAT),
192 | 55:
193 | dm_env_rpc_pb2.TensorSpec(
194 | name='foo', shape=[2], dtype=dm_env_rpc_pb2.DataType.INT32),
195 | }
196 | manager = spec_manager.SpecManager(dm_env_rpc_specs)
197 |
198 | expected = {
199 | 'foo': specs.Array(shape=[2], dtype=np.int32),
200 | 'fuzz': specs.Array(shape=[3], dtype=np.float32)
201 | }
202 |
203 | self.assertDictEqual(expected, dm_env_utils.dm_env_spec(manager))
204 |
205 | def test_empty_spec(self):
206 | self.assertDictEqual({},
207 | dm_env_utils.dm_env_spec(spec_manager.SpecManager({})))
208 |
209 | def test_spec_generate_and_validate_scalars(self):
210 | dm_env_rpc_specs = []
211 | for name, dtype in dm_env_rpc_pb2.DataType.items():
212 | if dtype != dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE:
213 | dm_env_rpc_specs.append(
214 | dm_env_rpc_pb2.TensorSpec(name=name, shape=(), dtype=dtype))
215 |
216 | for dm_env_rpc_spec in dm_env_rpc_specs:
217 | spec = dm_env_utils.tensor_spec_to_dm_env_spec(dm_env_rpc_spec)
218 | value = spec.generate_value()
219 | spec.validate(value)
220 |
221 | def test_spec_generate_and_validate_tensors(self):
222 | example_shape = (10, 10, 3)
223 |
224 | dm_env_rpc_specs = []
225 | for name, dtype in dm_env_rpc_pb2.DataType.items():
226 | if dtype != dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE:
227 | dm_env_rpc_specs.append(
228 | dm_env_rpc_pb2.TensorSpec(
229 | name=name, shape=example_shape, dtype=dtype))
230 |
231 | for dm_env_rpc_spec in dm_env_rpc_specs:
232 | spec = dm_env_utils.tensor_spec_to_dm_env_spec(dm_env_rpc_spec)
233 | value = spec.generate_value()
234 | spec.validate(value)
235 |
236 | if __name__ == '__main__':
237 | absltest.main()
238 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/error.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Provides custom Pythonic errors for dm_env_rpc error messages."""
17 | from typing import Iterable
18 | from google.protobuf import any_pb2
19 | from google.rpc import status_pb2
20 |
21 |
22 | class DmEnvRpcError(Exception):
23 | """A dm_env_rpc custom exception.
24 |
25 | Wraps a google.rpc.Status message as a Python Exception class.
26 | """
27 |
28 | def __init__(self, status_proto: status_pb2.Status):
29 | super().__init__()
30 | self._status_proto = status_proto
31 |
32 | @property
33 | def code(self) -> int:
34 | return self._status_proto.code
35 |
36 | @property
37 | def message(self) -> str:
38 | return self._status_proto.message
39 |
40 | @property
41 | def details(self) -> Iterable[any_pb2.Any]:
42 | return self._status_proto.details
43 |
44 | def __str__(self):
45 | return str(self._status_proto)
46 |
47 | def __repr__(self):
48 | return f'DmEnvRpcError(code={self.code}, message={self.message})'
49 |
50 | def __reduce__(self):
51 | return (DmEnvRpcError, (self._status_proto,))
52 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/error_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests for dm_env_rpc error module."""
16 |
17 | import pickle
18 |
19 | from absl.testing import absltest
20 | from google.rpc import code_pb2
21 | from google.rpc import status_pb2
22 | from dm_env_rpc.v1 import error
23 |
24 |
25 | class ErrorTest(absltest.TestCase):
26 |
27 | def testSimpleError(self):
28 | message = status_pb2.Status(
29 | code=code_pb2.INVALID_ARGUMENT, message='A test error.')
30 | exception = error.DmEnvRpcError(message)
31 |
32 | self.assertEqual(code_pb2.INVALID_ARGUMENT, exception.code)
33 | self.assertEqual('A test error.', exception.message)
34 | self.assertEqual(str(message), str(exception))
35 |
36 | def testPickleUnpickle(self):
37 | exception = error.DmEnvRpcError(status_pb2.Status(
38 | code=code_pb2.INVALID_ARGUMENT, message='foo.'))
39 | pickled = pickle.dumps(exception)
40 | unpickled = pickle.loads(pickled)
41 |
42 | self.assertEqual(code_pb2.INVALID_ARGUMENT, unpickled.code)
43 | self.assertEqual('foo.', unpickled.message)
44 |
45 | def testRepr(self):
46 | exception = error.DmEnvRpcError(status_pb2.Status(
47 | code=code_pb2.INVALID_ARGUMENT, message='foo.'))
48 | as_string = repr(exception)
49 | self.assertIn(exception.message, as_string)
50 | self.assertIn(str(exception.code), as_string)
51 |
52 |
53 | if __name__ == '__main__':
54 | absltest.main()
55 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/extensions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Extensions module for dm_env_rpc."""
16 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/extensions/properties.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | // ===========================================================================
15 | syntax = "proto3";
16 |
17 | package dm_env_rpc.v1.extensions.properties;
18 |
19 | import "dm_env_rpc/v1/dm_env_rpc.proto";
20 |
21 | message PropertySpec {
22 | // Required: TensorSpec name field for key value.
23 | dm_env_rpc.v1.TensorSpec spec = 1;
24 |
25 | bool is_readable = 2;
26 | bool is_writable = 3;
27 | bool is_listable = 4;
28 |
29 | string description = 5;
30 | }
31 |
32 | message ListPropertyRequest {
33 | // Key to list property for. Empty string is root level.
34 | string key = 1;
35 | }
36 |
37 | message ListPropertyResponse {
38 | repeated PropertySpec values = 1;
39 | }
40 |
41 | message ReadPropertyRequest {
42 | string key = 1;
43 | }
44 |
45 | message ReadPropertyResponse {
46 | dm_env_rpc.v1.Tensor value = 1;
47 | }
48 |
49 | message WritePropertyRequest {
50 | string key = 1;
51 | dm_env_rpc.v1.Tensor value = 2;
52 | }
53 |
54 | message WritePropertyResponse {}
55 |
56 | message PropertyRequest {
57 | oneof payload {
58 | ReadPropertyRequest read_property = 1;
59 | WritePropertyRequest write_property = 2;
60 | ListPropertyRequest list_property = 3;
61 | }
62 | }
63 |
64 | message PropertyResponse {
65 | oneof payload {
66 | ReadPropertyResponse read_property = 1;
67 | WritePropertyResponse write_property = 2;
68 | ListPropertyResponse list_property = 3;
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/extensions/properties.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A helper class for sending and receiving property requests and responses.
16 |
17 | This helper class provides a Pythonic interface for reading, writing and listing
18 | properties. It simplifies the packing and unpacking of property requests and
19 | responses using the provided dm_env_rpc.v1.connection.Connection instance to
20 | send and receive extension messages.
21 |
22 | Example Usage:
23 | property_extension = PropertiesExtension(connection)
24 |
25 | # To read a property:
26 | value = property_extension['my_property']
27 |
28 | # To write a property:
29 | property_extension['my_property'] = new_value
30 |
31 | # To find available properties:
32 | property_specs = property_extension.specs()
33 |
34 | spec = property_specs['my_property']
35 | """
36 | import contextlib
37 | from typing import Mapping, Sequence, Optional
38 |
39 | from dm_env import specs as dm_env_specs
40 | from google.protobuf import any_pb2
41 | from google.rpc import code_pb2
42 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection
43 | from dm_env_rpc.v1 import dm_env_rpc_pb2
44 | from dm_env_rpc.v1 import dm_env_utils
45 | from dm_env_rpc.v1 import error
46 | from dm_env_rpc.v1 import tensor_utils
47 | from dm_env_rpc.v1.extensions import properties_pb2
48 |
49 |
50 | @contextlib.contextmanager
51 | def _convert_dm_env_rpc_error():
52 | """Helper to convert DmEnvRpcError to a properties related exception."""
53 | try:
54 | yield
55 | except error.DmEnvRpcError as e:
56 | if e.code == code_pb2.NOT_FOUND:
57 | raise KeyError('Property key not found!') from e
58 | elif e.code == code_pb2.PERMISSION_DENIED:
59 | raise PermissionError('Property permission denied!') from e
60 | elif e.code == code_pb2.INVALID_ARGUMENT:
61 | raise ValueError('Property value error!') from e
62 | raise
63 |
64 |
65 | class PropertySpec(object):
66 | """Class that represents a property's specification."""
67 |
68 | def __init__(self, property_spec_proto: properties_pb2.PropertySpec):
69 | """Constructs a property specification from PropertySpec proto message.
70 |
71 | Args:
72 | property_spec_proto: A properties_pb2.PropertySpec message.
73 | """
74 | self._property_spec_proto = property_spec_proto
75 |
76 | @property
77 | def key(self) -> str:
78 | """Return the property's key."""
79 | return self._property_spec_proto.spec.name
80 |
81 | @property
82 | def readable(self) -> bool:
83 | """Returns True if the property is readable."""
84 | return self._property_spec_proto.is_readable
85 |
86 | @property
87 | def writable(self) -> bool:
88 | """Returns True if the property is writable."""
89 | return self._property_spec_proto.is_writable
90 |
91 | @property
92 | def listable(self) -> bool:
93 | """Returns True if the property is listable."""
94 | return self._property_spec_proto.is_listable
95 |
96 | @property
97 | def spec(self) -> Optional[dm_env_specs.Array]:
98 | """Returns a dm_env spec if the property has a valid dtype.
99 |
100 | Returns:
101 | Either a dm_env spec or, if the dtype is invalid, None.
102 | """
103 | if self._property_spec_proto.spec.dtype != (
104 | dm_env_rpc_pb2.DataType.INVALID_DATA_TYPE):
105 | return dm_env_utils.tensor_spec_to_dm_env_spec(
106 | self._property_spec_proto.spec)
107 | else:
108 | return None
109 |
110 | @property
111 | def description(self) -> str:
112 | """Returns the property's description."""
113 | return self._property_spec_proto.description
114 |
115 | def __repr__(self):
116 | return (f'PropertySpec(key={self.key}, readable={self.readable}, '
117 | f'writable={self.writable}, listable={self.listable}, '
118 | f'spec={self.spec}, description={self.description})')
119 |
120 |
121 | class PropertiesExtension(object):
122 | """Helper class for sending and receiving property requests and responses."""
123 |
124 | def __init__(self, connection: dm_env_rpc_connection.ConnectionType):
125 | """Construct extension with provided dm_env_rpc connection to the env.
126 |
127 | Args:
128 | connection: An instance of Connection already connected to a dm_env_rpc
129 | server.
130 | """
131 | self._connection = connection
132 |
133 | def __getitem__(self, key: str):
134 | """Alias for PropertiesExtension read function."""
135 | return self.read(key)
136 |
137 | def __setitem__(self, key: str, value) -> None:
138 | """Alias for PropertiesExtension write function."""
139 | self.write(key, value)
140 |
141 | def specs(self, key: str = '') -> Mapping[str, PropertySpec]:
142 | """Helper to return sub-properties as a dict."""
143 | return {
144 | sub_property.key: sub_property for sub_property in self.list(key)
145 | }
146 |
147 | def read(self, key: str):
148 | """Reads the value of a property.
149 |
150 | Args:
151 | key: A string key that represents the property to read.
152 |
153 | Returns:
154 | The value of the property, either as a scalar (float, int, string, etc.)
155 | or, if the response tensor has a non-empty `shape` attribute, a NumPy
156 | array of the payload with the correct type and shape. See
157 | tensor_utils.unpack for more details.
158 | """
159 | response = properties_pb2.PropertyResponse()
160 | packed_request = any_pb2.Any()
161 | packed_request.Pack(
162 | properties_pb2.PropertyRequest(
163 | read_property=properties_pb2.ReadPropertyRequest(key=key)))
164 | with _convert_dm_env_rpc_error():
165 | self._connection.send(packed_request).Unpack(response)
166 |
167 | return tensor_utils.unpack_tensor(response.read_property.value)
168 |
169 | def write(self, key: str, value) -> None:
170 | """Writes the provided value to a property.
171 |
172 | Args:
173 | key: A string key that represents the property to write.
174 | value: A scalar (float, int, string, etc.), NumPy array, or nested lists.
175 | See tensor_utils.pack for more details.
176 | """
177 | packed_request = any_pb2.Any()
178 | packed_request.Pack(
179 | properties_pb2.PropertyRequest(
180 | write_property=properties_pb2.WritePropertyRequest(
181 | key=key, value=tensor_utils.pack_tensor(value))))
182 | with _convert_dm_env_rpc_error():
183 | self._connection.send(packed_request)
184 |
185 | def list(self, key: str = '') -> Sequence[PropertySpec]:
186 | """Lists properties residing under the provided key.
187 |
188 | Args:
189 | key: A string key to list properties at this location. If empty, returns
190 | properties registered at the root level.
191 |
192 | Returns:
193 | A sequence of PropertySpecs.
194 | """
195 | response = properties_pb2.PropertyResponse()
196 | packed_request = any_pb2.Any()
197 | packed_request.Pack(
198 | properties_pb2.PropertyRequest(
199 | list_property=properties_pb2.ListPropertyRequest(key=key)))
200 | with _convert_dm_env_rpc_error():
201 | self._connection.send(packed_request).Unpack(response)
202 |
203 | return tuple(
204 | PropertySpec(sub_property)
205 | for sub_property in response.list_property.values)
206 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/extensions/properties_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests Properties extension."""
16 |
17 | import contextlib
18 | from unittest import mock
19 |
20 | from absl.testing import absltest
21 | from dm_env import specs
22 | import numpy as np
23 |
24 | from google.protobuf import any_pb2
25 | from google.rpc import code_pb2
26 | from google.rpc import status_pb2
27 | from google.protobuf import text_format
28 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection
29 | from dm_env_rpc.v1 import dm_env_rpc_pb2
30 | from dm_env_rpc.v1 import error
31 | from dm_env_rpc.v1.extensions import properties
32 | from dm_env_rpc.v1.extensions import properties_pb2
33 |
34 |
35 | def _create_property_request_key(text_proto):
36 | extension_message = any_pb2.Any()
37 | extension_message.Pack(
38 | text_format.Parse(text_proto, properties_pb2.PropertyRequest()))
39 | return dm_env_rpc_pb2.EnvironmentRequest(
40 | extension=extension_message).SerializeToString()
41 |
42 |
43 | def _pack_property_response(text_proto):
44 | extension_message = any_pb2.Any()
45 | extension_message.Pack(
46 | text_format.Parse(text_proto, properties_pb2.PropertyResponse()))
47 | return dm_env_rpc_pb2.EnvironmentResponse(extension=extension_message)
48 |
49 | # Set of expected requests and associated responses for mock connection.
50 | _EXPECTED_REQUEST_RESPONSE_PAIRS = {
51 | _create_property_request_key('read_property { key: "foo" }'):
52 | _pack_property_response(
53 | 'read_property { value: { int32s: { array: 1 } } }'),
54 | _create_property_request_key("""write_property {
55 | key: "bar"
56 | value: { strings { array: "some_value" } }
57 | }"""):
58 | _pack_property_response('write_property {}'),
59 | _create_property_request_key('read_property { key: "bar" }'):
60 | _pack_property_response(
61 | 'read_property { value: { strings: { array: "some_value" } } }'),
62 | _create_property_request_key('list_property { key: "baz" }'):
63 | _pack_property_response("""list_property {
64 | values: {
65 | is_readable:true
66 | spec { name: "baz.fiz" dtype:UINT32 shape: 2 shape: 2 }
67 | }}"""),
68 | _create_property_request_key('list_property {}'):
69 | _pack_property_response("""list_property {
70 | values: { is_readable:true spec { name: "foo" dtype:INT32 }
71 | description: "This is a documented integer" }
72 | values: { is_readable:true
73 | is_writable:true
74 | spec { name: "bar" dtype:STRING } }
75 | values: { is_listable:true spec { name: "baz" } }
76 | }"""),
77 | _create_property_request_key('read_property { key: "bad_property" }'):
78 | dm_env_rpc_pb2.EnvironmentResponse(
79 | error=status_pb2.Status(message='invalid property request.')),
80 | _create_property_request_key('read_property { key: "invalid_key" }'):
81 | dm_env_rpc_pb2.EnvironmentResponse(
82 | error=status_pb2.Status(
83 | code=code_pb2.NOT_FOUND, message='Invalid key.')),
84 | _create_property_request_key("""write_property {
85 | key: "argument_test"
86 | value: { strings: { array: "invalid" } } }"""):
87 | dm_env_rpc_pb2.EnvironmentResponse(
88 | error=status_pb2.Status(
89 | code=code_pb2.INVALID_ARGUMENT, message='Invalid argument.')),
90 | _create_property_request_key('read_property { key: "permission_key" }'):
91 | dm_env_rpc_pb2.EnvironmentResponse(
92 | error=status_pb2.Status(
93 | code=code_pb2.PERMISSION_DENIED, message='No permission.'))
94 | }
95 |
96 |
97 | @contextlib.contextmanager
98 | def _create_mock_connection():
99 | """Helper to create mock dm_env_rpc connection."""
100 | with mock.patch.object(dm_env_rpc_connection,
101 | 'dm_env_rpc_pb2_grpc') as mock_grpc:
102 |
103 | def _process(request_iterator, **kwargs):
104 | del kwargs
105 | for request in request_iterator:
106 | yield _EXPECTED_REQUEST_RESPONSE_PAIRS[request.SerializeToString()]
107 |
108 | mock_stub_class = mock.MagicMock()
109 | mock_stub_class.Process = _process
110 | mock_grpc.EnvironmentStub.return_value = mock_stub_class
111 | yield dm_env_rpc_connection.Connection(mock.MagicMock())
112 |
113 |
114 | class PropertiesTest(absltest.TestCase):
115 |
116 | def test_read_property(self):
117 | with _create_mock_connection() as connection:
118 | extension = properties.PropertiesExtension(connection)
119 | self.assertEqual(1, extension['foo'])
120 |
121 | def test_write_property(self):
122 | with _create_mock_connection() as connection:
123 | extension = properties.PropertiesExtension(connection)
124 | extension['bar'] = 'some_value'
125 | self.assertEqual('some_value', extension['bar'])
126 |
127 | def test_list_property(self):
128 | with _create_mock_connection() as connection:
129 | extension = properties.PropertiesExtension(connection)
130 | property_specs = extension.specs('baz')
131 | self.assertLen(property_specs, 1)
132 |
133 | property_spec = property_specs['baz.fiz']
134 | self.assertTrue(property_spec.readable)
135 | self.assertFalse(property_spec.writable)
136 | self.assertFalse(property_spec.listable)
137 | self.assertEqual(
138 | specs.Array(shape=(2, 2), dtype=np.uint32), property_spec.spec)
139 |
140 | def test_root_list_property(self):
141 | with _create_mock_connection() as connection:
142 | extension = properties.PropertiesExtension(connection)
143 | property_specs = extension.specs()
144 | self.assertLen(property_specs, 3)
145 | self.assertTrue(property_specs['foo'].readable)
146 | self.assertTrue(property_specs['bar'].readable)
147 | self.assertTrue(property_specs['bar'].writable)
148 | self.assertTrue(property_specs['baz'].listable)
149 |
150 | def test_invalid_spec_request_on_listable_property(self):
151 | with _create_mock_connection() as connection:
152 | extension = properties.PropertiesExtension(connection)
153 | property_specs = extension.specs()
154 | self.assertTrue(property_specs['baz'].listable)
155 | self.assertIsNone(property_specs['baz'].spec)
156 |
157 | def test_invalid_request(self):
158 | with _create_mock_connection() as connection:
159 | extension = properties.PropertiesExtension(connection)
160 | with self.assertRaisesRegex(error.DmEnvRpcError,
161 | 'invalid property request.'):
162 | _ = extension['bad_property']
163 |
164 | def test_invalid_key_raises_key_error(self):
165 | with _create_mock_connection() as connection:
166 | extension = properties.PropertiesExtension(connection)
167 | with self.assertRaises(KeyError):
168 | _ = extension['invalid_key']
169 |
170 | def test_invalid_argument_raises_value_error(self):
171 | with _create_mock_connection() as connection:
172 | extension = properties.PropertiesExtension(connection)
173 | with self.assertRaises(ValueError):
174 | extension['argument_test'] = 'invalid'
175 |
176 | def test_permission_denied_raises_permission_error(self):
177 | with _create_mock_connection() as connection:
178 | extension = properties.PropertiesExtension(connection)
179 | with self.assertRaises(PermissionError):
180 | _ = extension['permission_key']
181 |
182 | def test_property_description(self):
183 | with _create_mock_connection() as connection:
184 | extension = properties.PropertiesExtension(connection)
185 | property_specs = extension.specs()
186 | self.assertEqual('This is a documented integer',
187 | property_specs['foo'].description)
188 |
189 | def test_property_print(self):
190 | with _create_mock_connection() as connection:
191 | extension = properties.PropertiesExtension(connection)
192 | property_specs = extension.specs()
193 | self.assertRegex(
194 | str(property_specs['foo']),
195 | (r'PropertySpec\(key=foo, readable=True, writable=False, '
196 | r'listable=False, spec=.*, '
197 | r'description=This is a documented integer\)'))
198 |
199 |
200 | if __name__ == '__main__':
201 | absltest.main()
202 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/message_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Helper functions used to process dm_env_rpc request / response messages.
16 | """
17 |
18 | import typing
19 | from typing import Iterable, NamedTuple, Type, Union
20 |
21 | import immutabledict
22 |
23 | from google.protobuf import any_pb2
24 | from google.protobuf import message
25 | from dm_env_rpc.v1 import dm_env_rpc_pb2
26 | from dm_env_rpc.v1 import error
27 |
28 |
29 | _MESSAGE_TYPE_TO_FIELD = immutabledict.immutabledict({
30 | field.message_type.name: field.name
31 | for field in dm_env_rpc_pb2.EnvironmentRequest.DESCRIPTOR.fields
32 | })
33 |
34 | # An unpacked extension request (anything).
35 | # As any proto message that is not a native request is accepted, this definition
36 | # is overly broad - use with care.
37 | DmEnvRpcExtensionMessage = message.Message
38 |
39 | # A packed extension request.
40 | # Wraps a DmEnvRpcExtensionMessage.
41 | DmEnvRpcPackedExtensionMessage = any_pb2.Any
42 |
43 | # A native request RPC.
44 | DmEnvRpcNativeRequest = Union[
45 | dm_env_rpc_pb2.CreateWorldRequest,
46 | dm_env_rpc_pb2.JoinWorldRequest,
47 | dm_env_rpc_pb2.StepRequest,
48 | dm_env_rpc_pb2.ResetRequest,
49 | dm_env_rpc_pb2.ResetWorldRequest,
50 | dm_env_rpc_pb2.LeaveWorldRequest,
51 | dm_env_rpc_pb2.DestroyWorldRequest,
52 | ]
53 |
54 | # A native response RPC.
55 | DmEnvRpcNativeResponse = Union[
56 | dm_env_rpc_pb2.CreateWorldResponse,
57 | dm_env_rpc_pb2.JoinWorldResponse,
58 | dm_env_rpc_pb2.StepResponse,
59 | dm_env_rpc_pb2.ResetResponse,
60 | dm_env_rpc_pb2.ResetWorldResponse,
61 | dm_env_rpc_pb2.LeaveWorldResponse,
62 | dm_env_rpc_pb2.DestroyWorldResponse,
63 | ]
64 |
65 | # A native request RPC, or an extension message, wrapped in Any.
66 | DmEnvRpcRequest = Union[DmEnvRpcNativeRequest, DmEnvRpcPackedExtensionMessage]
67 | # A native response RPC, or an extension message, wrapped in Any.
68 | DmEnvRpcResponse = Union[DmEnvRpcNativeResponse, DmEnvRpcPackedExtensionMessage]
69 |
70 |
71 | def pack_rpc_request(
72 | request: Union[DmEnvRpcRequest, DmEnvRpcExtensionMessage],
73 | ) -> DmEnvRpcRequest:
74 | """Returns a DmEnvRpcRequest that is suitable to send over the wire.
75 |
76 | Arguments:
77 | request: The request to pack.
78 |
79 | Returns:
80 | Native request - returned as-is.
81 | Packed extension (any_pb2.Any) - returned as-is.
82 | Everything else - assumed to be an extension; wrapped in any_pb2.Any.
83 | """
84 | if isinstance(request, typing.get_args(DmEnvRpcNativeRequest)):
85 | return typing.cast(DmEnvRpcNativeRequest, request)
86 | else:
87 | return _pack_message(request)
88 |
89 |
90 | def unpack_rpc_request(
91 | request: Union[DmEnvRpcRequest, DmEnvRpcExtensionMessage],
92 | *,
93 | extension_type: Union[
94 | Type[message.Message], Iterable[Type[message.Message]]
95 | ],
96 | ) -> Union[DmEnvRpcRequest, DmEnvRpcExtensionMessage]:
97 | """Returns a DmEnvRpcRequest without a wrapper around extension messages.
98 |
99 | Arguments:
100 | request: The request to unpack.
101 | extension_type: One or more extension protobuf classes of known extension
102 | types.
103 |
104 | Returns:
105 | Native request - returned as-is.
106 | Unpacked extension in |extension_type| - returned as-is.
107 | Packed extension (any_pb2.Any) in |extension_type| - returned unpacked
108 |
109 | Raises:
110 | ValueError: The message is packed (any_pb2.Any), but not in |extension_type|
111 | or the message type is not a native request or known extension.
112 | """
113 | if isinstance(request, typing.get_args(DmEnvRpcNativeRequest)):
114 | return request
115 | else:
116 | return _unpack_message(
117 | request,
118 | extension_type=extension_type)
119 |
120 |
121 | def pack_rpc_response(
122 | response: Union[DmEnvRpcResponse, DmEnvRpcExtensionMessage],
123 | ) -> DmEnvRpcResponse:
124 | """Returns a DmEnvRpcResponse that is suitable to send over the wire.
125 |
126 | Arguments:
127 | response: The response to pack.
128 |
129 | Returns:
130 | Native response - returned as-is.
131 | Packed extension (any_pb2.Any) - returned as-is.
132 | Everything else - assumed to be an extension; wrapped in any_pb2.Any.
133 | """
134 | if isinstance(response, typing.get_args(DmEnvRpcNativeResponse)):
135 | return typing.cast(DmEnvRpcNativeResponse, response)
136 | else:
137 | return _pack_message(response)
138 |
139 |
140 | def unpack_rpc_response(
141 | response: Union[DmEnvRpcResponse, DmEnvRpcExtensionMessage],
142 | *,
143 | extension_type: Union[
144 | Type[message.Message], Iterable[Type[message.Message]]
145 | ],
146 | ) -> Union[DmEnvRpcRequest, DmEnvRpcExtensionMessage]:
147 | """Returns a DmEnvRpcResponse without a wrapper around extension messages.
148 |
149 | Arguments:
150 | response: The response to unpack.
151 | extension_type: One or more extension protobuf classes of known extension
152 | types.
153 |
154 | Returns:
155 | Native response - returned as-is.
156 | Unpacked extension in |extension_type| - returned as-is.
157 | Packed extension (any_pb2.Any) in |extension_type| - returned unpacked
158 |
159 | Raises:
160 | ValueError: The message is packed (any_pb2.Any), but not in |extension_type|
161 | or the message type is not a native request or known extension.
162 | """
163 | if isinstance(response, typing.get_args(DmEnvRpcNativeResponse)):
164 | return response
165 | else:
166 | return _unpack_message(response, extension_type=extension_type)
167 |
168 |
169 | class EnvironmentRequestAndFieldName(NamedTuple):
170 | """EnvironmentRequest and field name used when packing."""
171 | environment_request: dm_env_rpc_pb2.EnvironmentRequest
172 | field_name: str
173 |
174 |
175 | def pack_environment_request(
176 | request: DmEnvRpcRequest) -> EnvironmentRequestAndFieldName:
177 | """Constructs an EnvironmentRequest containing a request message.
178 |
179 | Args:
180 | request: An instance of a dm_env_rpc Request type, such as
181 | CreateWorldRequest.
182 |
183 | Returns:
184 | A tuple of (environment_request, field_name) where:
185 | environment_request: dm_env_rpc.v1.EnvironmentRequest containing the input
186 | request message.
187 | field_name: Name of the environment request field holding the input
188 | request message.
189 | """
190 | field_name = _MESSAGE_TYPE_TO_FIELD[type(request).__name__]
191 | environment_request = dm_env_rpc_pb2.EnvironmentRequest()
192 | getattr(environment_request, field_name).CopyFrom(request)
193 | return EnvironmentRequestAndFieldName(environment_request, field_name)
194 |
195 |
196 | def unpack_environment_response(
197 | environment_response: dm_env_rpc_pb2.EnvironmentResponse,
198 | expected_field_name: str) -> DmEnvRpcResponse:
199 | """Extracts the response message contained within an EnvironmentResponse.
200 |
201 | Args:
202 | environment_response: An instance of dm_env_rpc.v1.EnvironmentResponse.
203 | expected_field_name: Name of the environment response field expected to be
204 | holding the dm_env_rpc response message.
205 |
206 | Returns:
207 | dm_env_rpc response message wrapped in the input environment response.
208 |
209 | Raises:
210 | DmEnvRpcError: The dm_env_rpc EnvironmentResponse contains an error.
211 | ValueError: The dm_env_rpc response message contained in the
212 | EnvironmentResponse is held in a different field from the one expected.
213 | """
214 | response_field_name = environment_response.WhichOneof('payload')
215 | if response_field_name == 'error':
216 | raise error.DmEnvRpcError(environment_response.error)
217 | elif response_field_name == expected_field_name:
218 | return getattr(environment_response, expected_field_name)
219 | else:
220 | raise ValueError('Unexpected response message! expected: '
221 | f'{expected_field_name}, actual: {response_field_name}')
222 |
223 |
224 | def _pack_message(msg) -> any_pb2.Any:
225 | """Helper to pack message into an Any proto."""
226 | if isinstance(msg, any_pb2.Any):
227 | return msg
228 |
229 | # Assume the message is an extension.
230 | packed = any_pb2.Any()
231 | packed.Pack(msg)
232 | return packed
233 |
234 |
235 | def _unpack_message(
236 | msg: message.Message,
237 | *,
238 | extension_type: Union[
239 | Type[message.Message], Iterable[Type[message.Message]]
240 | ],
241 | ):
242 | """Helper to unpack a message from set of possible extensions.
243 |
244 | Args:
245 | msg: The message to process.
246 | extension_type: Type or type(s) used to match extension messages. The first
247 | matching type is used.
248 |
249 | Returns:
250 | An upacked extension message with type within |extension_type|.
251 |
252 | Raises:
253 | TypeError: Raised if a return type could not be determined.
254 | """
255 | if isinstance(extension_type, type):
256 | extension_type = (extension_type,)
257 | else:
258 | extension_type = tuple(extension_type)
259 |
260 | if isinstance(msg, extension_type):
261 | return msg
262 |
263 | if isinstance(msg, any_pb2.Any):
264 | matching_type = next(
265 | (e for e in extension_type if msg.Is(e.DESCRIPTOR)), None)
266 |
267 | if not matching_type:
268 | raise ValueError(
269 | 'Extension type could not be found to unpack message: '
270 | f'{type(msg).__name__}.\n'
271 | f'Known Types:\n' + '\n'.join(f'- {e}' for e in extension_type))
272 |
273 | unpacked = matching_type()
274 | msg.Unpack(unpacked)
275 | return unpacked
276 |
277 | raise ValueError(
278 | f'Cannot unpack extension message with type: {type(msg).__name__}.'
279 | )
280 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/message_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests for dm_env_rpc/message_utils."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 |
20 | from google.rpc import status_pb2
21 | from dm_env_rpc.v1 import dm_env_rpc_pb2
22 | from dm_env_rpc.v1 import error
23 | from dm_env_rpc.v1 import message_utils
24 | from dm_env_rpc.v1 import tensor_utils
25 | from google.protobuf import any_pb2
26 |
27 |
28 | _CREATE_WORLD_REQUEST = dm_env_rpc_pb2.CreateWorldRequest(
29 | settings={'foo': tensor_utils.pack_tensor('bar')})
30 | _CREATE_WORLD_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse(world_name='qux')
31 | _CREATE_WORLD_ENVIRONMENT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse(
32 | create_world=_CREATE_WORLD_RESPONSE)
33 |
34 | # Anything that's not a "native" rpc message is an extension.
35 | _EXTENSION_TYPE = error.status_pb2.Status
36 | _EXTENSION_MULTI_TYPE = (dm_env_rpc_pb2.EnvironmentResponse, _EXTENSION_TYPE)
37 | _EXTENSION_MESSAGE = _EXTENSION_TYPE(code=42)
38 | _PACKED_EXTENSION_MESSAGE = any_pb2.Any()
39 | _PACKED_EXTENSION_MESSAGE.Pack(_EXTENSION_MESSAGE)
40 |
41 |
42 | class MessageUtilsTests(parameterized.TestCase):
43 |
44 | def test_pack_create_world_request(self):
45 | environment_request, field_name = message_utils.pack_environment_request(
46 | _CREATE_WORLD_REQUEST)
47 | self.assertEqual(field_name, 'create_world')
48 | self.assertEqual(environment_request.WhichOneof('payload'),
49 | 'create_world')
50 | self.assertEqual(environment_request.create_world,
51 | _CREATE_WORLD_REQUEST)
52 |
53 | def test_unpack_create_world_response(self):
54 | response = message_utils.unpack_environment_response(
55 | _CREATE_WORLD_ENVIRONMENT_RESPONSE, 'create_world')
56 | self.assertEqual(response, _CREATE_WORLD_RESPONSE)
57 |
58 | def test_unpack_error_response(self):
59 | with self.assertRaisesRegex(error.DmEnvRpcError, 'A test error.'):
60 | message_utils.unpack_environment_response(
61 | dm_env_rpc_pb2.EnvironmentResponse(
62 | error=status_pb2.Status(message='A test error.')),
63 | 'create_world')
64 |
65 | def test_unpack_incorrect_response(self):
66 | with self.assertRaisesWithLiteralMatch(
67 | ValueError,
68 | 'Unexpected response message! expected: create_world, actual: '
69 | 'leave_world'):
70 | message_utils.unpack_environment_response(
71 | dm_env_rpc_pb2.EnvironmentResponse(
72 | leave_world=dm_env_rpc_pb2.LeaveWorldResponse()),
73 | 'create_world')
74 |
75 | @parameterized.named_parameters(
76 | dict(
77 | testcase_name='create_world_request_passes_through',
78 | message=_CREATE_WORLD_REQUEST,
79 | expected=_CREATE_WORLD_REQUEST,
80 | ),
81 | dict(
82 | testcase_name='packed_extension_message_passes_through',
83 | message=_PACKED_EXTENSION_MESSAGE,
84 | expected=_PACKED_EXTENSION_MESSAGE,
85 | ),
86 | dict(
87 | testcase_name='extension_message_is_packed',
88 | message=_EXTENSION_MESSAGE,
89 | expected=_PACKED_EXTENSION_MESSAGE,
90 | ),
91 | )
92 | def test_pack_request(self, message, expected):
93 | self.assertEqual(message_utils.pack_rpc_request(message), expected)
94 |
95 | @parameterized.named_parameters(
96 | dict(
97 | testcase_name='create_world_response',
98 | message=_CREATE_WORLD_RESPONSE,
99 | expected=_CREATE_WORLD_RESPONSE,
100 | ),
101 | dict(
102 | testcase_name='extension_message',
103 | message=_EXTENSION_MESSAGE,
104 | expected=_PACKED_EXTENSION_MESSAGE,
105 | ),
106 | dict(
107 | testcase_name='packed_extension_message',
108 | message=_PACKED_EXTENSION_MESSAGE,
109 | expected=_PACKED_EXTENSION_MESSAGE,
110 | ),
111 | )
112 | def test_pack_response(self, message, expected):
113 | self.assertEqual(message_utils.pack_rpc_response(message), expected)
114 |
115 | @parameterized.named_parameters(
116 | dict(
117 | testcase_name='create_world_request_passes_through',
118 | message=_CREATE_WORLD_REQUEST,
119 | extensions=_EXTENSION_TYPE,
120 | expected=_CREATE_WORLD_REQUEST,
121 | ),
122 | dict(
123 | testcase_name='extension_message_passes_through',
124 | message=_EXTENSION_MESSAGE,
125 | extensions=_EXTENSION_TYPE,
126 | expected=_EXTENSION_MESSAGE,
127 | ),
128 | dict(
129 | testcase_name='extension_message_passes_through_with_multi',
130 | message=_EXTENSION_MESSAGE,
131 | extensions=_EXTENSION_MULTI_TYPE,
132 | expected=_EXTENSION_MESSAGE,
133 | ),
134 | dict(
135 | testcase_name='packed_extension_message_is_unpacked_with_multi',
136 | message=_PACKED_EXTENSION_MESSAGE,
137 | extensions=_EXTENSION_MULTI_TYPE,
138 | expected=_EXTENSION_MESSAGE,
139 | ),
140 | )
141 | def test_unpack_request(self, message, extensions, expected):
142 | self.assertEqual(
143 | message_utils.unpack_rpc_request(message, extension_type=extensions),
144 | expected)
145 |
146 | @parameterized.named_parameters(
147 | dict(
148 | testcase_name='create_world_response_passes_through',
149 | message=_CREATE_WORLD_RESPONSE,
150 | extensions=_EXTENSION_TYPE,
151 | expected=_CREATE_WORLD_RESPONSE,
152 | ),
153 | dict(
154 | testcase_name='extension_message_passes_through',
155 | message=_EXTENSION_MESSAGE,
156 | extensions=_EXTENSION_TYPE,
157 | expected=_EXTENSION_MESSAGE,
158 | ),
159 | dict(
160 | testcase_name='packed_extension_message_is_unpacked',
161 | message=_PACKED_EXTENSION_MESSAGE,
162 | extensions=_EXTENSION_TYPE,
163 | expected=_EXTENSION_MESSAGE,
164 | ),
165 | dict(
166 | testcase_name='extension_message_passes_through_with_multi',
167 | message=_EXTENSION_MESSAGE,
168 | extensions=_EXTENSION_MULTI_TYPE,
169 | expected=_EXTENSION_MESSAGE,
170 | ),
171 | dict(
172 | testcase_name='packed_extension_message_is_unpacked_with_multi',
173 | message=_PACKED_EXTENSION_MESSAGE,
174 | extensions=_EXTENSION_MULTI_TYPE,
175 | expected=_EXTENSION_MESSAGE,
176 | ),
177 | )
178 | def test_unpack_response(self, message, extensions, expected):
179 | self.assertEqual(
180 | message_utils.unpack_rpc_response(message, extension_type=extensions),
181 | expected)
182 |
183 | if __name__ == '__main__':
184 | absltest.main()
185 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/spec_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Manager class to manage the dm_env_rpc UID system."""
16 |
17 | from typing import Any, Collection, Mapping, MutableMapping
18 | import numpy as np
19 |
20 | from dm_env_rpc.v1 import dm_env_rpc_pb2
21 | from dm_env_rpc.v1 import tensor_utils
22 |
23 |
24 | def _assert_shapes_match(tensor: dm_env_rpc_pb2.Tensor,
25 | dm_env_rpc_spec: dm_env_rpc_pb2.TensorSpec):
26 | """Raises ValueError if shape of tensor and spec don't match."""
27 | tensor_shape = np.asarray(tensor.shape)
28 | spec_shape = np.asarray(dm_env_rpc_spec.shape)
29 |
30 | # Check all elements are equal, or the spec element is -1 (variable length).
31 | if tensor_shape.size != spec_shape.size or not np.all(
32 | (tensor_shape == spec_shape) | (spec_shape < 0)):
33 | raise ValueError(
34 | 'Received dm_env_rpc tensor {} with shape {} but spec has shape {}.'
35 | .format(dm_env_rpc_spec.name, tensor_shape, spec_shape))
36 |
37 |
38 | class SpecManager(object):
39 | """Manages transitions between Python dicts and dm_env_rpc UIDs.
40 |
41 | To make sending and receiving actions and observations easier for dm_env_rpc,
42 | this helps manage the transition between UID-keyed dicts mapping to dm_env_rpc
43 | tensors and string-keyed dicts mapping to scalars, lists, or NumPy arrays.
44 | """
45 |
46 | def __init__(self, specs: Mapping[int, dm_env_rpc_pb2.TensorSpec]):
47 | """Builds the SpecManager from the given dm_env_rpc specs.
48 |
49 | Args:
50 | specs: A dict mapping UIDs to dm_env_rpc TensorSpecs, similar to what is
51 | stored in `actions` and `observations` in ActionObservationSpecs.
52 | """
53 | for spec in specs.values():
54 | if np.count_nonzero(np.asarray(spec.shape) < 0) > 1:
55 | raise ValueError(
56 | f'"{spec.name}" shape has > 1 variable length dimension. '
57 | f'Spec:\n{spec}')
58 |
59 | self._name_to_uid = {spec.name: uid for uid, spec in specs.items()}
60 | self._uid_to_name = {uid: spec.name for uid, spec in specs.items()}
61 | if len(self._name_to_uid) != len(self._uid_to_name):
62 | raise ValueError('There are duplicate names in the tensor specs.')
63 |
64 | self._specs_by_uid = specs
65 | self._specs_by_name = {spec.name: spec for spec in specs.values()}
66 |
67 | @property
68 | def specs_by_uid(self) -> Mapping[int, dm_env_rpc_pb2.TensorSpec]:
69 | return self._specs_by_uid
70 |
71 | @property
72 | def specs_by_name(self) -> Mapping[str, dm_env_rpc_pb2.TensorSpec]:
73 | return self._specs_by_name
74 |
75 | def name_to_uid(self, name: str) -> int:
76 | """Returns the UID for the given name."""
77 | return self._name_to_uid[name]
78 |
79 | def uid_to_name(self, uid: int) -> str:
80 | """Returns the name for the given UID."""
81 | return self._uid_to_name[uid]
82 |
83 | def name_to_spec(self, name: str) -> dm_env_rpc_pb2.TensorSpec:
84 | """Returns the dm_env_rpc TensorSpec named `name`."""
85 | return self._specs_by_name[name]
86 |
87 | def uid_to_spec(self, uid: int) -> dm_env_rpc_pb2.TensorSpec:
88 | """Returns the dm_env_rpc TensorSpec for the given UID."""
89 | return self._specs_by_uid[uid]
90 |
91 | def names(self) -> Collection[str]:
92 | """Returns the spec names in no particular order."""
93 | return self._name_to_uid.keys()
94 |
95 | def uids(self) -> Collection[int]:
96 | """Returns the spec UIDs in no particular order."""
97 | return self._uid_to_name.keys()
98 |
99 | def unpack(
100 | self, dm_env_rpc_tensors: Mapping[int, dm_env_rpc_pb2.Tensor]
101 | ) -> MutableMapping[str, Any]:
102 | """Unpacks a dm_env_rpc uid-to-tensor map to a name-keyed Python dict.
103 |
104 | Args:
105 | dm_env_rpc_tensors: A dict mapping UIDs to dm_env_rpc tensor protos.
106 |
107 | Returns:
108 | A dict mapping names to scalars and arrays.
109 | """
110 | unpacked = {}
111 | for uid, tensor in dm_env_rpc_tensors.items():
112 | name = self._uid_to_name[uid]
113 | dm_env_rpc_spec = self.name_to_spec(name)
114 | _assert_shapes_match(tensor, dm_env_rpc_spec)
115 | tensor_dtype = tensor_utils.get_tensor_type(tensor)
116 | spec_dtype = tensor_utils.data_type_to_np_type(dm_env_rpc_spec.dtype)
117 | if tensor_dtype != spec_dtype:
118 | raise ValueError(
119 | 'Received dm_env_rpc tensor {} with dtype {} but spec has dtype {}.'
120 | .format(name, tensor_dtype, spec_dtype))
121 | tensor_unpacked = tensor_utils.unpack_tensor(tensor)
122 | unpacked[name] = tensor_unpacked
123 | return unpacked
124 |
125 | def pack(
126 | self,
127 | tensors: Mapping[str, Any]) -> MutableMapping[int, dm_env_rpc_pb2.Tensor]:
128 | """Packs a name-keyed Python dict to a dm_env_rpc uid-to-tensor map.
129 |
130 | Args:
131 | tensors: A dict mapping string names to scalars and arrays.
132 |
133 | Returns:
134 | A dict mapping UIDs to dm_env_rpc tensor protos.
135 | """
136 | packed = {}
137 | for name, value in tensors.items():
138 | dm_env_rpc_spec = self.name_to_spec(name)
139 | tensor = tensor_utils.pack_tensor(value, dtype=dm_env_rpc_spec.dtype)
140 | _assert_shapes_match(tensor, dm_env_rpc_spec)
141 | packed[self.name_to_uid(name)] = tensor
142 | return packed
143 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/spec_manager_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Tests for SpecManager class."""
16 |
17 | from absl.testing import absltest
18 | import numpy as np
19 |
20 | from dm_env_rpc.v1 import dm_env_rpc_pb2
21 | from dm_env_rpc.v1 import spec_manager
22 | from dm_env_rpc.v1 import tensor_utils
23 |
24 | _EXAMPLE_SPECS = {
25 | 54:
26 | dm_env_rpc_pb2.TensorSpec(
27 | name='fuzz', shape=[2], dtype=dm_env_rpc_pb2.DataType.FLOAT),
28 | 55:
29 | dm_env_rpc_pb2.TensorSpec(
30 | name='foo', shape=[3], dtype=dm_env_rpc_pb2.DataType.INT32),
31 | }
32 |
33 |
34 | class SpecManagerTests(absltest.TestCase):
35 |
36 | def setUp(self):
37 | super(SpecManagerTests, self).setUp()
38 | self._spec_manager = spec_manager.SpecManager(_EXAMPLE_SPECS)
39 |
40 | def test_specs_by_uid(self):
41 | self.assertDictEqual(_EXAMPLE_SPECS, self._spec_manager.specs_by_uid)
42 |
43 | def test_specs_by_name(self):
44 | expected = {'foo': _EXAMPLE_SPECS[55], 'fuzz': _EXAMPLE_SPECS[54]}
45 | self.assertDictEqual(expected, self._spec_manager.specs_by_name)
46 |
47 | def test_name_to_uid(self):
48 | self.assertEqual(55, self._spec_manager.name_to_uid('foo'))
49 |
50 | def test_name_to_uid_no_such_name(self):
51 | with self.assertRaisesRegex(KeyError, 'bar'):
52 | self._spec_manager.name_to_uid('bar')
53 |
54 | def test_name_to_spec(self):
55 | spec = self._spec_manager.name_to_spec('foo')
56 | self.assertEqual([3], spec.shape)
57 |
58 | def test_name_to_spec_no_such_name(self):
59 | with self.assertRaisesRegex(KeyError, 'bar'):
60 | self._spec_manager.name_to_spec('bar')
61 |
62 | def test_uid_to_name(self):
63 | self.assertEqual('foo', self._spec_manager.uid_to_name(55))
64 |
65 | def test_uid_to_name_no_such_uid(self):
66 | with self.assertRaisesRegex(KeyError, '56'):
67 | self._spec_manager.uid_to_name(56)
68 |
69 | def test_names(self):
70 | self.assertEqual(set(['foo', 'fuzz']), self._spec_manager.names())
71 |
72 | def test_uids(self):
73 | self.assertEqual(set([54, 55]), self._spec_manager.uids())
74 |
75 | def test_uid_to_spec(self):
76 | spec = self._spec_manager.uid_to_spec(54)
77 | self.assertEqual([2], spec.shape)
78 |
79 | def test_pack(self):
80 | packed = self._spec_manager.pack({'fuzz': [1.0, 2.0], 'foo': [3, 4, 5]})
81 | expected = {
82 | 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32),
83 | 55: tensor_utils.pack_tensor([3, 4, 5], dtype=np.int32),
84 | }
85 | self.assertDictEqual(expected, packed)
86 |
87 | def test_partial_pack(self):
88 | packed = self._spec_manager.pack({
89 | 'fuzz': [1.0, 2.0],
90 | })
91 | expected = {
92 | 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32),
93 | }
94 | self.assertDictEqual(expected, packed)
95 |
96 | def test_pack_unknown_key_raises_error(self):
97 | with self.assertRaisesRegex(KeyError, 'buzz'):
98 | self._spec_manager.pack({'buzz': 'hello'})
99 |
100 | def test_pack_wrong_shape_raises_error(self):
101 | with self.assertRaisesRegex(ValueError, 'shape'):
102 | self._spec_manager.pack({'foo': [1, 2]})
103 |
104 | def test_pack_wrong_dtype_raises_error(self):
105 | with self.assertRaises(ValueError):
106 | self._spec_manager.pack({'foo': 'hello'})
107 |
108 | def test_pack_cast_int_to_float_is_ok(self):
109 | packed = self._spec_manager.pack({'fuzz': [1, 2]})
110 | self.assertEqual([1.0, 2.0], packed[54].floats.array)
111 |
112 | def test_unpack(self):
113 | unpacked = self._spec_manager.unpack({
114 | 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32),
115 | 55: tensor_utils.pack_tensor([3, 4, 5], dtype=np.int32),
116 | })
117 | self.assertLen(unpacked, 2)
118 | np.testing.assert_array_equal(np.asarray([1.0, 2.0]), unpacked['fuzz'])
119 | np.testing.assert_array_equal(np.asarray([3, 4, 5]), unpacked['foo'])
120 |
121 | def test_partial_unpack(self):
122 | unpacked = self._spec_manager.unpack({
123 | 54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32),
124 | })
125 | self.assertLen(unpacked, 1)
126 | np.testing.assert_array_equal(np.asarray([1.0, 2.0]), unpacked['fuzz'])
127 |
128 | def test_unpack_unknown_uid_raises_error(self):
129 | with self.assertRaisesRegex(KeyError, '53'):
130 | self._spec_manager.unpack({53: tensor_utils.pack_tensor('foo')})
131 |
132 | def test_unpack_wrong_shape_raises_error(self):
133 | with self.assertRaisesRegex(ValueError, 'shape'):
134 | self._spec_manager.unpack({55: tensor_utils.pack_tensor([1, 2])})
135 |
136 | def test_unpack_wrong_type_raises_error(self):
137 | with self.assertRaisesRegex(ValueError, 'dtype'):
138 | self._spec_manager.unpack(
139 | {55: tensor_utils.pack_tensor([1, 2, 3], dtype=np.float32)})
140 |
141 |
142 | class SpecManagerVariableSpecShapeTests(absltest.TestCase):
143 |
144 | def setUp(self):
145 | super(SpecManagerVariableSpecShapeTests, self).setUp()
146 | specs = {
147 | 101:
148 | dm_env_rpc_pb2.TensorSpec(
149 | name='foo', shape=[1, -1], dtype=dm_env_rpc_pb2.DataType.INT32),
150 | }
151 | self._spec_manager = spec_manager.SpecManager(specs)
152 |
153 | def test_variable_spec_shape(self):
154 | packed = self._spec_manager.pack({'foo': [[1, 2, 3, 4]]})
155 | expected = {
156 | 101: tensor_utils.pack_tensor([[1, 2, 3, 4]], dtype=np.int32),
157 | }
158 | self.assertDictEqual(expected, packed)
159 |
160 | def test_invalid_variable_shape(self):
161 | with self.assertRaisesRegex(ValueError, 'shape'):
162 | self._spec_manager.pack({'foo': np.ones((1, 2, 3), dtype=np.int32)})
163 |
164 | def test_empty_variable_shape(self):
165 | manager = spec_manager.SpecManager({
166 | 1:
167 | dm_env_rpc_pb2.TensorSpec(
168 | name='bar', shape=[], dtype=dm_env_rpc_pb2.DataType.INT32)
169 | })
170 | with self.assertRaisesRegex(ValueError, 'shape'):
171 | manager.pack({'bar': np.ones((1), dtype=np.int32)})
172 |
173 | def test_invalid_variable_spec_shape(self):
174 | with self.assertRaisesRegex(ValueError, 'shape has > 1 variable length'):
175 | spec_manager.SpecManager({
176 | 1:
177 | dm_env_rpc_pb2.TensorSpec(
178 | name='bar',
179 | shape=[1, -1, -1],
180 | dtype=dm_env_rpc_pb2.DataType.INT32)
181 | })
182 |
183 |
184 | class SpecManagerConstructorTests(absltest.TestCase):
185 |
186 | def test_duplicate_names_raise_error(self):
187 | specs = {
188 | 54:
189 | dm_env_rpc_pb2.TensorSpec(
190 | name='fuzz', shape=[3], dtype=dm_env_rpc_pb2.DataType.FLOAT),
191 | 55:
192 | dm_env_rpc_pb2.TensorSpec(
193 | name='fuzz', shape=[2], dtype=dm_env_rpc_pb2.DataType.FLOAT),
194 | }
195 | with self.assertRaisesRegex(ValueError, 'duplicate name'):
196 | spec_manager.SpecManager(specs)
197 |
198 |
199 | if __name__ == '__main__':
200 | absltest.main()
201 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/tensor_spec_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Helper Python utilities for managing dm_env_rpc TensorSpecs."""
16 |
17 | import dataclasses
18 | from typing import Generic, TypeVar, Union
19 | import numpy as np
20 |
21 | from dm_env_rpc.v1 import dm_env_rpc_pb2
22 | from dm_env_rpc.v1 import tensor_utils
23 |
24 |
25 | _BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE = (
26 | 'TensorSpec "{name}"\'s bounds [{minimum}, {maximum}] contain value(s) '
27 | 'that cannot be safely cast to dtype {dtype}.')
28 |
29 | T = TypeVar('T')
30 |
31 |
32 | @dataclasses.dataclass
33 | class Bounds(Generic[T]):
34 | min: T
35 | max: T
36 |
37 |
38 | def _is_valid_bound(array_or_scalar, np_dtype: np.dtype) -> bool:
39 | """Returns whether array_or_scalar is a valid bound for the given type."""
40 | array_or_scalar = np.asarray(array_or_scalar)
41 | if np.issubdtype(array_or_scalar.dtype, np.integer):
42 | iinfo = np.iinfo(np_dtype)
43 | for value in np.asarray(array_or_scalar).flat:
44 | if int(value) < iinfo.min or int(value) > iinfo.max:
45 | return False
46 | elif np.issubdtype(array_or_scalar.dtype, np.floating):
47 | finfo = np.finfo(np_dtype)
48 | for value in np.asarray(array_or_scalar).flat:
49 | if (float(value) < finfo.min and float(value) != -np.inf) or (
50 | float(value) > finfo.max and float(value) != np.inf
51 | ):
52 | return False
53 | return True
54 |
55 |
56 | def np_range_info(np_type: ...) -> Union[np.finfo, np.iinfo]:
57 | """Returns type info for `np_type`, which includes min and max attributes."""
58 | np_type = np.dtype(np_type)
59 | if issubclass(np_type.type, np.floating):
60 | return np.finfo(np_type)
61 | elif issubclass(np_type.type, np.integer):
62 | return np.iinfo(np_type)
63 | else:
64 | raise ValueError('{} does not have range info.'.format(np_type))
65 |
66 |
67 | def _get_value(min_max_value, shape, default):
68 | """Helper function that returns the min/max bounds for a Value message.
69 |
70 | Args:
71 | min_max_value: Value protobuf message to get value from.
72 | shape: Optional dimensions to unpack payload data to.
73 | default: Value to use if min_max_value is not set.
74 |
75 | Returns:
76 | A scalar if `shape` is empty or None, or an unpacked NumPy array of either
77 | the unpacked value or provided default.
78 |
79 | """
80 | which = min_max_value.WhichOneof('payload')
81 | value = which and getattr(min_max_value, which)
82 |
83 | if value is None:
84 | min_max = default
85 | else:
86 | unpacked = tensor_utils.unpack_proto(min_max_value)
87 | min_max = tensor_utils.reshape_array(
88 | unpacked, shape) if len(unpacked) > 1 else unpacked[0]
89 |
90 | if (shape is not None
91 | and np.any(np.array(shape) < 0)
92 | and np.asarray(min_max).size > 1):
93 | raise ValueError(
94 | "TensorSpec's with variable length shapes can only have scalar ranges. "
95 | 'Shape: {}, value: {}'.format(shape, min_max))
96 | return min_max
97 |
98 |
99 | def bounds(tensor_spec: dm_env_rpc_pb2.TensorSpec) -> Bounds:
100 | """Gets the inclusive bounds of `tensor_spec`.
101 |
102 | Args:
103 | tensor_spec: An instance of a dm_env_rpc TensorSpec proto.
104 |
105 | Returns:
106 | A named tuple (`min`, `max`) of inclusive bounds.
107 |
108 | Raises:
109 | ValueError: `tensor_spec` does not have a numeric dtype, or the type of its
110 | `min` or `max` does not match its dtype, or the the bounds are invalid in
111 | some way.
112 | """
113 | np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype)
114 | tensor_spec_type = dm_env_rpc_pb2.DataType.Name(tensor_spec.dtype).lower()
115 | if not issubclass(np_type.type, np.number):
116 | raise ValueError('TensorSpec "{}" has non-numeric type {}.'
117 | .format(tensor_spec.name, tensor_spec_type))
118 |
119 | # Check min payload type matches the tensor type.
120 | min_which = tensor_spec.min.WhichOneof('payload')
121 | if min_which and not min_which.startswith(tensor_spec_type):
122 | raise ValueError('TensorSpec "{}" has dtype {} but min type {}.'.format(
123 | tensor_spec.name, tensor_spec_type, min_which))
124 |
125 | # Check max payload type matches the tensor type.
126 | max_which = tensor_spec.max.WhichOneof('payload')
127 | if max_which and not max_which.startswith(tensor_spec_type):
128 | raise ValueError('TensorSpec "{}" has dtype {} but max type {}.'.format(
129 | tensor_spec.name, tensor_spec_type, max_which))
130 |
131 | dtype_bounds = np_range_info(np_type)
132 | min_bound = _get_value(tensor_spec.min, tensor_spec.shape, dtype_bounds.min)
133 | max_bound = _get_value(tensor_spec.max, tensor_spec.shape, dtype_bounds.max)
134 |
135 | if not _is_valid_bound(min_bound, np_type) or not _is_valid_bound(
136 | max_bound, np_type
137 | ):
138 | raise ValueError(
139 | _BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE.format(
140 | name=tensor_spec.name,
141 | minimum=min_bound,
142 | maximum=max_bound,
143 | dtype=tensor_spec_type))
144 |
145 | if np.any(max_bound < min_bound):
146 | raise ValueError('TensorSpec "{}" has min {} larger than max {}.'.format(
147 | tensor_spec.name, min_bound, max_bound))
148 |
149 | return Bounds(np_type.type(min_bound), np_type.type(max_bound))
150 |
151 |
152 | def set_bounds(tensor_spec: dm_env_rpc_pb2.TensorSpec, minimum, maximum):
153 | """Modifies `tensor_spec` to have its inclusive bounds set.
154 |
155 | Packs `minimum` in to `tensor_spec.min` and `maximum` in to `tensor_spec.max`.
156 |
157 | Args:
158 | tensor_spec: An instance of a dm_env_rpc TensorSpec proto. It should
159 | already have its `name`, `dtype` and `shape` attributes set.
160 | minimum: The minimum value that elements in the described tensor can obtain.
161 | A scalar, iterable of scalars, or None. If None, `min` will be cleared on
162 | `tensor_spec`.
163 | maximum: The maximum value that elements in the described tensor can obtain.
164 | A scalar, iterable of scalars, or None. If None, `max` will be cleared on
165 | `tensor_spec`.
166 | """
167 | np_type = tensor_utils.data_type_to_np_type(tensor_spec.dtype)
168 | if not issubclass(np_type.type, np.number):
169 | raise ValueError(f'TensorSpec has non-numeric type "{np_type}".')
170 |
171 | has_min = minimum is not None
172 | has_max = maximum is not None
173 |
174 | if ((has_min and not _is_valid_bound(minimum, np_type)) or
175 | (has_max and not _is_valid_bound(maximum, np_type))):
176 | raise ValueError(
177 | _BOUNDS_CANNOT_BE_SAFELY_CAST_TO_DTYPE.format(
178 | name=tensor_spec.name,
179 | minimum=minimum,
180 | maximum=maximum,
181 | dtype=dm_env_rpc_pb2.DataType.Name(tensor_spec.dtype)))
182 |
183 | if has_min:
184 | minimum = np.asarray(minimum, dtype=np_type)
185 | if minimum.size != 1 and minimum.shape != tuple(tensor_spec.shape):
186 | raise ValueError(
187 | f'minimum has shape {minimum.shape}, which is incompatible with '
188 | f"tensor_spec {tensor_spec.name}'s shape {tensor_spec.shape}.")
189 |
190 | if has_max:
191 | maximum = np.asarray(maximum, dtype=np_type)
192 | if maximum.size != 1 and maximum.shape != tuple(tensor_spec.shape):
193 | raise ValueError(
194 | f'maximum has shape {maximum.shape}, which is incompatible with '
195 | f"tensor_spec {tensor_spec.name}'s shape {tensor_spec.shape}.")
196 |
197 | if (has_min and has_max and np.any(maximum < minimum)):
198 | raise ValueError('TensorSpec "{}" has min {} larger than max {}.'.format(
199 | tensor_spec.name, minimum, maximum))
200 |
201 | packer = tensor_utils.get_packer(np_type)
202 | if has_min:
203 | packer.pack(tensor_spec.min, minimum)
204 | else:
205 | tensor_spec.ClearField('min')
206 |
207 | if has_max:
208 | packer.pack(tensor_spec.max, maximum)
209 | else:
210 | tensor_spec.ClearField('max')
211 |
--------------------------------------------------------------------------------
/dm_env_rpc/v1/tensor_utils_benchmark.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Micro-benchmark for tensor_utils.pack_tensor."""
16 |
17 | import abc
18 | import timeit
19 |
20 | from absl import app
21 | from absl import flags
22 | import numpy as np
23 |
24 | from dm_env_rpc.v1 import tensor_utils
25 |
26 | flags.DEFINE_integer('repeats', 10000,
27 | 'Number of times each benchmark will run.')
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | class _AbstractBenchmark(metaclass=abc.ABCMeta):
32 | """Base class for benchmarks using timeit."""
33 |
34 | def run(self):
35 | time = timeit.timeit(self.statement, setup=self.setup, number=FLAGS.repeats)
36 | print(f'{self.name} -- overall: {time:0.2f} s, '
37 | f'per call: {time/FLAGS.repeats:0.1e} s')
38 |
39 | def setup(self):
40 | pass
41 |
42 | @abc.abstractmethod
43 | def statement(self):
44 | pass
45 |
46 | @abc.abstractproperty
47 | def name(self):
48 | pass
49 |
50 |
51 | class _PackBenchmark(_AbstractBenchmark):
52 | """Benchmark for packing a numpy array to a Tensor proto."""
53 |
54 | def __init__(self, dtype, shape):
55 | self._name = f'pack {np.dtype(dtype).name}'
56 | self._dtype = dtype
57 | self._shape = shape
58 |
59 | @property
60 | def name(self):
61 | return self._name
62 |
63 | def setup(self):
64 | # Use non-zero values in case there's something special about zero arrays.
65 | self._unpacked = np.arange(
66 | np.prod(self._shape), dtype=self._dtype).reshape(self._shape)
67 |
68 | def statement(self):
69 | self._unpacked.flat[0] += 1 # prevent caching of the result
70 | tensor_utils.pack_tensor(self._unpacked, self._dtype)
71 |
72 |
73 | class _UnpackBenchmark(_AbstractBenchmark):
74 | """Benchmark for unpacking a Tensor proto to a numpy array."""
75 |
76 | def __init__(self, dtype, shape):
77 | self._name = f'unpack {np.dtype(dtype).name}'
78 | self._shape = shape
79 | self._dtype = dtype
80 |
81 | @property
82 | def name(self):
83 | return self._name
84 |
85 | def setup(self):
86 | # Use non-zero values in case there's something special about zero arrays.
87 | tensor = np.arange(
88 | np.prod(self._shape), dtype=self._dtype).reshape(self._shape)
89 | self._packed = tensor_utils.pack_tensor(tensor, self._dtype)
90 |
91 | def statement(self):
92 | tensor_utils.unpack_tensor(self._packed)
93 |
94 |
95 | def main(argv):
96 | if len(argv) > 1:
97 | raise app.UsageError('Too many command-line arguments.')
98 | # Pick `shape` such that number of bytes is consistent between benchmarks.
99 | benchmarks = (
100 | _PackBenchmark(dtype=np.uint8, shape=(128, 128, 3)),
101 | _PackBenchmark(dtype=np.int32, shape=(64, 64, 3)),
102 | _PackBenchmark(dtype=np.int64, shape=(32, 64, 3)),
103 | _PackBenchmark(dtype=np.uint32, shape=(64, 64, 3)),
104 | _PackBenchmark(dtype=np.uint64, shape=(32, 64, 3)),
105 | _PackBenchmark(dtype=np.float32, shape=(64, 64, 3)),
106 | _PackBenchmark(dtype=np.float64, shape=(32, 64, 3)),
107 | _UnpackBenchmark(dtype=np.uint8, shape=(128, 128, 3)),
108 | _UnpackBenchmark(dtype=np.int32, shape=(64, 64, 3)),
109 | _UnpackBenchmark(dtype=np.int64, shape=(32, 64, 3)),
110 | _UnpackBenchmark(dtype=np.uint32, shape=(64, 64, 3)),
111 | _UnpackBenchmark(dtype=np.uint64, shape=(32, 64, 3)),
112 | _UnpackBenchmark(dtype=np.float32, shape=(64, 64, 3)),
113 | _UnpackBenchmark(dtype=np.float64, shape=(32, 64, 3)),
114 | )
115 | for benchmark in benchmarks:
116 | benchmark.run()
117 |
118 |
119 | if __name__ == '__main__':
120 | app.run(main)
121 |
--------------------------------------------------------------------------------
/docs/v1/2x2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/2x2.png
--------------------------------------------------------------------------------
/docs/v1/2x3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/2x3.png
--------------------------------------------------------------------------------
/docs/v1/appendix.md:
--------------------------------------------------------------------------------
1 | ## Advice for implementers
2 |
3 | A server implementation does not need to implement the full protocol to be
4 | usable. When first starting, just supporting `JoinWorld`, `LeaveWorld` and
5 | `Step` is sufficient to provide a minimum `dm_env_rpc` environment (though
6 | clients using the provided `DmEnvAdaptor` will need access to `Reset` as well).
7 | The world can be considered already created just by virtue of the server being
8 | available and listening on a port. Unsupported requests should just return an
9 | error. After this base, the following can be added:
10 |
11 | * `CreateWorld` and `DestroyWorld` can provide a way for clients to give
12 | settings and manage the lifetime of worlds.
13 | * Supporting sequences, where one sequence reaches its natural conclusion and
14 | the next begins on the next step.
15 | * `Reset` and `ResetWorld` can provide mechanisms to manage transitions
16 | between sequences and to update settings.
17 | * Read/Write/List properties for client code to query or set state outside of
18 | the normal observe/act loop.
19 |
20 | A client implementation likely has to support the full range of features and
21 | data types that the server it wants to interact with supports. However, it's
22 | unnecessary to support data types or features that the server does not
23 | implement. If the target server does not provide tensors with more than one
24 | dimension it probably isn't worth the effort to support higher dimensional
25 | tensors, for instance.
26 |
27 | For Python clients, some python utility code is provided. Specifically:
28 |
29 | * connection.py - A utility class for managing the client-side gRPC connection
30 | and the error handling from server responses.
31 | * spec_manager.py - A utility class for managing the UID-to-name mapping, so
32 | the rest of the client code can use the more human readable tensor names.
33 | * tensor_utils.py - A few utility functions for packing and unpacking NumPy
34 | arrays to `dm_env_rpc` tensors.
35 |
36 | For other languages similar utility functions will likely be needed. It is
37 | especially recommended that client code have something similar to `SpecManager`
38 | that turns UIDs into human readable text as soon as possible. Strings are both
39 | less likely to be used wrong and more easily debugged.
40 |
41 | ### Common state transition errors
42 |
43 | The [Environment connection state machine](overview.md#states) cannot transition
44 | directly between `INTERRUPTED` and `TERMINATED` states. Servers may
45 | inadvertently cause an `INTERRUPTED` --> `TERMINATED` transition if:
46 |
47 | 1. The server environment starts in a completed state (and therefore finishes
48 | in 0 steps), or
49 | 1. The server environment setup fails but does not throw an exception until the
50 | first step.
51 |
52 | To ensure server state transitions are legal, please check the validity of the
53 | server environment when it is initialized.
54 |
55 | ## Reward functions
56 |
57 | [Reward function design is difficult](https://www.alexirpan.com/2018/02/14/rl-hard.html#reward-function-design-is-difficult),
58 | and may be the first thing a client will tweak.
59 |
60 | For instance, a simple arcade game could expose its score function as `reward`.
61 | However, the score function might not actually be a good measure of playing
62 | ability if there's a series of actions which increases score without advancing
63 | towards the actual desired goal. Alternatively, games might have an obvious
64 | reward signal only at the end (whether the agent won or not), which might be too
65 | sparse for a reinforcement learning agent. Constructing a robust reward function
66 | is an area of active research and it's perfectly reasonable for an environment
67 | to abdicate the responsibility of forming one.
68 |
69 | For instance, constructing a reward function for the game of chess is actually
70 | rather tricky. A simple reward at the end if an agent wins is going to make
71 | learning difficult, as agents won't get feedback during a game if they make
72 | mistakes or play well. Beginner chess players often use
73 | [material values](https://en.wikipedia.org/wiki/Chess_piece_relative_value) to
74 | evaluate a given position to decide who is winning, with queens being worth more
75 | than rooks, etc. This might seem like a prime candidate for a reward function,
76 | however there are
77 | [well known shortcomings](https://en.wikipedia.org/wiki/Chess_piece_relative_value#Shortcomings_of_piece_valuation_systems)
78 | to this simple system. For instance, using this as a reward function can blind
79 | an agent to a move which loses material but produces a checkmate.
80 |
81 | If a client does construct a custom reward function it may want access to data
82 | which normally would be considered hidden and unavailable. Exposing this
83 | information to clients may feel like cheating, however getting an AI agent to
84 | start learning at all is often half the battle. To this end servers should
85 | expose as much relevant information through the environment as they can to give
86 | clients room to experiment. Weaning an agent off this information down the line
87 | may be possible; just be sure to document which observables are considered
88 | hidden information so a client can strip them in the future.
89 |
90 | For client implementations, if a given server does not provide a reward or
91 | discount observable or they aren't suitable you can build your own from other
92 | observables. For instance, the provided `DmEnvAdaptor` has `reward` and
93 | `discount` functions which can be overridden in derived classes.
94 |
95 | ## Nested observations example
96 |
97 | Retrofitting an existing object-oriented codebase to be a `dm_env_rpc` server
98 | can be difficult, as the data is likely arranged in a tree-like structure of
99 | heterogeneous nodes. For instance, a game might have pickup and character
100 | classes which inherit from the same base:
101 |
102 | ```
103 | Entity {
104 | Vector3 position;
105 | }
106 |
107 | HealthPickup : Entity {
108 | float HealthRecoveryAmount;
109 | Vector4 IconColor;
110 | }
111 |
112 | Character : Entity {
113 | string Name;
114 | }
115 | ```
116 |
117 | It's not immediately clear how to turn this data into tensors, which can be a
118 | difficult issue for a server implementation.
119 |
120 | First, though, not all data needs to be sent to a joined connection. For a
121 | health pickup, the health recovery amount might be hidden information that
122 | agents don't generally have access to. Likewise, the IconColor might only matter
123 | for a user interface, which an agent might not have.
124 |
125 | After filtering unnecessary data the remainder may fit in a
126 | "[structure of arrays](https://en.wikipedia.org/wiki/AoS_and_SoA)" format.
127 |
128 | In our example case, we can structure the remaining data as a few different
129 | tensors. The specs for them might look like:
130 |
131 | ```
132 | TensorSpec {
133 | name = "HealthPickups.Position",
134 | shape = [3, -1]
135 | dtype = float
136 | },
137 |
138 | TensorSpec {
139 | name = "Characters.Position"
140 | shape = [3, -1]
141 | dtype = float
142 | },
143 |
144 | TensorSpec {
145 | name = "Characters.Name",
146 | shape = [-1]
147 | dtype = string
148 | }
149 | ```
150 |
151 | The use of a period "." character allows clients to reconstruct a hierarchy,
152 | which can be useful in grouping tensors into logical categories.
153 |
154 | If a structure of arrays format is infeasible, custom protocol messages can be
155 | implemented as an alternative. This allows a more natural data representation,
156 | but requires the client to also compile the protocol buffers and makes discovery
157 | of metadata, such as range for numeric types, more difficult. For our toy
158 | example, we could form the data in to this intermediate protocol buffer format:
159 |
160 | ```protobuf
161 | message Vector3 {
162 | float x = 1;
163 | float y = 2;
164 | float z = 3;
165 | }
166 |
167 | message HealthPickup {
168 | Vector3 position = 1;
169 | }
170 |
171 | message Character {
172 | Vector3 position = 1;
173 | string name = 2;
174 | }
175 | ```
176 |
177 | The `TensorSpec` would then look like:
178 |
179 | ```
180 | TensorSpec {
181 | name = "Characters",
182 | shape = [-1],
183 | dtype = proto
184 | },
185 |
186 | TensorSpec {
187 | name = "HealthPickups",
188 | shape = [-1],
189 | dtype = proto
190 | },
191 | ```
192 |
193 | ### Rendering
194 |
195 | Often renders are desirable observations for agents, whether human or machine.
196 | In past frameworks, such as
197 | [Atari](https://deepmind.com/research/publications/playing-atari-deep-reinforcement-learning)
198 | and
199 | [DMLab](https://deepmind.com/blog/article/impala-scalable-distributed-deeprl-dmlab-30),
200 | environments have been responsible for rendering. Servers implementing this
201 | protocol will likely (but not necessarily) continue this tradition.
202 |
203 | For performance reasons rendering should not be done unless requested by an
204 | agent, and then only the individual renders requested. Servers can batch similar
205 | renders if desirable.
206 |
207 | The exact image format for a render is at the discretion of the server
208 | implementation, but should be documented. Reasonable choices are to return an
209 | interleaved RGB buffer, similar to a GPU render texture, or a standard image
210 | format such as PNG or JPEG.
211 |
212 | Human agents generally prefer high resolution images, such as the high
213 | definition 1920x1080 format. Reinforcement agents, however, often use much
214 | smaller resolutions, such as 96x72, for performance of both environment and
215 | agent. It is up to the server implementation how or if it wants to expose
216 | resolution as a setting and what resolutions it wants to support. It could be
217 | hard coded or specified in world settings or join settings. Server implementers
218 | are advised to consider performance impacts of whatever choice they make, as
219 | rendering time often dominates the runtime of many environments.
220 |
221 | ### Per-element Min/Max Ranges {#per-element-ranges}
222 |
223 | When defining actions for an environment, it's desirable to provide users with
224 | the range of valid values that can be sent. `TensorSpec`s support this by
225 | providing min and max range fields, which can either be defined as a single
226 | scalar (see [broadcastable](overview.md#broadcastable)) for all elements in a
227 | `Tensor`, or each element can have their own range defined.
228 |
229 | For the vast majority of actions, we strongly encourage a single range only
230 | should be defined. When users start defining per-element ranges, this is
231 | typically indicative of the action being a combination of several, distinct
232 | actions. By slicing the action into separate actions, there's also the benefit
233 | of providing a more descriptive name, as well as making it easier for clients to
234 | easily compose their own action spaces.
235 |
236 | There are some examples where it may be desirable to not split the action. For
237 | example, imagine an `ABSOLUTE_MOUSE_XY` action. This would have two elements
238 | (one for the `X` and `Y` positions respectively), but would likely have
239 | different ranges based on the width and height of the screen. Splitting the
240 | action would mean agents could send only the `X` or `Y` action without the
241 | other, which might not be valid.
242 |
243 | ### Documentation
244 |
245 | Actions and observations can be self-documenting, in the sense that their
246 | `TensorSpec`s provide information about their type, shape and bounds. However,
247 | the exact meaning of actions and observations are not discoverable through the
248 | protocol. In addition, `CreateWorldRequest` and `JoinWorldRequest` settings lack
249 | even a spec. Therefore server implementers should be sure to document allowed
250 | settings for create and join world, along with their types and shapes, the
251 | consequences of any actions and the meaning of any observations, as well as any
252 | properties.
253 |
--------------------------------------------------------------------------------
/docs/v1/extensions/index.md:
--------------------------------------------------------------------------------
1 | # Extensions
2 |
3 | `dm_env_rpc` Extensions provide a way for users to send custom messages over the
4 | same bi-directional gRPC stream as the standard messages. This can be useful for
5 | debugging or manipulating the simulation in a way that isn't appropriate through
6 | the typical `dm_env_rpc` protocol.
7 |
8 | Extensions must complement the existing `dm_env_rpc` protocol, so that agents
9 | are able to send conventional `dm_env_rpc` messages, interspersed with extension
10 | requests.
11 |
12 | To send an extension message, you must pack your custom message into an `Any`
13 | proto, assigning to the `extension` field in `EnvironmentRequest`. For example:
14 |
15 | ```python
16 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection
17 | from google.protobuf import any_pb2
18 | from google.protobuf import struct_pb2
19 |
20 | def send_custom_message(connection: dm_env_rpc_connection.ConnectionType):
21 | """Send/receive custom Struct message through a dm_env_rpc connection."""
22 | my_message = struct_pb2.Struct(
23 | fields={'foo': struct_pb2.Value(string_value='bar')})
24 |
25 | packed_message = any_pb2.Any()
26 | packed_message.Pack(my_message)
27 |
28 | response = struct_pb2.Struct()
29 | connection.send(packed_message).Unpack(response)
30 | ```
31 |
32 | If the simulation can respond to such a message, it must send a response using
33 | the corresponding `extension` field in `EnvironmentResponse`. The server must
34 | send an [error](../overview.md#errors) if it doesn't recognise the request.
35 |
36 | ## Common Extensions
37 |
38 | The following commonly used extensions are provided with `dm_env_rpc`:
39 |
40 | * [Properties](properties.md)
41 |
42 | ## Recommendations
43 |
44 | ### Should you use an extension?
45 |
46 | Although extensions can be powerful, if you expect an extension message to be
47 | sent by the client every step, consider making it a proper action or observable,
48 | even if it's intended as metadata. This better ensures that all mutable actions
49 | can be executed at well-ordered times.
50 |
51 | ### Creating one parent request/response for your extension
52 |
53 | If your extension has more than a couple of requests, consider creating a single
54 | parent request/response message that you can add/remove messages from. This
55 | simplifies the server code by only having to unpack a single request, and makes
56 | it easier to compartmentalize each extension. For example:
57 |
58 | ```proto
59 | message AwesomeRequest {
60 | string foo = 1;
61 | }
62 |
63 | message AwesomeResponse {}
64 |
65 | message AnotherAwesomeRequest {
66 | string bar = 1;
67 | }
68 |
69 | message AnotherAwesomeResponse {}
70 |
71 | message MyExtensionRequest {
72 | oneof payload {
73 | AwesomeRequest awesome = 1;
74 | AnotherAwesomeRequest another_awesome = 2;
75 | }
76 | }
77 |
78 | message MyExtensionResponse {
79 | oneof payload {
80 | AwesomeResponse awesome = 1;
81 | AnotherAwesomeResponse another_awesome = 2;
82 | }
83 | }
84 | ```
85 |
86 | ## Alternatives
87 |
88 | An alternative to `dm_env_rpc` extension messages is to register a separate gRPC
89 | service. This has the benefit of being able to use gRPC's other
90 | [service methods](https://grpc.io/docs/guides/concepts/#service-definition).
91 | However, if you need messages to be sent at a particular point (e.g. after a
92 | specific `StepRequest`), synchronizing these disparate services will add
93 | additional complexity to the server.
94 |
--------------------------------------------------------------------------------
/docs/v1/extensions/properties.md:
--------------------------------------------------------------------------------
1 | ## Properties
2 |
3 | Often side-channel data is useful for debugging or manipulating the simulation
4 | in some way which isn’t appropriate for an agent’s interface. The property
5 | system provides this capability, allowing both the reading, writing and
6 | discovery of properties. Properties are implemented as an extension to the core
7 | protocol (see [extensions](index.md) for more details).
8 |
9 | Properties queried before a `JoinWorldRequest` will correspond to universal
10 | properties that apply for all worlds. Properties queried after a
11 | JoinWorldRequest can add a layer of world-specific properties.
12 |
13 | For writing properties, it’s up to the server to determine if/when any
14 | modification should take place (e.g. changing the world seed might not take
15 | place until the next sequence).
16 |
17 | For reading properties, the exact timing of the observation is up to the world.
18 | It may occur at the previous step's time, or it may occur at some intermediate
19 | time.
20 |
21 | Properties can be laid out in a tree-like structure, as long as each node in the
22 | tree has a unique key, by having parent nodes be listable (the is_listable field
23 | in the Property message).
24 |
25 | Although properties can be powerful, if you expect a property to be read or
26 | written to every step by a normally functioning agent, it may be preferable to
27 | make it a proper action or observation, even if it’s intended to be metadata.
28 | For instance, a score which provides hidden information could be done as a
29 | property, but it might be preferable to do it as an observation. This ensures
30 | observations and actions all occur at well-ordered times.
31 |
32 | ### Read Property
33 |
34 | ```proto
35 | package dm_env_rpc.v1.extensions.properties;
36 |
37 | message ReadPropertyRequest {
38 | string key = 1;
39 | }
40 |
41 | message ReadPropertyResponse {
42 | dm_env_rpc.v1.Tensor value = 1;
43 | }
44 | ```
45 |
46 | Returns the current value for the property with the provided `key`.
47 |
48 | ### Write Property
49 |
50 | ```proto
51 | package dm_env_rpc.v1.extensions.properties;
52 |
53 | message WritePropertyRequest {
54 | string key = 1;
55 | dm_env_rpc.v1.Tensor value = 2;
56 | }
57 |
58 | message WritePropertyResponse {}
59 | ```
60 |
61 | Set the property referened by the `key` field to the value of the provided
62 | `Tensor`.
63 |
64 | ### List Property
65 |
66 | ```proto
67 | package dm_env_rpc.v1.extensions.properties;
68 |
69 | message ListPropertyRequest {
70 | // Key to list property for. Empty string is root level.
71 | string key = 1;
72 | }
73 |
74 | message ListPropertyResponse {
75 | message PropertySpec {
76 | // Required: TensorSpec name field for key value.
77 | dm_env_rpc.v1.TensorSpec spec = 1;
78 |
79 | bool is_readable = 2;
80 | bool is_writable = 3;
81 | bool is_listable = 4;
82 | }
83 |
84 | repeated PropertySpec values = 1;
85 | }
86 | ```
87 |
88 | Returns an array of `PropertySpec` values for each property residing under the
89 | provided `key`. If the `key` is empty, properties registered at the root level
90 | are returned. If the `key` is not empty and the property is not listable an
91 | error is returned.
92 |
93 | Each `PropertySpec` must return a `TensorSpec` with its name field set to the
94 | property's key. For readable and writable properties, the type and shape of the
95 | property must also be set. Properties which are only listable must have the
96 | default value for type (`dm_env_rpc.v1.DataType.INVALID_DATA_TYPE`) and shape
97 | (an empty array).
98 |
--------------------------------------------------------------------------------
/docs/v1/glossary.md:
--------------------------------------------------------------------------------
1 | ## Glossary
2 |
3 | ### Server
4 |
5 | A server is a process which implements the dm_env_rpc `Environment` gRPC
6 | service. A server can host one or more worlds and allow one or more client
7 | connections.
8 |
9 | ### Client
10 |
11 | A client connects to a server. It can request worlds to be created and join them
12 | to become agents. It can also query properties and reset worlds without being an
13 | agent.
14 |
15 | ### Agent
16 |
17 | A client which has joined a world. In a game sense this is a "player". It has
18 | the ability to send actions and receive observations. This could be a human or a
19 | reinforcement learning agent.
20 |
21 | ### World
22 |
23 | A system or simulation in which one or more agents interact.
24 |
25 | ### Environment
26 |
27 | An agent's view of a world. In limited information situations, the environment
28 | may expose only part of the world state to an agent that corresponds to
29 | information that agent is allowed by the simulation to have. Agents communicate
30 | directly with an environment, and the environment communicates with the world to
31 | synchronize with it.
32 |
33 | ### Sequence
34 |
35 | A series of discrete states, where one state is correlated to previous and
36 | subsequent states, possibly ending in a terminal state, and usually modified by
37 | agent actions. In the simplest case, playing an entire game until one player is
38 | declared the winner is one sequence. Also sometimes called an "episode" in
39 | reinforcement learning contexts.
40 |
--------------------------------------------------------------------------------
/docs/v1/index.md:
--------------------------------------------------------------------------------
1 | # `dm_env_rpc` documentation
2 |
3 | * [Protocol overview](overview.md)
4 | * [Protocol reference](reference.md)
5 | * [Extensions](extensions/index.md)
6 | * [Appendix](appendix.md)
7 | * [Glossary](glossary.md)
8 |
--------------------------------------------------------------------------------
/docs/v1/overview.md:
--------------------------------------------------------------------------------
1 | ## Protocol overview
2 |
3 | `dm_env_rpc` is a protocol for [agents](glossary.md#agent)
4 | ([clients](glossary.md#client)) to communicate with
5 | [environments](glossary.md#environment) ([servers](glossary.md#server)). A
6 | server has a single remote procedural call (RPC) named `Process` for handling
7 | messages from clients.
8 |
9 | ```protobuf
10 | service Environment {
11 | // Process incoming environment requests.
12 | rpc Process(stream EnvironmentRequest) returns (stream EnvironmentResponse) {}
13 | }
14 | ```
15 |
16 | Each call to `Process` creates a bidirectional streaming connection between a
17 | client and the server. It is up to each server to decide if it can support
18 | multiple simultaneous clients, if each client can instantiate its own
19 | [world](glossary.md#world), or if each client is expected to connect to the same
20 | underlying world.
21 |
22 | Each connection accepts a stream of `EnvironmentRequest` messages from the
23 | client. The server sends exactly one `EnvironmentResponse` for each request. The
24 | payload of the response always either corresponds to that of the request (e.g. a
25 | `StepResponse` in response to a `StepRequest`) or is an error `Status` message.
26 |
27 | ### Streaming
28 |
29 | Clients may send multiple requests without waiting for responses. The server
30 | processes these requests in the order that they are sent and returns the
31 | corresponding responses in the same order. It is the client's responsibility to
32 | ensure each request is valid when processed.
33 |
34 | Note: gRPC guarantees messages will be delivered in the order they are sent.
35 |
36 | ### States
37 |
38 | An Environment connection can be in one of two states: joined to a world or not.
39 | When not joined to a world, `StepRequest` and `ResetRequest` calls are
40 | unavailable (the server will send an error upon receiving them).
41 |
42 | A joined connection may be in a variety of sub-states (ie: RUNNING, TERMINATED,
43 | and INTERRUPTED). Agents transition between these states using `StepRequest`,
44 | `ResetRequest` and `ResetWorldRequest` calls, though the environment controls
45 | which state is transitioned to.
46 |
47 | 
48 |
49 | ### Tensors
50 |
51 | For the purposes of this protocol tensors are loosely based on NumPy arrays:
52 | n-dimensional arrays of data with the same data type. A tensor with "n"
53 | dimensions can be referred to as an n-tensor. A 0-tensor is just a scalar value,
54 | such as a single float. A 1-tensor can be thought of either as a single
55 | dimensional array or vector. A 2-tensor is a two dimensional array or a matrix.
56 | In principle there's no limit to the number of dimensions a tensor can have in
57 | the protocol, but in practice we rarely have more than 3 or 4 dimensions.
58 | Tensors are not allowed to be
59 | [ragged](https://en.wikipedia.org/wiki/Jagged_array) (have rows with different
60 | numbers of elements), though they may have a
61 | [variable length](#variable-lengths) along one dimension.
62 |
63 | A tensor's shape represents the number of elements along each dimension. A
64 | 2-tensor with a shape of `[3, 4]` would be a 2 dimensional array with 3 rows and
65 | 4 columns.
66 |
67 | In order to pack these tensors in a way that can be sent over the network they
68 | have to be flattened to a one dimensional array. For multidimensional tensors
69 | it's expected that they will be packed in a row-major format. That is, the
70 | values at indices `[2, 3, 4]` and `[2, 3, 5]` are located next to each other in
71 | the flattened array. This is the default memory layout in C based languages,
72 | such as C/C++, Java, and C#, and NumPy in Python and TensorFlow, but is opposite
73 | to how column-major languages work, such as Fortran.
74 |
75 | Consult the
76 | [Row- and column-major order](https://en.wikipedia.org/wiki/Row-_and_column-major_order)
77 | article for more information.
78 |
79 | #### Variable lengths:
80 |
81 | Normally a tensor has a well defined shape. However, if one of the elements in a
82 | Tensor's shape is negative it represents a variable dimension. Either the client
83 | or the server, upon receiving a Tensor message with a Shape with a negative
84 | element, will attempt to infer the correct value for the shape based on the
85 | number of elements in the Tensor's array part and the rest of the shape.
86 |
87 | Note: even though this dimension has variable length, the tensor itself is still
88 | not ragged. The variable dimension has a definite length that can be inferred.
89 |
90 | For instance, a Tensor with shape `[2, -1]` represents a variable length
91 | 2-tensor with two rows and a variable number of columns. If this Tensor's array
92 | part contains 6 elements `[1, 2, 3, 4, 5, 6]` then the final produced 2-tensor
93 | will look like:
94 |
95 | 
96 |
97 | Variable length tensors are useful for situations where a given observation or
98 | action's length is unknowable from frame to frame. For instance, the number of
99 | words in a sentence or the number of stochastic events in a given time frame.
100 |
101 | Note: At most one dimension is allowed to be variable on a given tensor.
102 |
103 | Note: servers should provide actions and observations with non-variable length
104 | if possible, as it can reduce the complexity of agent implementations.
105 |
106 | #### Broadcastable
107 |
108 | If a tensor contains all elements of the same value, it is "broadcastable" and
109 | can be represented with a single value in the array part of the Tensor, even if
110 | the shape requires more elements. Either the client or server, upon receiving a
111 | broadcastable tensor, will unpack it to an appropriately sized multidimensional
112 | array with each element being set to the value from the lone element in the
113 | array part of the Tensor.
114 |
115 | For instance, a Tensor with Shape `[2, 2]` and a single element `1` in its array
116 | part will produce a 2-tensor that looks like:
117 |
118 | 
119 |
120 | #### TensorSpec
121 |
122 | A `TensorSpec` provides metadata about a tensor, such as its name, type, and
123 | expected shape. Tensor names must be unique within a given domain (action or
124 | observation) so clients can use them as keys. A period "." character in a Tensor
125 | name indicates a level of nesting. See
126 | [Nested actions/observations](#nested-actions-or-observations) for more details.
127 |
128 | #### Ranges
129 |
130 | Numerical tensors can have min and max ranges on the TensorSpec. These ranges
131 | are inclusive, and can be either:
132 |
133 | * Scalar: Indicates all `Tensor` elements must adhere to this `min/max` value.
134 | * N-dimensional: Must match the `TensorSpec` shape, where each `Tensor`
135 | element has a distinct `min/max` value.
136 |
137 | For a `TensorSpec` with a shape of [variable-length](#variable-lengths), only
138 | scalar ranges are supported.
139 |
140 | Whilst distinct, per-element `min/max` ranges are supported, we encourage
141 | implementers to instead provide separate actions. For more discussion, see the
142 | [per-element min/max range](appendix.md#per-element-ranges) appendix.
143 |
144 | Note: Range is not enforced by the protocol. Servers and clients should be
145 | careful to validate any incoming or outgoing tensors to make sure they are in
146 | range. Servers should return an error for any out of range tensors from the
147 | client.
148 |
149 | ### UIDs
150 |
151 | Unique Identifications (UIDs) are 64 bit numbers used as keys for data that is
152 | sent over the wire frequently, specifically observations and actions. This
153 | reduces the amount of data compared to string keys, since a key is needed for
154 | each action and observation every step. For data that is not intended to be
155 | referenced frequently, such as create and join settings, string keys are used
156 | for clarity.
157 |
158 | For more information on UIDs see [JoinWorld specs](reference.md#specs)
159 |
160 | ### Errors
161 |
162 | If an `EnvironmentRequest` fails for any reason, the payload will contain an
163 | error `Status` instead of the normal response message. It’s up to the server
164 | implementation to decide what error codes and messages to use. For fatal errors,
165 | the server can close the stream after sending the error. For recoverable errors
166 | the server can treat the failed request as a no-op and clients can retry.
167 |
168 | The client cannot send errors to the server. If the client has an error that it
169 | can’t recover from, it should just close the connection (gracefully, if
170 | possible).
171 |
172 | Since a server can *only* send messages in response to a given
173 | `EnvironmentRequest`, the errors should ideally be focused on problems from a
174 | specific client request. More general issues or warnings from a given server
175 | implementation should be logged through a separate mechanism.
176 |
177 | If a server implementation needs to report an error, it should send as much
178 | detail about the nature of the problem as possible and any likely remedies. A
179 | client may have difficulties debugging a server, perhaps because the server is
180 | running on a different machine, so the server should send enough information to
181 | properly diagnose the problem. Any additional relevant information about the
182 | error that would normally be logged by the server should also be included in the
183 | error sent to the client.
184 |
185 | ### Nested actions or observations
186 |
187 | Nested actions or observations are not directly supported by the protocol,
188 | however there are two ways they can be handled:
189 |
190 | 1. Flattening the hierarchy, using a period "." character to indicate a level
191 | of nesting. Servers can flatten the nested structure to push through the
192 | wire and the client can reconstruct the nested structure on its side.
193 | Servers need to be careful not to use "." as part of the tensor's name,
194 | except to indicate a level of nesting.
195 |
196 | 2. Defining a custom proto message type, or using the proto common type Struct,
197 | and setting it as the payload in a Tensor message’s array field.
198 |
199 | Flattening the hierarchy is easier for clients to consume, but can involve a
200 | great deal of work on the server. A custom proto message is more flexible but
201 | means every client needs to compile the custom protobuf for their desired
202 | language.
203 |
204 | Nested data structures occur commonly with object-oriented codebases, such as
205 | from an existing game, and flattening them can be difficult. For an in-depth
206 | discussion see the
207 | [nested observations example](appendix.md#nested-observations-example).
208 |
209 | ### Reward and discount
210 |
211 | `dm_env_rpc` does not provide explicit channels for reward or discount (common
212 | reinforcement learning signals). For servers where there's a sensible reward or
213 | discount already available they can be provided through a `reward` or `discount`
214 | observation respectively. For `dm_env`, the provided `DmEnvAdaptor` will
215 | properly route the reward and discount for client code if available.
216 |
217 | A server may choose not to provide reward and discount observations, however.
218 | See [reward functions](appendix.md#reward-functions) for a discussion on the
219 | pitfalls of reward design.
220 |
221 | ### Multiagent support
222 |
223 | Some servers may support multiple joined connections on the same world. These
224 | multiagent servers are responsible for coordinating how agents interact through
225 | the world, and ensuring each connection has a separate environment for each
226 | agent.
227 |
--------------------------------------------------------------------------------
/docs/v1/single_agent_connect_and_step.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/single_agent_connect_and_step.png
--------------------------------------------------------------------------------
/docs/v1/single_agent_sequence_transitions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/single_agent_sequence_transitions.png
--------------------------------------------------------------------------------
/docs/v1/single_agent_world_destruction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/single_agent_world_destruction.png
--------------------------------------------------------------------------------
/docs/v1/state_transitions.graphviz:
--------------------------------------------------------------------------------
1 | # Source to generate the state_transitions.png diagram using Graphviz.
2 |
3 | digraph {
4 | node [shape=record];
5 | compound=true;
6 | "" [shape=diamond]
7 | "" -> TERMINATED [label = "JoinWorld "]
8 | subgraph cluster_world_states {
9 | rank="same";
10 | TERMINATED -> RUNNING [label = "Step", color=grey, dir=both];
11 | RUNNING -> INTERRUPTED [label = "Step ", color=grey, dir=both];
12 | RUNNING -> INTERRUPTED [label = "Reset", color=grey];
13 | RUNNING -> RUNNING [label = " Step", color=grey];
14 | TERMINATED -> TERMINATED [label = "Reset", color=grey];
15 | INTERRUPTED -> INTERRUPTED [label = "Reset", color=grey];
16 | }
17 |
18 | TERMINATED -> "" [label = "LeaveWorld", ltail=cluster_world_states];
19 | }
20 |
--------------------------------------------------------------------------------
/docs/v1/state_transitions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/dm_env_rpc/dea4326e7e4e15b3abd787a5907fa12b3c6940e2/docs/v1/state_transitions.png
--------------------------------------------------------------------------------
/examples/catch_human_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Example Catch human agent."""
16 |
17 | from concurrent import futures
18 |
19 | from absl import app
20 | import grpc
21 | import pygame
22 |
23 | import catch_environment
24 | from dm_env_rpc.v1 import connection as dm_env_rpc_connection
25 | from dm_env_rpc.v1 import dm_env_adaptor
26 | from dm_env_rpc.v1 import dm_env_rpc_pb2
27 | from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc
28 |
29 | _FRAMES_PER_SEC = 3
30 | _FRAME_DELAY_MS = int(1000.0 // _FRAMES_PER_SEC)
31 |
32 | _BLACK = (0, 0, 0)
33 | _WHITE = (255, 255, 255)
34 |
35 | _ACTION_LEFT = -1
36 | _ACTION_NOTHING = 0
37 | _ACTION_RIGHT = 1
38 |
39 | _ACTION_PADDLE = 'paddle'
40 | _OBSERVATION_REWARD = 'reward'
41 | _OBSERVATION_BOARD = 'board'
42 |
43 |
44 | def _draw_row(row_str, row_index, standard_font, window_surface):
45 | text = standard_font.render(row_str, True, _WHITE)
46 | text_rect = text.get_rect()
47 | text_rect.left = 50
48 | text_rect.top = 30 + (row_index * 30)
49 | window_surface.blit(text, text_rect)
50 |
51 |
52 | def _render_window(board, window_surface, reward):
53 | """Render the game onto the window surface."""
54 |
55 | standard_font = pygame.font.SysFont('Courier', 24)
56 | instructions_font = pygame.font.SysFont('Courier', 16)
57 |
58 | num_rows = board.shape[0]
59 | num_cols = board.shape[1]
60 |
61 | window_surface.fill(_BLACK)
62 |
63 | # Draw board.
64 | header = '* ' * (num_cols + 2)
65 | _draw_row(header, 0, standard_font, window_surface)
66 | for board_index in range(num_rows):
67 | row = board[board_index]
68 | row_str = '* '
69 | for c in row:
70 | row_str += 'x ' if c == 1. else ' '
71 | row_str += '* '
72 | _draw_row(row_str, board_index + 1, standard_font, window_surface)
73 | _draw_row(header, num_rows + 1, standard_font, window_surface)
74 |
75 | # Draw footer.
76 | reward_str = 'Reward: {}'.format(reward)
77 | _draw_row(reward_str, num_rows + 3, standard_font, window_surface)
78 | instructions = ('Instructions: Left/Right arrow keys to move paddle, Escape '
79 | 'to exit.')
80 | _draw_row(instructions, num_rows + 5, instructions_font, window_surface)
81 |
82 |
83 | def _start_server():
84 | """Starts the Catch gRPC server."""
85 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
86 | servicer = catch_environment.CatchEnvironmentService()
87 | dm_env_rpc_pb2_grpc.add_EnvironmentServicer_to_server(servicer, server)
88 |
89 | port = server.add_secure_port('localhost:0', grpc.local_server_credentials())
90 | server.start()
91 | return server, port
92 |
93 |
94 | def main(_):
95 | pygame.init()
96 |
97 | server, port = _start_server()
98 |
99 | with dm_env_rpc_connection.create_secure_channel_and_connect(
100 | f'localhost:{port}') as connection:
101 | env, world_name = dm_env_adaptor.create_and_join_world(
102 | connection, create_world_settings={}, join_world_settings={})
103 | with env:
104 | window_surface = pygame.display.set_mode((800, 600), 0, 32)
105 | pygame.display.set_caption('Catch Human Agent')
106 |
107 | keep_running = True
108 | while keep_running:
109 | requested_action = _ACTION_NOTHING
110 |
111 | for event in pygame.event.get():
112 | if event.type == pygame.QUIT:
113 | keep_running = False
114 | break
115 | elif event.type == pygame.KEYDOWN:
116 | if event.key == pygame.K_LEFT:
117 | requested_action = _ACTION_LEFT
118 | elif event.key == pygame.K_RIGHT:
119 | requested_action = _ACTION_RIGHT
120 | elif event.key == pygame.K_ESCAPE:
121 | keep_running = False
122 | break
123 |
124 | actions = {_ACTION_PADDLE: requested_action}
125 | timestep = env.step(actions)
126 |
127 | board = timestep.observation[_OBSERVATION_BOARD]
128 | reward = timestep.reward
129 |
130 | _render_window(board, window_surface, reward)
131 |
132 | pygame.display.update()
133 |
134 | pygame.time.wait(_FRAME_DELAY_MS)
135 |
136 | connection.send(dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))
137 |
138 | server.stop(None)
139 |
140 |
141 | if __name__ == '__main__':
142 | app.run(main)
143 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Requirements necessary to run dm_env_rpc tests.
2 | # Please install other requirements via pip install.
3 |
4 | absl-py==2.1.0
5 | pytest==8.2.2
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Install script for setuptools."""
16 |
17 | import importlib.util
18 | import os
19 |
20 | from setuptools import Command
21 | from setuptools import find_packages
22 | from setuptools import setup
23 | from setuptools.command.build_ext import build_ext
24 | from setuptools.command.build_py import build_py
25 |
26 |
27 | _ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28 | _GOOGLE_COMMON_PROTOS_ROOT_DIR = os.path.join(
29 | _ROOT_DIR, 'third_party/api-common-protos'
30 | )
31 |
32 | # Tuple of proto message definitions to build Python bindings for. Paths must
33 | # be relative to root directory.
34 | _DM_ENV_RPC_PROTOS = (
35 | 'dm_env_rpc/v1/dm_env_rpc.proto',
36 | 'dm_env_rpc/v1/extensions/properties.proto',
37 | )
38 |
39 |
40 | class _GenerateProtoFiles(Command):
41 | """Command to generate protobuf bindings for dm_env_rpc.proto."""
42 |
43 | descriptions = 'Generates Python protobuf bindings for dm_env_rpc.proto.'
44 | user_options = []
45 |
46 | def initialize_options(self):
47 | pass
48 |
49 | def finalize_options(self):
50 | pass
51 |
52 | def run(self):
53 | # Import grpc_tools and importlib_resources here, after setuptools has
54 | # installed setup_requires dependencies.
55 | from grpc_tools import protoc # pylint: disable=g-import-not-at-top
56 | import importlib_resources # pylint: disable=g-import-not-at-top
57 |
58 | if not os.path.exists(
59 | os.path.join(_GOOGLE_COMMON_PROTOS_ROOT_DIR, 'google/rpc/status.proto')
60 | ):
61 | raise RuntimeError(
62 | 'Cannot find third_party/api-common-protos. '
63 | 'Please run `git submodule init && git submodule update` to install '
64 | 'the api-common-protos submodule.'
65 | )
66 |
67 | with importlib_resources.as_file(
68 | importlib_resources.files('grpc_tools') / '_proto'
69 | ) as grpc_protos_include:
70 | for proto_path in _DM_ENV_RPC_PROTOS:
71 | proto_args = [
72 | 'grpc_tools.protoc',
73 | '--proto_path={}'.format(_GOOGLE_COMMON_PROTOS_ROOT_DIR),
74 | '--proto_path={}'.format(grpc_protos_include),
75 | '--proto_path={}'.format(_ROOT_DIR),
76 | '--python_out={}'.format(_ROOT_DIR),
77 | '--grpc_python_out={}'.format(_ROOT_DIR),
78 | os.path.join(_ROOT_DIR, proto_path),
79 | ]
80 | if protoc.main(proto_args) != 0:
81 | raise RuntimeError('ERROR: {}'.format(proto_args))
82 |
83 |
84 | class _BuildExt(build_ext):
85 | """Generate protobuf bindings in build_ext stage."""
86 |
87 | def run(self):
88 | self.run_command('generate_protos')
89 | build_ext.run(self)
90 |
91 |
92 | class _BuildPy(build_py):
93 | """Generate protobuf bindings in build_py stage."""
94 |
95 | def run(self):
96 | self.run_command('generate_protos')
97 | build_py.run(self)
98 |
99 |
100 | def _load_version():
101 | """Load dm_env_rpc version."""
102 | spec = importlib.util.spec_from_file_location(
103 | '_version', 'dm_env_rpc/_version.py'
104 | )
105 | version_module = importlib.util.module_from_spec(spec)
106 | spec.loader.exec_module(version_module)
107 | return version_module.__version__
108 |
109 |
110 | setup(
111 | name='dm-env-rpc',
112 | version=_load_version(),
113 | description='A networking protocol for agent-environment communication.',
114 | author='DeepMind',
115 | license='Apache License, Version 2.0',
116 | keywords='reinforcement-learning python machine learning',
117 | packages=find_packages(exclude=['examples']),
118 | install_requires=[
119 | 'dm-env>=1.2',
120 | 'immutabledict',
121 | 'googleapis-common-protos',
122 | 'grpcio',
123 | 'numpy<2.0',
124 | 'protobuf>=3.8',
125 | ],
126 | python_requires='>=3.8',
127 | setup_requires=['grpcio-tools', 'importlib_resources'],
128 | extras_require={
129 | 'examples': ['pygame'],
130 | },
131 | cmdclass={
132 | 'build_ext': _BuildExt,
133 | 'build_py': _BuildPy,
134 | 'generate_protos': _GenerateProtoFiles,
135 | },
136 | classifiers=[
137 | 'Development Status :: 5 - Production/Stable',
138 | 'Environment :: Console',
139 | 'Intended Audience :: Science/Research',
140 | 'License :: OSI Approved :: Apache Software License',
141 | 'Operating System :: POSIX :: Linux',
142 | 'Operating System :: Microsoft :: Windows',
143 | 'Operating System :: MacOS :: MacOS X',
144 | 'Programming Language :: Python :: 3.8',
145 | 'Programming Language :: Python :: 3.9',
146 | 'Programming Language :: Python :: 3.10',
147 | 'Programming Language :: Python :: 3.11',
148 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
149 | ],
150 | )
151 |
--------------------------------------------------------------------------------