├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── conftest.py ├── dm_env ├── __init__.py ├── _abstract_test_mixin.py ├── _environment.py ├── _environment_test.py ├── _metadata.py ├── auto_reset_environment.py ├── auto_reset_environment_test.py ├── specs.py ├── specs_test.py ├── test_utils.py └── test_utils_test.py ├── docs └── index.md ├── examples ├── __init__.py ├── catch.py └── catch_test.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore byte-compiled Python code 2 | *.py[cod] 3 | 4 | # Ignore directories created during the build/installation process 5 | *.egg-info/ 6 | build/ 7 | dist/ 8 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All significant changes to this project will be documented here. 4 | 5 | ## [Unreleased] 6 | 7 | ## [1.6] 8 | 9 | Release date: 2022-12-21 10 | 11 | * Add support for Python 3.10. 12 | 13 | ## [1.5] 14 | 15 | Release date: 2021-07-12 16 | 17 | * Add an `AutoResetEnvironment` that calls `reset` for you if the previous 18 | step was the last step. 19 | 20 | ## [1.4] 21 | 22 | Release date: 2021-02-12 23 | 24 | * Dropped support for Python versions < 3.6. 25 | 26 | ## [1.3] 27 | 28 | Release date: 2020-10-30 29 | 30 | * Added a `StringArray` spec subclass for representing arrays of variable- 31 | length strings. 32 | * Added a check to enforce that `minimum <= maximum` for `BoundedArray`. 33 | 34 | ## [1.2] 35 | 36 | Release date: 2019-11-12 37 | 38 | ### Changed 39 | 40 | * `test_utils.EnvironmentTestMixin` can now be used to validate 41 | implementations of `dm_env.Environment` where actions, observations, rewards 42 | and/or discounts are arbitrary nested structures containing numpy arrays or 43 | scalars. 44 | 45 | ## [1.1] 46 | 47 | Release date: 2019-08-12 48 | 49 | ### Added 50 | 51 | * Specs now have a `replace` method that can be used to create a new instance 52 | with some of the attributes replaced (similar to `namedtuple._replace`). 53 | 54 | ### Changed 55 | 56 | * The `BoundedArray` constructor now casts `minimum` and `maximum` so that 57 | their dtypes match that of the spec instance. 58 | 59 | ## [1.0] 60 | 61 | Release date: 2019-07-18 62 | 63 | * Initial release. 64 | 65 | [Unreleased]: https://github.com/deepmind/dm_env/compare/v1.6...HEAD 66 | [1.6]: https://github.com/deepmind/dm_env/compare/v1.5...v1.6 67 | [1.5]: https://github.com/deepmind/dm_env/compare/v1.4...v1.5 68 | [1.4]: https://github.com/deepmind/dm_env/compare/v1.3...v1.4 69 | [1.3]: https://github.com/deepmind/dm_env/compare/v1.2...v1.3 70 | [1.2]: https://github.com/deepmind/dm_env/compare/v1.1...v1.2 71 | [1.1]: https://github.com/deepmind/dm_env/compare/v1.0...v1.1 72 | [1.0]: https://github.com/deepmind/dm_env/releases/tag/v1.0 73 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | This library is widely used in our research code, so we are unlikely to be able 4 | to accept changes to the interface. We can more easily accept patches and 5 | contributions related to documentation. There are just a few small guidelines 6 | you need to follow. 7 | 8 | ## Contributor License Agreement 9 | 10 | Contributions to this project must be accompanied by a Contributor License 11 | Agreement. You (or your employer) retain the copyright to your contribution, 12 | this simply gives us permission to use and redistribute your contributions as 13 | part of the project. Head over to to see 14 | your current agreements on file or to sign a new one. 15 | 16 | You generally only need to submit a CLA once, so if you've already submitted one 17 | (even if it was for a different project), you probably don't need to do it 18 | again. 19 | 20 | ## Code reviews 21 | 22 | All submissions, including submissions by project members, require review. We 23 | use GitHub pull requests for this purpose. Consult 24 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 25 | information on using pull requests. 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `dm_env`: The DeepMind RL Environment API 2 | 3 | ![PyPI Python version](https://img.shields.io/pypi/pyversions/dm-env) 4 | ![PyPI version](https://badge.fury.io/py/dm-env.svg) 5 | 6 | This package describes an interface for Python reinforcement learning (RL) 7 | environments. It consists of the following core components: 8 | 9 | * `dm_env.Environment`: An abstract base class for RL environments. 10 | * `dm_env.TimeStep`: A container class representing the outputs of the 11 | environment on each time step (transition). 12 | * `dm_env.specs`: A module containing primitives that are used to describe the 13 | format of the actions consumed by an environment, as well as the 14 | observations, rewards, and discounts it returns. 15 | * `dm_env.test_utils`: Tools for testing whether concrete environment 16 | implementations conform to the `dm_env.Environment` interface. 17 | 18 | Please see the documentation [here][api_docs] for more information about the 19 | semantics of the environment interface and how to use it. The [examples] 20 | subdirectory also contains illustrative examples of RL environments implemented 21 | using the `dm_env` interface. 22 | 23 | ## Installation 24 | 25 | `dm_env` can be installed from PyPI using `pip`: 26 | 27 | ```bash 28 | pip install dm-env 29 | ``` 30 | 31 | Note that from version 1.4 onwards, we support Python 3.6+ only. 32 | 33 | You can also install it directly from our GitHub repository using `pip`: 34 | 35 | ```bash 36 | pip install git+git://github.com/deepmind/dm_env.git 37 | ``` 38 | 39 | or alternatively by checking out a local copy of our repository and running: 40 | 41 | ```bash 42 | pip install /path/to/local/dm_env/ 43 | ``` 44 | 45 | [api_docs]: docs/index.md 46 | [examples]: examples/ 47 | 48 | ## Citing 49 | 50 | To cite this repository: 51 | 52 | ```bibtex 53 | @misc{dm_env2019, 54 | author={Alistair Muldal and 55 | Yotam Doron and 56 | John Aslanides and 57 | Tim Harley and 58 | Tom Ward and 59 | Siqi Liu}, 60 | title={dm\_env: A Python interface for reinforcement learning environments}, 61 | year={2019}, 62 | url={http://github.com/deepmind/dm_env} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The dm_env Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """pytest configuration for dm_env.""" 16 | 17 | collect_ignore = [ 18 | 'conftest.py', 19 | 'setup.py', 20 | ] 21 | -------------------------------------------------------------------------------- /dm_env/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """A Python interface for reinforcement learning environments.""" 17 | 18 | from dm_env import _environment 19 | from dm_env._metadata import __version__ 20 | 21 | Environment = _environment.Environment 22 | StepType = _environment.StepType 23 | TimeStep = _environment.TimeStep 24 | 25 | # Helper functions for creating TimeStep namedtuples with default settings. 26 | restart = _environment.restart 27 | termination = _environment.termination 28 | transition = _environment.transition 29 | truncation = _environment.truncation 30 | -------------------------------------------------------------------------------- /dm_env/_abstract_test_mixin.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Base class for TestMixin classes.""" 17 | 18 | 19 | class TestMixin: 20 | """Base class for TestMixins. 21 | 22 | Subclasses must override `make_object_under_test`. 23 | """ 24 | 25 | def setUp(self): 26 | # A call to super is required for cooperative multiple inheritance to work. 27 | super().setUp() 28 | test_method = getattr(self, self._testMethodName) 29 | make_obj_kwargs = getattr(test_method, "_make_obj_kwargs", {}) 30 | self.object_under_test = self.make_object_under_test(**make_obj_kwargs) 31 | 32 | def make_object_under_test(self, **unused_kwargs): 33 | raise NotImplementedError( 34 | "Attempt to run tests from an abstract TestMixin subclass %s. " 35 | "Perhaps you forgot to override make_object_under_test?" % type(self)) 36 | -------------------------------------------------------------------------------- /dm_env/_environment.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Python RL Environment API.""" 17 | 18 | import abc 19 | import enum 20 | from typing import Any, NamedTuple 21 | 22 | from dm_env import specs 23 | 24 | 25 | class TimeStep(NamedTuple): 26 | """Returned with every call to `step` and `reset` on an environment. 27 | 28 | A `TimeStep` contains the data emitted by an environment at each step of 29 | interaction. A `TimeStep` holds a `step_type`, an `observation` (typically a 30 | NumPy array or a dict or list of arrays), and an associated `reward` and 31 | `discount`. 32 | 33 | The first `TimeStep` in a sequence will have `StepType.FIRST`. The final 34 | `TimeStep` will have `StepType.LAST`. All other `TimeStep`s in a sequence will 35 | have `StepType.MID. 36 | 37 | Attributes: 38 | step_type: A `StepType` enum value. 39 | reward: A scalar, NumPy array, nested dict, list or tuple of rewards; or 40 | `None` if `step_type` is `StepType.FIRST`, i.e. at the start of a 41 | sequence. 42 | discount: A scalar, NumPy array, nested dict, list or tuple of discount 43 | values in the range `[0, 1]`, or `None` if `step_type` is 44 | `StepType.FIRST`, i.e. at the start of a sequence. 45 | observation: A NumPy array, or a nested dict, list or tuple of arrays. 46 | Scalar values that can be cast to NumPy arrays (e.g. Python floats) are 47 | also valid in place of a scalar array. 48 | """ 49 | 50 | # TODO(b/143116886): Use generics here when PyType supports them. 51 | step_type: Any 52 | reward: Any 53 | discount: Any 54 | observation: Any 55 | 56 | def first(self) -> bool: 57 | return self.step_type == StepType.FIRST 58 | 59 | def mid(self) -> bool: 60 | return self.step_type == StepType.MID 61 | 62 | def last(self) -> bool: 63 | return self.step_type == StepType.LAST 64 | 65 | 66 | class StepType(enum.IntEnum): 67 | """Defines the status of a `TimeStep` within a sequence.""" 68 | # Denotes the first `TimeStep` in a sequence. 69 | FIRST = 0 70 | # Denotes any `TimeStep` in a sequence that is not FIRST or LAST. 71 | MID = 1 72 | # Denotes the last `TimeStep` in a sequence. 73 | LAST = 2 74 | 75 | def first(self) -> bool: 76 | return self is StepType.FIRST 77 | 78 | def mid(self) -> bool: 79 | return self is StepType.MID 80 | 81 | def last(self) -> bool: 82 | return self is StepType.LAST 83 | 84 | 85 | class Environment(metaclass=abc.ABCMeta): 86 | """Abstract base class for Python RL environments. 87 | 88 | Observations and valid actions are described with `Array` specs, defined in 89 | the `specs` module. 90 | """ 91 | 92 | @abc.abstractmethod 93 | def reset(self) -> TimeStep: 94 | """Starts a new sequence and returns the first `TimeStep` of this sequence. 95 | 96 | Returns: 97 | A `TimeStep` namedtuple containing: 98 | step_type: A `StepType` of `FIRST`. 99 | reward: `None`, indicating the reward is undefined. 100 | discount: `None`, indicating the discount is undefined. 101 | observation: A NumPy array, or a nested dict, list or tuple of arrays. 102 | Scalar values that can be cast to NumPy arrays (e.g. Python floats) 103 | are also valid in place of a scalar array. Must conform to the 104 | specification returned by `observation_spec()`. 105 | """ 106 | 107 | @abc.abstractmethod 108 | def step(self, action) -> TimeStep: 109 | """Updates the environment according to the action and returns a `TimeStep`. 110 | 111 | If the environment returned a `TimeStep` with `StepType.LAST` at the 112 | previous step, this call to `step` will start a new sequence and `action` 113 | will be ignored. 114 | 115 | This method will also start a new sequence if called after the environment 116 | has been constructed and `reset` has not been called. Again, in this case 117 | `action` will be ignored. 118 | 119 | Args: 120 | action: A NumPy array, or a nested dict, list or tuple of arrays 121 | corresponding to `action_spec()`. 122 | 123 | Returns: 124 | A `TimeStep` namedtuple containing: 125 | step_type: A `StepType` value. 126 | reward: Reward at this timestep, or None if step_type is 127 | `StepType.FIRST`. Must conform to the specification returned by 128 | `reward_spec()`. 129 | discount: A discount in the range [0, 1], or None if step_type is 130 | `StepType.FIRST`. Must conform to the specification returned by 131 | `discount_spec()`. 132 | observation: A NumPy array, or a nested dict, list or tuple of arrays. 133 | Scalar values that can be cast to NumPy arrays (e.g. Python floats) 134 | are also valid in place of a scalar array. Must conform to the 135 | specification returned by `observation_spec()`. 136 | """ 137 | 138 | def reward_spec(self): 139 | """Describes the reward returned by the environment. 140 | 141 | By default this is assumed to be a single float. 142 | 143 | Returns: 144 | An `Array` spec, or a nested dict, list or tuple of `Array` specs. 145 | """ 146 | return specs.Array(shape=(), dtype=float, name='reward') 147 | 148 | def discount_spec(self): 149 | """Describes the discount returned by the environment. 150 | 151 | By default this is assumed to be a single float between 0 and 1. 152 | 153 | Returns: 154 | An `Array` spec, or a nested dict, list or tuple of `Array` specs. 155 | """ 156 | return specs.BoundedArray( 157 | shape=(), dtype=float, minimum=0., maximum=1., name='discount') 158 | 159 | @abc.abstractmethod 160 | def observation_spec(self): 161 | """Defines the observations provided by the environment. 162 | 163 | May use a subclass of `specs.Array` that specifies additional properties 164 | such as min and max bounds on the values. 165 | 166 | Returns: 167 | An `Array` spec, or a nested dict, list or tuple of `Array` specs. 168 | """ 169 | 170 | @abc.abstractmethod 171 | def action_spec(self): 172 | """Defines the actions that should be provided to `step`. 173 | 174 | May use a subclass of `specs.Array` that specifies additional properties 175 | such as min and max bounds on the values. 176 | 177 | Returns: 178 | An `Array` spec, or a nested dict, list or tuple of `Array` specs. 179 | """ 180 | 181 | def close(self): 182 | """Frees any resources used by the environment. 183 | 184 | Implement this method for an environment backed by an external process. 185 | 186 | This method can be used directly 187 | 188 | ```python 189 | env = Env(...) 190 | # Use env. 191 | env.close() 192 | ``` 193 | 194 | or via a context manager 195 | 196 | ```python 197 | with Env(...) as env: 198 | # Use env. 199 | ``` 200 | """ 201 | pass 202 | 203 | def __enter__(self): 204 | """Allows the environment to be used in a with-statement context.""" 205 | return self 206 | 207 | def __exit__(self, exc_type, exc_value, traceback): 208 | """Allows the environment to be used in a with-statement context.""" 209 | del exc_type, exc_value, traceback # Unused. 210 | self.close() 211 | 212 | 213 | # Helper functions for creating TimeStep namedtuples with default settings. 214 | 215 | 216 | def restart(observation): 217 | """Returns a `TimeStep` with `step_type` set to `StepType.FIRST`.""" 218 | return TimeStep(StepType.FIRST, None, None, observation) 219 | 220 | 221 | def transition(reward, observation, discount=1.0): 222 | """Returns a `TimeStep` with `step_type` set to `StepType.MID`.""" 223 | return TimeStep(StepType.MID, reward, discount, observation) 224 | 225 | 226 | def termination(reward, observation): 227 | """Returns a `TimeStep` with `step_type` set to `StepType.LAST`.""" 228 | return TimeStep(StepType.LAST, reward, 0.0, observation) 229 | 230 | 231 | def truncation(reward, observation, discount=1.0): 232 | """Returns a `TimeStep` with `step_type` set to `StepType.LAST`.""" 233 | return TimeStep(StepType.LAST, reward, discount, observation) 234 | -------------------------------------------------------------------------------- /dm_env/_environment_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for dm_env._environment.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import dm_env 21 | 22 | 23 | class TimeStepHelpersTest(parameterized.TestCase): 24 | 25 | @parameterized.parameters(dict(observation=-1), dict(observation=[2., 3.])) 26 | def test_restart(self, observation): 27 | time_step = dm_env.restart(observation) 28 | self.assertIs(dm_env.StepType.FIRST, time_step.step_type) 29 | self.assertEqual(observation, time_step.observation) 30 | self.assertIsNone(time_step.reward) 31 | self.assertIsNone(time_step.discount) 32 | 33 | @parameterized.parameters( 34 | dict(observation=-1., reward=2.0, discount=1.0), 35 | dict(observation=(2., 3.), reward=0., discount=0.)) 36 | def test_transition(self, observation, reward, discount): 37 | time_step = dm_env.transition( 38 | reward=reward, observation=observation, discount=discount) 39 | self.assertIs(dm_env.StepType.MID, time_step.step_type) 40 | self.assertEqual(observation, time_step.observation) 41 | self.assertEqual(reward, time_step.reward) 42 | self.assertEqual(discount, time_step.discount) 43 | 44 | @parameterized.parameters( 45 | dict(observation=-1., reward=2.0), 46 | dict(observation=(2., 3.), reward=0.)) 47 | def test_termination(self, observation, reward): 48 | time_step = dm_env.termination(reward=reward, observation=observation) 49 | self.assertIs(dm_env.StepType.LAST, time_step.step_type) 50 | self.assertEqual(observation, time_step.observation) 51 | self.assertEqual(reward, time_step.reward) 52 | self.assertEqual(0.0, time_step.discount) 53 | 54 | @parameterized.parameters( 55 | dict(observation=-1., reward=2.0, discount=1.0), 56 | dict(observation=(2., 3.), reward=0., discount=0.)) 57 | def test_truncation(self, reward, observation, discount): 58 | time_step = dm_env.truncation(reward, observation, discount) 59 | self.assertIs(dm_env.StepType.LAST, time_step.step_type) 60 | self.assertEqual(observation, time_step.observation) 61 | self.assertEqual(reward, time_step.reward) 62 | self.assertEqual(discount, time_step.discount) 63 | 64 | @parameterized.parameters( 65 | dict(step_type=dm_env.StepType.FIRST, 66 | is_first=True, is_mid=False, is_last=False), 67 | dict(step_type=dm_env.StepType.MID, 68 | is_first=False, is_mid=True, is_last=False), 69 | dict(step_type=dm_env.StepType.LAST, 70 | is_first=False, is_mid=False, is_last=True), 71 | ) 72 | def test_step_type_helpers(self, step_type, is_first, is_mid, is_last): 73 | time_step = dm_env.TimeStep( 74 | reward=None, discount=None, observation=None, step_type=step_type) 75 | 76 | with self.subTest('TimeStep methods'): 77 | self.assertEqual(is_first, time_step.first()) 78 | self.assertEqual(is_mid, time_step.mid()) 79 | self.assertEqual(is_last, time_step.last()) 80 | 81 | with self.subTest('StepType methods'): 82 | self.assertEqual(is_first, time_step.step_type.first()) 83 | self.assertEqual(is_mid, time_step.step_type.mid()) 84 | self.assertEqual(is_last, time_step.step_type.last()) 85 | 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /dm_env/_metadata.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Package metadata for dm_env. 17 | 18 | This is kept in a separate module so that it can be imported from setup.py, at 19 | a time when dm_env's dependencies may not have been installed yet. 20 | """ 21 | 22 | __version__ = '1.6' # https://www.python.org/dev/peps/pep-0440/ 23 | -------------------------------------------------------------------------------- /dm_env/auto_reset_environment.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2021 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Auto-resetting dm_env.Environment helper class. 17 | 18 | The dm_env API states that stepping an environment after a LAST timestep should 19 | return the FIRST timestep of a new episode. Environment authors sometimes miss 20 | this part or find it awkward to implement. This module contains a class that 21 | helps implement the reset behaviour. 22 | """ 23 | 24 | import abc 25 | from dm_env import _environment 26 | 27 | 28 | class AutoResetEnvironment(_environment.Environment): 29 | """Auto-resetting environment base class. 30 | 31 | This class implements the required `step()` and `reset()` methods and 32 | instead requires users to implement `_step()` and `_reset()`. This class 33 | handles the reset behaviour automatically when it detects a LAST timestep. 34 | """ 35 | 36 | def __init__(self): 37 | self._reset_next_step = True 38 | 39 | @abc.abstractmethod 40 | def _reset(self) -> _environment.TimeStep: 41 | """Returns a `timestep` namedtuple as per the regular `reset()` method.""" 42 | 43 | @abc.abstractmethod 44 | def _step(self, action) -> _environment.TimeStep: 45 | """Returns a `timestep` namedtuple as per the regular `step()` method.""" 46 | 47 | def reset(self) -> _environment.TimeStep: 48 | self._reset_next_step = False 49 | return self._reset() 50 | 51 | def step(self, action) -> _environment.TimeStep: 52 | if self._reset_next_step: 53 | return self.reset() 54 | timestep = self._step(action) 55 | self._reset_next_step = timestep.last() 56 | return timestep 57 | -------------------------------------------------------------------------------- /dm_env/auto_reset_environment_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2021 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for auto_reset_environment.""" 17 | 18 | from absl.testing import absltest 19 | from dm_env import _environment 20 | from dm_env import auto_reset_environment 21 | from dm_env import specs 22 | from dm_env import test_utils 23 | import numpy as np 24 | 25 | 26 | class FakeEnvironment(auto_reset_environment.AutoResetEnvironment): 27 | """Environment that resets after a given number of steps.""" 28 | 29 | def __init__(self, step_limit): 30 | super(FakeEnvironment, self).__init__() 31 | self._step_limit = step_limit 32 | self._steps_taken = 0 33 | 34 | def _reset(self): 35 | self._steps_taken = 0 36 | return _environment.restart(observation=np.zeros(3)) 37 | 38 | def _step(self, action): 39 | self._steps_taken += 1 40 | if self._steps_taken < self._step_limit: 41 | return _environment.transition(reward=0.0, observation=np.zeros(3)) 42 | else: 43 | return _environment.termination(reward=0.0, observation=np.zeros(3)) 44 | 45 | def action_spec(self): 46 | return specs.Array(shape=(), dtype='int') 47 | 48 | def observation_spec(self): 49 | return specs.Array(shape=(3,), dtype='float') 50 | 51 | 52 | class AutoResetEnvironmentTest(test_utils.EnvironmentTestMixin, 53 | absltest.TestCase): 54 | 55 | def make_object_under_test(self): 56 | return FakeEnvironment(step_limit=5) 57 | 58 | def make_action_sequence(self): 59 | for _ in range(20): 60 | yield np.array(0) 61 | 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /dm_env/specs.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Classes that describe numpy arrays.""" 17 | 18 | import inspect 19 | from typing import Optional 20 | 21 | import numpy as np 22 | 23 | _INVALID_SHAPE = 'Expected shape %r but found %r' 24 | _INVALID_DTYPE = 'Expected dtype %r but found %r' 25 | _OUT_OF_BOUNDS = 'Values were not all within bounds %s <= %s <= %s' 26 | _VAR_ARGS_NOT_ALLOWED = 'Spec subclasses must not accept *args.' 27 | _VAR_KWARGS_NOT_ALLOWED = 'Spec subclasses must not accept **kwargs.' 28 | _MINIMUM_MUST_BE_LESS_THAN_OR_EQUAL_TO_MAXIMUM = ( 29 | 'All values in `minimum` must be less than or equal to their corresponding ' 30 | 'value in `maximum`, got:\nminimum={minimum!r}\nmaximum={maximum!r}.') 31 | _MINIMUM_INCOMPATIBLE_WITH_SHAPE = '`minimum` is incompatible with `shape`' 32 | _MAXIMUM_INCOMPATIBLE_WITH_SHAPE = '`maximum` is incompatible with `shape`' 33 | 34 | 35 | class Array: 36 | """Describes a numpy array or scalar shape and dtype. 37 | 38 | An `Array` spec allows an API to describe the arrays that it accepts or 39 | returns, before that array exists. 40 | The equivalent version describing a `tf.Tensor` is `TensorSpec`. 41 | """ 42 | __slots__ = ('_shape', '_dtype', '_name') 43 | __hash__ = None 44 | 45 | def __init__(self, shape, dtype, name: Optional[str] = None): 46 | """Initializes a new `Array` spec. 47 | 48 | Args: 49 | shape: An iterable specifying the array shape. 50 | dtype: numpy dtype or string specifying the array dtype. 51 | name: Optional string containing a semantic name for the corresponding 52 | array. Defaults to `None`. 53 | 54 | Raises: 55 | TypeError: If `shape` is not an iterable of elements convertible to int, 56 | or if `dtype` is not convertible to a numpy dtype. 57 | """ 58 | self._shape = tuple(int(dim) for dim in shape) 59 | self._dtype = np.dtype(dtype) 60 | self._name = name 61 | 62 | @property 63 | def shape(self): 64 | """Returns a `tuple` specifying the array shape.""" 65 | return self._shape 66 | 67 | @property 68 | def dtype(self): 69 | """Returns a numpy dtype specifying the array dtype.""" 70 | return self._dtype 71 | 72 | @property 73 | def name(self): 74 | """Returns the name of the Array.""" 75 | return self._name 76 | 77 | def __repr__(self): 78 | return 'Array(shape={}, dtype={}, name={})'.format(self.shape, 79 | repr(self.dtype), 80 | repr(self.name)) 81 | 82 | def __eq__(self, other): 83 | """Checks if the shape and dtype of two specs are equal.""" 84 | if not isinstance(other, Array): 85 | return False 86 | return self.shape == other.shape and self.dtype == other.dtype 87 | 88 | def __ne__(self, other): 89 | return not self == other 90 | 91 | def _fail_validation(self, message, *args): 92 | message %= args 93 | if self.name: 94 | message += ' for spec %s' % self.name 95 | raise ValueError(message) 96 | 97 | def validate(self, value): 98 | """Checks if value conforms to this spec. 99 | 100 | Args: 101 | value: a numpy array or value convertible to one via `np.asarray`. 102 | 103 | Returns: 104 | value, converted if necessary to a numpy array. 105 | 106 | Raises: 107 | ValueError: if value doesn't conform to this spec. 108 | """ 109 | value = np.asarray(value) 110 | if value.shape != self.shape: 111 | self._fail_validation(_INVALID_SHAPE, self.shape, value.shape) 112 | if value.dtype != self.dtype: 113 | self._fail_validation(_INVALID_DTYPE, self.dtype, value.dtype) 114 | return value 115 | 116 | def generate_value(self): 117 | """Generate a test value which conforms to this spec.""" 118 | return np.zeros(shape=self.shape, dtype=self.dtype) 119 | 120 | def _get_constructor_kwargs(self): 121 | """Returns constructor kwargs for instantiating a new copy of this spec.""" 122 | # Get the names and kinds of the constructor parameters. 123 | params = inspect.signature(type(self)).parameters 124 | # __init__ must not accept *args or **kwargs, since otherwise we won't be 125 | # able to infer what the corresponding attribute names are. 126 | kinds = {value.kind for value in params.values()} 127 | if inspect.Parameter.VAR_POSITIONAL in kinds: 128 | raise TypeError(_VAR_ARGS_NOT_ALLOWED) 129 | elif inspect.Parameter.VAR_KEYWORD in kinds: 130 | raise TypeError(_VAR_KWARGS_NOT_ALLOWED) 131 | # Note that we assume direct correspondence between the names of constructor 132 | # arguments and attributes. 133 | return {name: getattr(self, name) for name in params.keys()} 134 | 135 | def replace(self, **kwargs): 136 | """Returns a new copy of `self` with specified attributes replaced. 137 | 138 | Args: 139 | **kwargs: Optional attributes to replace. 140 | 141 | Returns: 142 | A new copy of `self`. 143 | """ 144 | all_kwargs = self._get_constructor_kwargs() 145 | all_kwargs.update(kwargs) 146 | return type(self)(**all_kwargs) 147 | 148 | def __reduce__(self): 149 | return Array, (self._shape, self._dtype, self._name) 150 | 151 | 152 | class BoundedArray(Array): 153 | """An `Array` spec that specifies minimum and maximum values. 154 | 155 | Example usage: 156 | ```python 157 | # Specifying the same minimum and maximum for every element. 158 | spec = BoundedArray((3, 4), np.float64, minimum=0.0, maximum=1.0) 159 | 160 | # Specifying a different minimum and maximum for each element. 161 | spec = BoundedArray( 162 | (2,), np.float64, minimum=[0.1, 0.2], maximum=[0.9, 0.9]) 163 | 164 | # Specifying the same minimum and a different maximum for each element. 165 | spec = BoundedArray( 166 | (3,), np.float64, minimum=-10.0, maximum=[4.0, 5.0, 3.0]) 167 | ``` 168 | 169 | Bounds are meant to be inclusive. This is especially important for 170 | integer types. The following spec will be satisfied by arrays 171 | with values in the set {0, 1, 2}: 172 | ```python 173 | spec = BoundedArray((3, 4), int, minimum=0, maximum=2) 174 | ``` 175 | 176 | Note that one or both bounds may be infinite. For example, the set of 177 | non-negative floats can be expressed as: 178 | ```python 179 | spec = BoundedArray((), np.float64, minimum=0.0, maximum=np.inf) 180 | ``` 181 | In this case `np.inf` would be considered valid, since the upper bound is 182 | inclusive. 183 | """ 184 | __slots__ = ('_minimum', '_maximum') 185 | __hash__ = None 186 | 187 | def __init__(self, shape, dtype, minimum, maximum, name=None): 188 | """Initializes a new `BoundedArray` spec. 189 | 190 | Args: 191 | shape: An iterable specifying the array shape. 192 | dtype: numpy dtype or string specifying the array dtype. 193 | minimum: Number or sequence specifying the minimum element bounds 194 | (inclusive). Must be broadcastable to `shape`. 195 | maximum: Number or sequence specifying the maximum element bounds 196 | (inclusive). Must be broadcastable to `shape`. 197 | name: Optional string containing a semantic name for the corresponding 198 | array. Defaults to `None`. 199 | 200 | Raises: 201 | ValueError: If `minimum` or `maximum` are not broadcastable to `shape`. 202 | ValueError: If any values in `minimum` are greater than their 203 | corresponding value in `maximum`. 204 | TypeError: If the shape is not an iterable or if the `dtype` is an invalid 205 | numpy dtype. 206 | """ 207 | super(BoundedArray, self).__init__(shape, dtype, name) 208 | 209 | try: 210 | bcast_minimum = np.broadcast_to(minimum, shape=shape) 211 | except ValueError as numpy_exception: 212 | raise ValueError(_MINIMUM_INCOMPATIBLE_WITH_SHAPE) from numpy_exception 213 | try: 214 | bcast_maximum = np.broadcast_to(maximum, shape=shape) 215 | except ValueError as numpy_exception: 216 | raise ValueError(_MAXIMUM_INCOMPATIBLE_WITH_SHAPE) from numpy_exception 217 | 218 | if np.any(bcast_minimum > bcast_maximum): 219 | raise ValueError(_MINIMUM_MUST_BE_LESS_THAN_OR_EQUAL_TO_MAXIMUM.format( 220 | minimum=minimum, maximum=maximum)) 221 | 222 | self._minimum = np.array(minimum, dtype=self.dtype) 223 | self._minimum.setflags(write=False) 224 | 225 | self._maximum = np.array(maximum, dtype=self.dtype) 226 | self._maximum.setflags(write=False) 227 | 228 | @property 229 | def minimum(self): 230 | """Returns a NumPy array specifying the minimum bounds (inclusive).""" 231 | return self._minimum 232 | 233 | @property 234 | def maximum(self): 235 | """Returns a NumPy array specifying the maximum bounds (inclusive).""" 236 | return self._maximum 237 | 238 | def __repr__(self): 239 | template = ('BoundedArray(shape={}, dtype={}, name={}, ' 240 | 'minimum={}, maximum={})') 241 | return template.format(self.shape, repr(self.dtype), repr(self.name), 242 | self._minimum, self._maximum) 243 | 244 | def __eq__(self, other): 245 | if not isinstance(other, BoundedArray): 246 | return False 247 | return (super(BoundedArray, self).__eq__(other) and 248 | (self.minimum == other.minimum).all() and 249 | (self.maximum == other.maximum).all()) 250 | 251 | def validate(self, value): 252 | value = np.asarray(value) 253 | super(BoundedArray, self).validate(value) 254 | if (value < self.minimum).any() or (value > self.maximum).any(): 255 | self._fail_validation(_OUT_OF_BOUNDS, self.minimum, value, self.maximum) 256 | return value 257 | 258 | def generate_value(self): 259 | return (np.ones(shape=self.shape, dtype=self.dtype) * 260 | self.dtype.type(self.minimum)) 261 | 262 | def __reduce__(self): 263 | return BoundedArray, (self._shape, self._dtype, self._minimum, 264 | self._maximum, self._name) 265 | 266 | 267 | _NUM_VALUES_NOT_POSITIVE = '`num_values` must be a positive integer, got {}.' 268 | _DTYPE_NOT_INTEGRAL = '`dtype` must be integral, got {}.' 269 | _DTYPE_OVERFLOW = ( 270 | '`dtype` {} is not big enough to hold `num_values` ({}) without overflow.') 271 | 272 | 273 | class DiscreteArray(BoundedArray): 274 | """Represents a discrete, scalar, zero-based space. 275 | 276 | This is a special case of the parent `BoundedArray` class. It represents a 277 | 0-dimensional numpy array containing a single integer value between 278 | 0 and num_values - 1 (inclusive), and exposes a scalar `num_values` property 279 | in addition to the standard `BoundedArray` interface. 280 | 281 | For an example use-case, this can be used to define the action space of a 282 | simple RL environment that accepts discrete actions. 283 | """ 284 | 285 | _REPR_TEMPLATE = ( 286 | 'DiscreteArray(shape={self.shape}, dtype={self.dtype}, name={self.name}, ' 287 | 'minimum={self.minimum}, maximum={self.maximum}, ' 288 | 'num_values={self.num_values})') 289 | 290 | __slots__ = ('_num_values',) 291 | 292 | def __init__(self, num_values, dtype=np.int32, name=None): 293 | """Initializes a new `DiscreteArray` spec. 294 | 295 | Args: 296 | num_values: Integer specifying the number of possible values to represent. 297 | dtype: The dtype of the array. Must be an integral type large enough to 298 | hold `num_values` without overflow. 299 | name: Optional string specifying the name of the array. 300 | 301 | Raises: 302 | ValueError: If `num_values` is not positive, if `dtype` is not integral, 303 | or if `dtype` is not large enough to hold `num_values` without overflow. 304 | """ 305 | if num_values <= 0 or not np.issubdtype(type(num_values), np.integer): 306 | raise ValueError(_NUM_VALUES_NOT_POSITIVE.format(num_values)) 307 | 308 | if not np.issubdtype(dtype, np.integer): 309 | raise ValueError(_DTYPE_NOT_INTEGRAL.format(dtype)) 310 | 311 | num_values = int(num_values) 312 | maximum = num_values - 1 313 | dtype = np.dtype(dtype) 314 | 315 | if np.min_scalar_type(maximum) > dtype: 316 | raise ValueError(_DTYPE_OVERFLOW.format(dtype, num_values)) 317 | 318 | super(DiscreteArray, self).__init__( 319 | shape=(), 320 | dtype=dtype, 321 | minimum=0, 322 | maximum=maximum, 323 | name=name) 324 | self._num_values = num_values 325 | 326 | @property 327 | def num_values(self): 328 | """Returns the number of items.""" 329 | return self._num_values 330 | 331 | def __repr__(self): 332 | return self._REPR_TEMPLATE.format(self=self) # pytype: disable=duplicate-keyword-argument 333 | 334 | def __reduce__(self): 335 | return DiscreteArray, (self._num_values, self._dtype, self._name) 336 | 337 | 338 | _VALID_STRING_TYPES = (str, bytes) 339 | _INVALID_STRING_TYPE = ( 340 | 'Expected `string_type` to be one of: {}, got: {{!r}}.' 341 | .format(_VALID_STRING_TYPES)) 342 | _INVALID_ELEMENT_TYPE = ( 343 | 'Expected all elements to be of type: %s. Got value: %r of type: %s.') 344 | 345 | 346 | class StringArray(Array): 347 | """Represents an array of variable-length Python strings.""" 348 | __slots__ = ('_string_type',) 349 | 350 | _REPR_TEMPLATE = ( 351 | '{self.__class__.__name__}(shape={self.shape}, ' 352 | 'string_type={self.string_type}, name={self.name})') 353 | 354 | def __init__(self, shape, string_type=str, name=None): 355 | """Initializes a new `StringArray` spec. 356 | 357 | Args: 358 | shape: An iterable specifying the array shape. 359 | string_type: The native Python string type for each element; either 360 | unicode or ASCII. Defaults to unicode. 361 | name: Optional string containing a semantic name for the corresponding 362 | array. Defaults to `None`. 363 | """ 364 | if string_type not in _VALID_STRING_TYPES: 365 | raise ValueError(_INVALID_STRING_TYPE.format(string_type)) 366 | self._string_type = string_type 367 | super(StringArray, self).__init__(shape=shape, dtype=object, name=name) 368 | 369 | @property 370 | def string_type(self): 371 | """Returns the Python string type for each element.""" 372 | return self._string_type 373 | 374 | def validate(self, value): 375 | """Checks if value conforms to this spec. 376 | 377 | Args: 378 | value: a numpy array or value convertible to one via `np.asarray`. 379 | 380 | Returns: 381 | value, converted if necessary to a numpy array. 382 | 383 | Raises: 384 | ValueError: if value doesn't conform to this spec. 385 | """ 386 | value = np.asarray(value, dtype=object) 387 | if value.shape != self.shape: 388 | self._fail_validation(_INVALID_SHAPE, self.shape, value.shape) 389 | for item in value.flat: 390 | if not isinstance(item, self.string_type): 391 | self._fail_validation( 392 | _INVALID_ELEMENT_TYPE, self.string_type, item, type(item)) 393 | return value 394 | 395 | def generate_value(self): 396 | """Generate a test value which conforms to this spec.""" 397 | empty_string = self.string_type() # pylint: disable=not-callable 398 | return np.full(shape=self.shape, dtype=self.dtype, fill_value=empty_string) 399 | 400 | def __repr__(self): 401 | return self._REPR_TEMPLATE.format(self=self) # pytype: disable=duplicate-keyword-argument 402 | 403 | def __reduce__(self): 404 | return type(self), (self.shape, self.string_type, self.name) 405 | -------------------------------------------------------------------------------- /dm_env/specs_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for dm_env.specs.""" 17 | 18 | import pickle 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from dm_env import specs 23 | import numpy as np 24 | 25 | 26 | class ArrayTest(parameterized.TestCase): 27 | 28 | def testShapeTypeError(self): 29 | with self.assertRaises(TypeError): 30 | specs.Array(32, np.int32) 31 | 32 | def testShapeElementTypeError(self): 33 | with self.assertRaises(TypeError): 34 | specs.Array([None], np.int32) 35 | 36 | def testDtypeTypeError(self): 37 | with self.assertRaises(TypeError): 38 | specs.Array((1, 2, 3), "32") 39 | 40 | def testScalarShape(self): 41 | specs.Array((), np.int32) 42 | 43 | def testStringDtype(self): 44 | specs.Array((1, 2, 3), "int32") 45 | 46 | def testNumpyDtype(self): 47 | specs.Array((1, 2, 3), np.int32) 48 | 49 | def testDtype(self): 50 | spec = specs.Array((1, 2, 3), np.int32) 51 | self.assertEqual(np.int32, spec.dtype) 52 | 53 | def testShape(self): 54 | spec = specs.Array([1, 2, 3], np.int32) 55 | self.assertEqual((1, 2, 3), spec.shape) 56 | 57 | def testEqual(self): 58 | spec_1 = specs.Array((1, 2, 3), np.int32) 59 | spec_2 = specs.Array((1, 2, 3), np.int32) 60 | self.assertEqual(spec_1, spec_2) 61 | 62 | def testNotEqualDifferentShape(self): 63 | spec_1 = specs.Array((1, 2, 3), np.int32) 64 | spec_2 = specs.Array((1, 3, 3), np.int32) 65 | self.assertNotEqual(spec_1, spec_2) 66 | 67 | def testNotEqualDifferentDtype(self): 68 | spec_1 = specs.Array((1, 2, 3), np.int64) 69 | spec_2 = specs.Array((1, 2, 3), np.int32) 70 | self.assertNotEqual(spec_1, spec_2) 71 | 72 | def testNotEqualOtherClass(self): 73 | spec_1 = specs.Array((1, 2, 3), np.int32) 74 | spec_2 = None 75 | self.assertNotEqual(spec_1, spec_2) 76 | self.assertNotEqual(spec_2, spec_1) 77 | 78 | spec_2 = () 79 | self.assertNotEqual(spec_1, spec_2) 80 | self.assertNotEqual(spec_2, spec_1) 81 | 82 | def testIsUnhashable(self): 83 | spec = specs.Array(shape=(1, 2, 3), dtype=np.int32) 84 | with self.assertRaisesRegex(TypeError, "unhashable type"): 85 | hash(spec) 86 | 87 | @parameterized.parameters( 88 | dict(value=np.zeros((1, 2), dtype=np.int32), is_valid=True), 89 | dict(value=np.zeros((1, 2), dtype=np.float32), is_valid=False), 90 | ) 91 | def testValidateDtype(self, value, is_valid): 92 | spec = specs.Array((1, 2), np.int32) 93 | if is_valid: # Should not raise any exception. 94 | spec.validate(value) 95 | else: 96 | with self.assertRaisesWithLiteralMatch( 97 | ValueError, 98 | specs._INVALID_DTYPE % (spec.dtype, value.dtype)): 99 | spec.validate(value) 100 | 101 | @parameterized.parameters( 102 | dict(value=np.zeros((1, 2), dtype=np.int32), is_valid=True), 103 | dict(value=np.zeros((1, 2, 3), dtype=np.int32), is_valid=False), 104 | ) 105 | def testValidateShape(self, value, is_valid): 106 | spec = specs.Array((1, 2), np.int32) 107 | if is_valid: # Should not raise any exception. 108 | spec.validate(value) 109 | else: 110 | with self.assertRaisesWithLiteralMatch( 111 | ValueError, 112 | specs._INVALID_SHAPE % (spec.shape, value.shape)): 113 | spec.validate(value) 114 | 115 | def testGenerateValue(self): 116 | spec = specs.Array((1, 2), np.int32) 117 | test_value = spec.generate_value() 118 | spec.validate(test_value) 119 | 120 | def testSerialization(self): 121 | desc = specs.Array([1, 5], np.float32, "test") 122 | self.assertEqual(pickle.loads(pickle.dumps(desc)), desc) 123 | 124 | @parameterized.parameters( 125 | {"arg_name": "shape", "new_value": (2, 3)}, 126 | {"arg_name": "dtype", "new_value": np.int32}, 127 | {"arg_name": "name", "new_value": "something_else"}) 128 | def testReplace(self, arg_name, new_value): 129 | old_spec = specs.Array([1, 5], np.float32, "test") 130 | new_spec = old_spec.replace(**{arg_name: new_value}) 131 | self.assertIsNot(old_spec, new_spec) 132 | self.assertEqual(getattr(new_spec, arg_name), new_value) 133 | for attr_name in set(["shape", "dtype", "name"]).difference([arg_name]): 134 | self.assertEqual(getattr(new_spec, attr_name), 135 | getattr(old_spec, attr_name)) 136 | 137 | def testReplaceRaisesTypeErrorIfSubclassAcceptsVarArgs(self): 138 | 139 | class InvalidSpecSubclass(specs.Array): 140 | 141 | def __init__(self, *args): # pylint: disable=useless-super-delegation 142 | super(InvalidSpecSubclass, self).__init__(*args) 143 | 144 | spec = InvalidSpecSubclass([1, 5], np.float32, "test") 145 | 146 | with self.assertRaisesWithLiteralMatch( 147 | TypeError, specs._VAR_ARGS_NOT_ALLOWED): 148 | spec.replace(name="something_else") 149 | 150 | def testReplaceRaisesTypeErrorIfSubclassAcceptsVarKwargs(self): 151 | 152 | class InvalidSpecSubclass(specs.Array): 153 | 154 | def __init__(self, **kwargs): # pylint: disable=useless-super-delegation 155 | super(InvalidSpecSubclass, self).__init__(**kwargs) 156 | 157 | spec = InvalidSpecSubclass(shape=[1, 5], dtype=np.float32, name="test") 158 | 159 | with self.assertRaisesWithLiteralMatch( 160 | TypeError, specs._VAR_KWARGS_NOT_ALLOWED): 161 | spec.replace(name="something_else") 162 | 163 | 164 | class BoundedArrayTest(parameterized.TestCase): 165 | 166 | def testInvalidMinimum(self): 167 | with self.assertRaisesWithLiteralMatch( 168 | ValueError, specs._MINIMUM_INCOMPATIBLE_WITH_SHAPE): 169 | specs.BoundedArray((3, 5), np.uint8, (0, 0, 0), (1, 1)) 170 | 171 | def testInvalidMaximum(self): 172 | with self.assertRaisesWithLiteralMatch( 173 | ValueError, specs._MAXIMUM_INCOMPATIBLE_WITH_SHAPE): 174 | specs.BoundedArray((3, 5), np.uint8, 0, (1, 1, 1)) 175 | 176 | def testMinMaxAttributes(self): 177 | spec = specs.BoundedArray((1, 2, 3), np.float32, 0, (5, 5, 5)) 178 | self.assertEqual(type(spec.minimum), np.ndarray) 179 | self.assertEqual(type(spec.maximum), np.ndarray) 180 | 181 | @parameterized.parameters( 182 | dict(spec_dtype=np.float32, min_dtype=np.float64, max_dtype=np.int32), 183 | dict(spec_dtype=np.uint64, min_dtype=np.uint8, max_dtype=float)) 184 | def testMinMaxCasting(self, spec_dtype, min_dtype, max_dtype): 185 | minimum = np.array(0., dtype=min_dtype) 186 | maximum = np.array((3.14, 15.9, 265.4), dtype=max_dtype) 187 | spec = specs.BoundedArray( 188 | shape=(1, 2, 3), dtype=spec_dtype, minimum=minimum, maximum=maximum) 189 | self.assertEqual(spec.minimum.dtype, spec_dtype) 190 | self.assertEqual(spec.maximum.dtype, spec_dtype) 191 | 192 | def testReadOnly(self): 193 | spec = specs.BoundedArray((1, 2, 3), np.float32, 0, (5, 5, 5)) 194 | with self.assertRaisesRegex(ValueError, "read-only"): 195 | spec.minimum[0] = -1 196 | with self.assertRaisesRegex(ValueError, "read-only"): 197 | spec.maximum[0] = 100 198 | 199 | def testEqualBroadcastingBounds(self): 200 | spec_1 = specs.BoundedArray( 201 | (1, 2), np.float32, minimum=0.0, maximum=1.0) 202 | spec_2 = specs.BoundedArray( 203 | (1, 2), np.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0]) 204 | self.assertEqual(spec_1, spec_2) 205 | 206 | def testNotEqualDifferentMinimum(self): 207 | spec_1 = specs.BoundedArray( 208 | (1, 2), np.float32, minimum=[0.0, -0.6], maximum=[1.0, 1.0]) 209 | spec_2 = specs.BoundedArray( 210 | (1, 2), np.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0]) 211 | self.assertNotEqual(spec_1, spec_2) 212 | 213 | def testNotEqualOtherClass(self): 214 | spec_1 = specs.BoundedArray( 215 | (1, 2), np.float32, minimum=[0.0, -0.6], maximum=[1.0, 1.0]) 216 | spec_2 = specs.Array((1, 2), np.float32) 217 | self.assertNotEqual(spec_1, spec_2) 218 | self.assertNotEqual(spec_2, spec_1) 219 | 220 | spec_2 = None 221 | self.assertNotEqual(spec_1, spec_2) 222 | self.assertNotEqual(spec_2, spec_1) 223 | 224 | spec_2 = () 225 | self.assertNotEqual(spec_1, spec_2) 226 | self.assertNotEqual(spec_2, spec_1) 227 | 228 | def testNotEqualDifferentMaximum(self): 229 | spec_1 = specs.BoundedArray( 230 | (1, 2), np.int32, minimum=0.0, maximum=2.0) 231 | spec_2 = specs.BoundedArray( 232 | (1, 2), np.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0]) 233 | self.assertNotEqual(spec_1, spec_2) 234 | 235 | def testIsUnhashable(self): 236 | spec = specs.BoundedArray( 237 | shape=(1, 2), dtype=np.int32, minimum=0.0, maximum=2.0) 238 | with self.assertRaisesRegex(TypeError, "unhashable type"): 239 | hash(spec) 240 | 241 | def testRepr(self): 242 | as_string = repr(specs.BoundedArray( 243 | (1, 2), np.int32, minimum=73.0, maximum=101.0)) 244 | self.assertIn("73", as_string) 245 | self.assertIn("101", as_string) 246 | 247 | @parameterized.parameters( 248 | dict(value=np.array([[5, 6], [8, 10]], dtype=np.int32), is_valid=True), 249 | dict(value=np.array([[5, 6], [8, 11]], dtype=np.int32), is_valid=False), 250 | dict(value=np.array([[4, 6], [8, 10]], dtype=np.int32), is_valid=False), 251 | ) 252 | def testValidateBounds(self, value, is_valid): 253 | spec = specs.BoundedArray((2, 2), np.int32, minimum=5, maximum=10) 254 | if is_valid: # Should not raise any exception. 255 | spec.validate(value) 256 | else: 257 | with self.assertRaisesWithLiteralMatch( 258 | ValueError, 259 | specs._OUT_OF_BOUNDS % (spec.minimum, value, spec.maximum)): 260 | spec.validate(value) 261 | 262 | @parameterized.parameters( 263 | # Semi-infinite intervals. 264 | dict(minimum=0., maximum=np.inf, value=0., is_valid=True), 265 | dict(minimum=0., maximum=np.inf, value=1., is_valid=True), 266 | dict(minimum=0., maximum=np.inf, value=np.inf, is_valid=True), 267 | dict(minimum=0., maximum=np.inf, value=-1., is_valid=False), 268 | dict(minimum=0., maximum=np.inf, value=-np.inf, is_valid=False), 269 | dict(minimum=-np.inf, maximum=0., value=0., is_valid=True), 270 | dict(minimum=-np.inf, maximum=0., value=-1., is_valid=True), 271 | dict(minimum=-np.inf, maximum=0., value=-np.inf, is_valid=True), 272 | dict(minimum=-np.inf, maximum=0., value=1., is_valid=False), 273 | # Infinite interval. 274 | dict(minimum=-np.inf, maximum=np.inf, value=1., is_valid=True), 275 | dict(minimum=-np.inf, maximum=np.inf, value=-1., is_valid=True), 276 | dict(minimum=-np.inf, maximum=np.inf, value=-np.inf, is_valid=True), 277 | dict(minimum=-np.inf, maximum=np.inf, value=np.inf, is_valid=True), 278 | # Special case where minimum == maximum. 279 | dict(minimum=0., maximum=0., value=0., is_valid=True), 280 | dict(minimum=0., maximum=0., value=np.finfo(float).eps, is_valid=False), 281 | ) 282 | def testValidateBoundsFloat(self, minimum, maximum, value, is_valid): 283 | spec = specs.BoundedArray((), float, minimum=minimum, maximum=maximum) 284 | if is_valid: # Should not raise any exception. 285 | spec.validate(value) 286 | else: 287 | with self.assertRaisesWithLiteralMatch( 288 | ValueError, 289 | specs._OUT_OF_BOUNDS % (spec.minimum, value, spec.maximum)): 290 | spec.validate(value) 291 | 292 | def testValidateReturnsValue(self): 293 | spec = specs.BoundedArray([1], np.int32, minimum=0, maximum=1) 294 | validated_value = spec.validate(np.array([0], dtype=np.int32)) 295 | self.assertIsNotNone(validated_value) 296 | 297 | def testGenerateValue(self): 298 | spec = specs.BoundedArray((2, 2), np.int32, minimum=5, maximum=10) 299 | test_value = spec.generate_value() 300 | spec.validate(test_value) 301 | 302 | def testScalarBounds(self): 303 | spec = specs.BoundedArray((), float, minimum=0.0, maximum=1.0) 304 | 305 | self.assertIsInstance(spec.minimum, np.ndarray) 306 | self.assertIsInstance(spec.maximum, np.ndarray) 307 | 308 | # Sanity check that numpy compares correctly to a scalar for an empty shape. 309 | self.assertEqual(0.0, spec.minimum) 310 | self.assertEqual(1.0, spec.maximum) 311 | 312 | # Check that the spec doesn't fail its own input validation. 313 | _ = specs.BoundedArray( 314 | spec.shape, spec.dtype, spec.minimum, spec.maximum) 315 | 316 | def testSerialization(self): 317 | desc = specs.BoundedArray([1, 5], np.float32, -1, 1, "test") 318 | self.assertEqual(pickle.loads(pickle.dumps(desc)), desc) 319 | 320 | @parameterized.parameters( 321 | {"arg_name": "shape", "new_value": (2, 3)}, 322 | {"arg_name": "dtype", "new_value": np.int32}, 323 | {"arg_name": "name", "new_value": "something_else"}, 324 | {"arg_name": "minimum", "new_value": -2}, 325 | {"arg_name": "maximum", "new_value": 2}, 326 | ) 327 | def testReplace(self, arg_name, new_value): 328 | old_spec = specs.BoundedArray([1, 5], np.float32, -1, 1, "test") 329 | new_spec = old_spec.replace(**{arg_name: new_value}) 330 | self.assertIsNot(old_spec, new_spec) 331 | self.assertEqual(getattr(new_spec, arg_name), new_value) 332 | for attr_name in set(["shape", "dtype", "name", "minimum", "maximum"] 333 | ).difference([arg_name]): 334 | self.assertEqual(getattr(new_spec, attr_name), 335 | getattr(old_spec, attr_name)) 336 | 337 | @parameterized.parameters([ 338 | dict(minimum=1., maximum=0.), 339 | dict(minimum=[0., 1.], maximum=0.), 340 | dict(minimum=1., maximum=[0., 0.]), 341 | dict(minimum=[0., 1.], maximum=[0., 0.]), 342 | ]) 343 | def testErrorIfMinimumGreaterThanMaximum(self, minimum, maximum): 344 | with self.assertRaisesWithLiteralMatch( 345 | ValueError, 346 | specs._MINIMUM_MUST_BE_LESS_THAN_OR_EQUAL_TO_MAXIMUM.format( 347 | minimum=minimum, maximum=maximum)): 348 | specs.BoundedArray((2,), np.float32, minimum, maximum, "test") 349 | 350 | 351 | class DiscreteArrayTest(parameterized.TestCase): 352 | 353 | @parameterized.parameters(0, -3) 354 | def testInvalidNumActions(self, num_values): 355 | with self.assertRaisesWithLiteralMatch( 356 | ValueError, specs._NUM_VALUES_NOT_POSITIVE.format(num_values)): 357 | specs.DiscreteArray(num_values=num_values) 358 | 359 | @parameterized.parameters(np.float32, object) 360 | def testDtypeNotIntegral(self, dtype): 361 | with self.assertRaisesWithLiteralMatch( 362 | ValueError, specs._DTYPE_NOT_INTEGRAL.format(dtype)): 363 | specs.DiscreteArray(num_values=5, dtype=dtype) 364 | 365 | @parameterized.parameters( 366 | dict(dtype=np.uint8, num_values=2 ** 8 + 1), 367 | dict(dtype=np.uint64, num_values=2 ** 64 + 1)) 368 | def testDtypeOverflow(self, num_values, dtype): 369 | with self.assertRaisesWithLiteralMatch( 370 | ValueError, specs._DTYPE_OVERFLOW.format(np.dtype(dtype), num_values)): 371 | specs.DiscreteArray(num_values=num_values, dtype=dtype) 372 | 373 | def testRepr(self): 374 | as_string = repr(specs.DiscreteArray(num_values=5)) 375 | self.assertIn("num_values=5", as_string) 376 | 377 | def testProperties(self): 378 | num_values = 5 379 | spec = specs.DiscreteArray(num_values=5) 380 | self.assertEqual(spec.minimum, 0) 381 | self.assertEqual(spec.maximum, num_values - 1) 382 | self.assertEqual(spec.dtype, np.int32) 383 | self.assertEqual(spec.num_values, num_values) 384 | 385 | def testSerialization(self): 386 | desc = specs.DiscreteArray(2, np.int32, "test") 387 | self.assertEqual(pickle.loads(pickle.dumps(desc)), desc) 388 | 389 | @parameterized.parameters( 390 | {"arg_name": "num_values", "new_value": 4}, 391 | {"arg_name": "dtype", "new_value": np.int64}, 392 | {"arg_name": "name", "new_value": "something_else"}) 393 | def testReplace(self, arg_name, new_value): 394 | old_spec = specs.DiscreteArray(2, np.int32, "test") 395 | new_spec = old_spec.replace(**{arg_name: new_value}) 396 | self.assertIsNot(old_spec, new_spec) 397 | self.assertEqual(getattr(new_spec, arg_name), new_value) 398 | for attr_name in set( 399 | ["num_values", "dtype", "name"]).difference([arg_name]): 400 | self.assertEqual(getattr(new_spec, attr_name), 401 | getattr(old_spec, attr_name)) 402 | 403 | 404 | class StringArrayTest(parameterized.TestCase): 405 | 406 | @parameterized.parameters(int, bool) 407 | def testInvalidStringType(self, string_type): 408 | with self.assertRaisesWithLiteralMatch( 409 | ValueError, specs._INVALID_STRING_TYPE.format(string_type)): 410 | specs.StringArray(shape=(), string_type=string_type) 411 | 412 | @parameterized.parameters( 413 | dict(value=[u"foo", u"bar"], spec_string_type=str), 414 | dict(value=(u"foo", u"bar"), spec_string_type=str), 415 | dict(value=np.array([u"foo", u"bar"]), spec_string_type=str), 416 | dict(value=[b"foo", b"bar"], spec_string_type=bytes), 417 | dict(value=(b"foo", b"bar"), spec_string_type=bytes), 418 | dict(value=np.array([b"foo", b"bar"]), spec_string_type=bytes), 419 | ) 420 | def testValidateCorrectInput(self, value, spec_string_type): 421 | spec = specs.StringArray(shape=(2,), string_type=spec_string_type) 422 | validated = spec.validate(value) 423 | self.assertIsInstance(validated, np.ndarray) 424 | 425 | @parameterized.parameters( 426 | dict(value=np.array(u"foo"), spec_shape=(1,)), 427 | dict(value=np.array([u"foo"]), spec_shape=()), 428 | dict(value=np.array([u"foo", u"bar", u"baz"]), spec_shape=(2,)), 429 | ) 430 | def testInvalidShape(self, value, spec_shape): 431 | spec = specs.StringArray(shape=spec_shape, string_type=str) 432 | with self.assertRaisesWithLiteralMatch( 433 | ValueError, 434 | specs._INVALID_SHAPE % (spec_shape, value.shape)): 435 | spec.validate(value) 436 | 437 | @parameterized.parameters( 438 | dict(bad_element=42, spec_string_type=str), 439 | dict(bad_element=False, spec_string_type=str), 440 | dict(bad_element=[u"foo"], spec_string_type=str), 441 | dict(bad_element=b"foo", spec_string_type=str), 442 | dict(bad_element=u"foo", spec_string_type=bytes), 443 | ) 444 | def testInvalidItemType(self, bad_element, spec_string_type): 445 | spec = specs.StringArray(shape=(3,), string_type=spec_string_type) 446 | good_element = spec_string_type() 447 | value = [good_element, bad_element, good_element] 448 | message = specs._INVALID_ELEMENT_TYPE % ( 449 | spec_string_type, bad_element, type(bad_element)) 450 | with self.assertRaisesWithLiteralMatch(ValueError, message): 451 | spec.validate(value) 452 | 453 | @parameterized.parameters( 454 | dict( 455 | shape=(), 456 | string_type=str, 457 | expected=np.array(u"", dtype=object)), 458 | dict( 459 | shape=(1, 2), 460 | string_type=bytes, 461 | expected=np.array([[b"", b""]], dtype=object)), 462 | ) 463 | def testGenerateValue(self, shape, string_type, expected): 464 | spec = specs.StringArray(shape=shape, string_type=string_type) 465 | value = spec.generate_value() 466 | spec.validate(value) # Should be valid. 467 | np.testing.assert_array_equal(expected, value) 468 | 469 | @parameterized.parameters( 470 | dict(shape=(), string_type=str, name=None), 471 | dict(shape=(2, 3), string_type=bytes, name="foobar"), 472 | ) 473 | def testRepr(self, shape, string_type, name): 474 | spec = specs.StringArray(shape=shape, string_type=string_type, name=name) 475 | spec_repr = repr(spec) 476 | self.assertIn("StringArray", spec_repr) 477 | self.assertIn("shape={}".format(shape), spec_repr) 478 | self.assertIn("string_type={}".format(string_type), spec_repr) 479 | self.assertIn("name={}".format(name), spec_repr) 480 | 481 | @parameterized.parameters( 482 | dict(shape=(), string_type=str, name=None), 483 | dict(shape=(2, 3), string_type=bytes, name="foobar"), 484 | ) 485 | def testSerialization(self, shape, string_type, name): 486 | spec = specs.StringArray(shape=shape, string_type=string_type, name=name) 487 | self.assertEqual(pickle.loads(pickle.dumps(spec)), spec) 488 | 489 | if __name__ == "__main__": 490 | absltest.main() 491 | -------------------------------------------------------------------------------- /dm_env/test_utils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Reusable fixtures for testing implementations of `dm_env.Environment`. 17 | 18 | This generally kicks the tyres on an environment, and checks that it complies 19 | with the interface contract for `dm_env.Environment`. 20 | 21 | To test your own environment, all that's required is to inherit from 22 | `EnvironmentTestMixin` and `absltest.TestCase` (in this order), overriding 23 | `make_object_under_test`: 24 | 25 | ```python 26 | from absl.testing import absltest 27 | from dm_env import test_utils 28 | 29 | class MyEnvImplementationTest(test_utils.EnvironmentTestMixin, 30 | absltest.TestCase): 31 | 32 | def make_object_under_test(self): 33 | return my_env.MyEnvImplementation() 34 | ``` 35 | 36 | We recommend that you also override `make_action_sequence` in order to generate 37 | a sequence of actions that covers any 'interesting' behaviour in your 38 | environment. For episodic environments in particular, we recommend returning an 39 | action sequence that allows the environment to reach the end of an episode, 40 | otherwise the contract around end-of-episode behaviour will not be checked. The 41 | default implementation of `make_action_sequence` simply generates a dummy action 42 | conforming to the `action_spec` and repeats it 20 times. 43 | 44 | You can also add your own tests alongside the defaults if you want to test some 45 | behaviour that's specific to your environment. There are some assertions and 46 | helpers here which may be useful to you in writing these tests. 47 | 48 | Note that we disable the pytype: attribute-error static check for the mixin as 49 | absltest.TestCase methods aren't statically available here, only once mixed in. 50 | """ 51 | 52 | from absl import logging 53 | import dm_env 54 | import tree 55 | from dm_env import _abstract_test_mixin 56 | _STEP_NEW_ENV_MUST_RETURN_FIRST = ( 57 | "calling step() on a fresh environment must produce a step with " 58 | "step_type FIRST, got {}") 59 | _RESET_MUST_RETURN_FIRST = ( 60 | "reset() must produce a step with step_type FIRST, got {}.") 61 | _FIRST_MUST_NOT_HAVE_REWARD = "a FIRST step must not have a reward." 62 | _FIRST_MUST_NOT_HAVE_DISCOUNT = "a FIRST step must not have a discount." 63 | _STEP_AFTER_FIRST_MUST_NOT_RETURN_FIRST = ( 64 | "calling step() after a FIRST step must not produce another FIRST.") 65 | _FIRST_MUST_COME_AFTER_LAST = ( 66 | "step() must produce a FIRST step after a LAST step.") 67 | _FIRST_MUST_ONLY_COME_AFTER_LAST = ( 68 | "step() must only produce a FIRST step after a LAST step " 69 | "or on a fresh environment.") 70 | 71 | 72 | class EnvironmentTestMixin(_abstract_test_mixin.TestMixin): 73 | """Mixin to help test implementations of `dm_env.Environment`. 74 | 75 | Subclasses must override `make_object_under_test` to return an instance of the 76 | `Environment` to be tested. 77 | """ 78 | 79 | @property 80 | def environment(self): 81 | """An alias of `self.object_under_test`, for readability.""" 82 | return self.object_under_test 83 | 84 | def tearDown(self): 85 | self.environment.close() 86 | # A call to super is required for cooperative multiple inheritance to work. 87 | super().tearDown() # pytype: disable=attribute-error 88 | 89 | def make_action_sequence(self): 90 | """Generates a sequence of actions for a longer test. 91 | 92 | Yields: 93 | A sequence of actions compatible with environment's action_spec(). 94 | 95 | Ideally you should override this to generate an action sequence that will 96 | trigger an end of episode, in order to ensure this behaviour is tested. 97 | Otherwise it will just repeat a test value conforming to the action spec 98 | 20 times. 99 | """ 100 | for _ in range(20): 101 | yield self.make_action() 102 | 103 | def make_action(self): 104 | """Returns a single action conforming to the environment's action_spec().""" 105 | spec = self.environment.action_spec() 106 | return tree.map_structure(lambda s: s.generate_value(), spec) 107 | 108 | def reset_environment(self): 109 | """Resets the environment and checks that the returned TimeStep is valid. 110 | 111 | Returns: 112 | The TimeStep instance returned by reset(). 113 | """ 114 | step = self.environment.reset() 115 | self.assertValidStep(step) 116 | self.assertIs(dm_env.StepType.FIRST, step.step_type, # pytype: disable=attribute-error 117 | _RESET_MUST_RETURN_FIRST.format(step.step_type)) 118 | return step 119 | 120 | def step_environment(self, action=None): 121 | """Steps the environment and checks that the returned TimeStep is valid. 122 | 123 | Args: 124 | action: Optional action conforming to the environment's action_spec(). If 125 | None then a valid action will be generated. 126 | 127 | Returns: 128 | The TimeStep instance returned by step(action). 129 | """ 130 | if action is None: 131 | action = self.make_action() 132 | step = self.environment.step(action) 133 | self.assertValidStep(step) 134 | return step 135 | 136 | # Custom assertions 137 | # ---------------------------------------------------------------------------- 138 | 139 | def assertValidStep(self, step): 140 | """Checks that a TimeStep conforms to the environment's specs. 141 | 142 | Args: 143 | step: An instance of TimeStep. 144 | """ 145 | # pytype: disable=attribute-error 146 | self.assertIsInstance(step, dm_env.TimeStep) 147 | self.assertIsInstance(step.step_type, dm_env.StepType) 148 | if step.step_type is dm_env.StepType.FIRST: 149 | self.assertIsNone(step.reward, _FIRST_MUST_NOT_HAVE_REWARD) 150 | self.assertIsNone(step.discount, _FIRST_MUST_NOT_HAVE_DISCOUNT) 151 | else: 152 | self.assertValidReward(step.reward) 153 | self.assertValidDiscount(step.discount) 154 | self.assertValidObservation(step.observation) 155 | # pytype: enable=attribute-error 156 | 157 | def assertConformsToSpec(self, value, spec): 158 | """Checks that `value` conforms to `spec`. 159 | 160 | Args: 161 | value: A potentially nested structure of numpy arrays or scalars. 162 | spec: A potentially nested structure of `specs.Array` instances. 163 | """ 164 | try: 165 | tree.assert_same_structure(value, spec) 166 | except (TypeError, ValueError) as e: 167 | self.fail("`spec` and `value` have mismatching structures: {}".format(e)) # pytype: disable=attribute-error 168 | def validate(path, item, array_spec): 169 | try: 170 | return array_spec.validate(item) 171 | except ValueError as e: 172 | raise ValueError("Value at path {!r} failed validation: {}." 173 | .format("/".join(map(str, path)), e)) 174 | tree.map_structure_with_path(validate, value, spec) 175 | 176 | def assertValidObservation(self, observation): 177 | """Checks that `observation` conforms to the `observation_spec()`.""" 178 | self.assertConformsToSpec(observation, self.environment.observation_spec()) 179 | 180 | def assertValidReward(self, reward): 181 | """Checks that `reward` conforms to the `reward_spec()`.""" 182 | self.assertConformsToSpec(reward, self.environment.reward_spec()) 183 | 184 | def assertValidDiscount(self, discount): 185 | """Checks that `discount` conforms to the `discount_spec()`.""" 186 | self.assertConformsToSpec(discount, self.environment.discount_spec()) 187 | 188 | # Test cases 189 | # ---------------------------------------------------------------------------- 190 | 191 | def test_reset(self): 192 | # Won't hurt to check this works twice in a row: 193 | for _ in range(2): 194 | self.reset_environment() 195 | 196 | def test_step_on_fresh_environment(self): 197 | # Calling `step()` on a fresh environment should be equivalent to `reset()`. 198 | # Note that the action should be ignored. 199 | step = self.step_environment() 200 | self.assertIs(dm_env.StepType.FIRST, step.step_type, # pytype: disable=attribute-error 201 | _STEP_NEW_ENV_MUST_RETURN_FIRST.format(step.step_type)) 202 | step = self.step_environment() 203 | self.assertIsNot(dm_env.StepType.FIRST, step.step_type, # pytype: disable=attribute-error 204 | _STEP_AFTER_FIRST_MUST_NOT_RETURN_FIRST) 205 | 206 | def test_step_after_reset(self): 207 | for _ in range(2): 208 | self.reset_environment() 209 | step = self.step_environment() 210 | self.assertIsNot(dm_env.StepType.FIRST, step.step_type, # pytype: disable=attribute-error 211 | _STEP_AFTER_FIRST_MUST_NOT_RETURN_FIRST) 212 | 213 | def test_longer_action_sequence(self): 214 | """Steps the environment using actions generated by `make_action_sequence`. 215 | 216 | The sequence of TimeSteps returned are checked for validity. 217 | """ 218 | encountered_last_step = False 219 | for _ in range(2): 220 | step = self.reset_environment() 221 | prev_step_type = step.step_type 222 | for action in self.make_action_sequence(): 223 | step = self.step_environment(action) 224 | if prev_step_type is dm_env.StepType.LAST: 225 | self.assertIs(dm_env.StepType.FIRST, step.step_type, # pytype: disable=attribute-error 226 | _FIRST_MUST_COME_AFTER_LAST) 227 | else: 228 | self.assertIsNot(dm_env.StepType.FIRST, step.step_type, # pytype: disable=attribute-error 229 | _FIRST_MUST_ONLY_COME_AFTER_LAST) 230 | if step.last(): 231 | encountered_last_step = True 232 | prev_step_type = step.step_type 233 | if not encountered_last_step: 234 | logging.info( 235 | "Could not test the contract around end-of-episode behaviour. " 236 | "Consider implementing `make_action_sequence` so that an end of " 237 | "episode is reached.") 238 | else: 239 | logging.info("Successfully checked end of episode.") 240 | -------------------------------------------------------------------------------- /dm_env/test_utils_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for dm_env.test_utils.""" 17 | 18 | import itertools 19 | 20 | from absl.testing import absltest 21 | import dm_env 22 | from dm_env import specs 23 | from dm_env import test_utils 24 | import numpy as np 25 | 26 | REWARD_SPEC = specs.Array(shape=(), dtype=float) 27 | DISCOUNT_SPEC = specs.BoundedArray(shape=(), dtype=float, minimum=0, maximum=1) 28 | OBSERVATION_SPEC = specs.Array(shape=(2, 3), dtype=float) 29 | ACTION_SPEC = specs.BoundedArray(shape=(), dtype=int, minimum=0, maximum=2) 30 | 31 | REWARD = REWARD_SPEC.generate_value() 32 | DISCOUNT = DISCOUNT_SPEC.generate_value() 33 | OBSERVATION = OBSERVATION_SPEC.generate_value() 34 | 35 | FIRST = dm_env.restart(observation=OBSERVATION) 36 | MID = dm_env.transition( 37 | reward=REWARD, observation=OBSERVATION, discount=DISCOUNT) 38 | LAST = dm_env.truncation( 39 | reward=REWARD, observation=OBSERVATION, discount=DISCOUNT) 40 | 41 | 42 | class MockEnvironment(dm_env.Environment): 43 | 44 | def __init__(self, timesteps): 45 | self._timesteps = timesteps 46 | self._iter_timesteps = itertools.cycle(self._timesteps) 47 | 48 | def reset(self): 49 | self._iter_timesteps = itertools.cycle(self._timesteps) 50 | return next(self._iter_timesteps) 51 | 52 | def step(self, action): 53 | return next(self._iter_timesteps) 54 | 55 | def reward_spec(self): 56 | return REWARD_SPEC 57 | 58 | def discount_spec(self): 59 | return DISCOUNT_SPEC 60 | 61 | def action_spec(self): 62 | return ACTION_SPEC 63 | 64 | def observation_spec(self): 65 | return OBSERVATION_SPEC 66 | 67 | 68 | def _make_test_case_with_expected_failures( 69 | name, 70 | timestep_sequence, 71 | expected_failures): 72 | 73 | class NewTestCase(test_utils.EnvironmentTestMixin, absltest.TestCase): 74 | 75 | def make_object_under_test(self): 76 | return MockEnvironment(timestep_sequence) 77 | 78 | for method_name, exception_type in expected_failures: 79 | def wrapped_method( 80 | self, method_name=method_name, exception_type=exception_type): 81 | super_method = getattr(super(NewTestCase, self), method_name) 82 | with self.assertRaises(exception_type): 83 | return super_method() 84 | setattr(NewTestCase, method_name, wrapped_method) 85 | 86 | NewTestCase.__name__ = name 87 | return NewTestCase 88 | 89 | 90 | TestValidTimestepSequence = _make_test_case_with_expected_failures( 91 | name='TestValidTimestepSequence', 92 | timestep_sequence=[FIRST, MID, MID, LAST], 93 | expected_failures=[], 94 | ) 95 | 96 | 97 | # Sequences where the ordering of StepTypes is invalid. 98 | 99 | 100 | TestTwoFirstStepsInARow = _make_test_case_with_expected_failures( 101 | name='TestTwoFirstStepsInARow', 102 | timestep_sequence=[FIRST, FIRST, MID, MID, LAST], 103 | expected_failures=[ 104 | ('test_longer_action_sequence', AssertionError), 105 | ('test_step_after_reset', AssertionError), 106 | ('test_step_on_fresh_environment', AssertionError), 107 | ], 108 | ) 109 | 110 | 111 | TestStartsWithMid = _make_test_case_with_expected_failures( 112 | name='TestStartsWithMid', 113 | timestep_sequence=[MID, MID, LAST], 114 | expected_failures=[ 115 | ('test_longer_action_sequence', AssertionError), 116 | ('test_reset', AssertionError), 117 | ('test_step_after_reset', AssertionError), 118 | ('test_step_on_fresh_environment', AssertionError), 119 | ], 120 | ) 121 | 122 | TestMidAfterLast = _make_test_case_with_expected_failures( 123 | name='TestMidAfterLast', 124 | timestep_sequence=[FIRST, MID, LAST, MID], 125 | expected_failures=[ 126 | ('test_longer_action_sequence', AssertionError), 127 | ], 128 | ) 129 | 130 | TestFirstAfterMid = _make_test_case_with_expected_failures( 131 | name='TestFirstAfterMid', 132 | timestep_sequence=[FIRST, MID, FIRST], 133 | expected_failures=[ 134 | ('test_longer_action_sequence', AssertionError), 135 | ], 136 | ) 137 | 138 | # Sequences where one or more TimeSteps have invalid contents. 139 | 140 | 141 | TestFirstStepHasReward = _make_test_case_with_expected_failures( 142 | name='TestFirstStepHasReward', 143 | timestep_sequence=[ 144 | FIRST._replace(reward=1.0), # Should be None. 145 | MID, 146 | MID, 147 | LAST, 148 | ], 149 | expected_failures=[ 150 | ('test_reset', AssertionError), 151 | ('test_step_after_reset', AssertionError), 152 | ('test_step_on_fresh_environment', AssertionError), 153 | ('test_longer_action_sequence', AssertionError), 154 | ] 155 | ) 156 | 157 | TestFirstStepHasDiscount = _make_test_case_with_expected_failures( 158 | name='TestFirstStepHasDiscount', 159 | timestep_sequence=[ 160 | FIRST._replace(discount=1.0), # Should be None. 161 | MID, 162 | MID, 163 | LAST, 164 | ], 165 | expected_failures=[ 166 | ('test_reset', AssertionError), 167 | ('test_step_after_reset', AssertionError), 168 | ('test_step_on_fresh_environment', AssertionError), 169 | ('test_longer_action_sequence', AssertionError), 170 | ] 171 | ) 172 | 173 | TestInvalidReward = _make_test_case_with_expected_failures( 174 | name='TestInvalidReward', 175 | timestep_sequence=[ 176 | FIRST, 177 | MID._replace(reward=False), # Should be a float. 178 | MID, 179 | LAST, 180 | ], 181 | expected_failures=[ 182 | ('test_step_after_reset', ValueError), 183 | ('test_step_on_fresh_environment', ValueError), 184 | ('test_longer_action_sequence', ValueError), 185 | ] 186 | ) 187 | 188 | TestInvalidDiscount = _make_test_case_with_expected_failures( 189 | name='TestInvalidDiscount', 190 | timestep_sequence=[ 191 | FIRST, 192 | MID._replace(discount=1.5), # Should be between 0 and 1. 193 | MID, 194 | LAST, 195 | ], 196 | expected_failures=[ 197 | ('test_step_after_reset', ValueError), 198 | ('test_step_on_fresh_environment', ValueError), 199 | ('test_longer_action_sequence', ValueError), 200 | ] 201 | ) 202 | 203 | TestInvalidObservation = _make_test_case_with_expected_failures( 204 | name='TestInvalidObservation', 205 | timestep_sequence=[ 206 | FIRST, 207 | MID._replace(observation=np.zeros((3, 4))), # Wrong shape. 208 | MID, 209 | LAST, 210 | ], 211 | expected_failures=[ 212 | ('test_step_after_reset', ValueError), 213 | ('test_step_on_fresh_environment', ValueError), 214 | ('test_longer_action_sequence', ValueError), 215 | ] 216 | ) 217 | 218 | TestMismatchingObservationStructure = _make_test_case_with_expected_failures( 219 | name='TestInvalidObservation', 220 | timestep_sequence=[ 221 | FIRST, 222 | MID._replace(observation=[OBSERVATION]), # Wrong structure. 223 | MID, 224 | LAST, 225 | ], 226 | expected_failures=[ 227 | ('test_step_after_reset', AssertionError), 228 | ('test_step_on_fresh_environment', AssertionError), 229 | ('test_longer_action_sequence', AssertionError), 230 | ] 231 | ) 232 | 233 | 234 | if __name__ == '__main__': 235 | absltest.main() 236 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Environment API and Semantics 2 | 3 | This text describes the Python-based Environment API defined by `dm_env`. 4 | 5 | ## Overview 6 | 7 | The main interaction with an environment is via the `step()` method. 8 | 9 | Each call to an environment's `step()` method takes an `action` parameter 10 | and returns a `TimeStep` namedtuple with fields 11 | 12 | ```none 13 | step_type, reward, discount, observation 14 | ``` 15 | 16 | A **sequence** consists of a series of `TimeStep`s returned by consecutive calls 17 | to `step()`. In many settings we refer to each sequence as an *episode*. Each 18 | sequence starts with a `step_type` of `FIRST`, ends with a `step_type` of 19 | `LAST`, and has a `step_type` of `MID` for all intermediate `TimeStep`s. 20 | 21 | As well as `step()`, each environment implements a `reset()` method. This takes 22 | no arguments, forces the start of a new sequence and returns the first 23 | `TimeStep`. See the [run loop samples](#run-loop-samples) below for more 24 | details. 25 | 26 | Calling `step()` on a new environment instance, or immediately after a 27 | `TimeStep` with a `step_type` of `LAST` is equivalent to calling `reset()`. In 28 | other words, the `action` argument will be ignored and a new sequence will 29 | begin, starting with a `step_type` of `FIRST`. 30 | 31 | NOTE: The `discount` does *not* determine when a sequence ends. The `discount` 32 | may be 0 in the middle of a sequence and ≥0 at the end of a sequence. 33 | 34 | ### Example sequences 35 | 36 | We show two examples of sequences below, along with the first `TimeStep` of the 37 | next sequence. 38 | 39 | Each row corresponds to the tuple returned by an environment's `step()` method. 40 | We use `r`, `ɣ` and `obs` to denote the reward, discount and observation 41 | respectively, `x` to denote a `None` or optional value at a timestep, and `✓` 42 | to denote a value that exists at a timestep. 43 | 44 | Example: A sequence where the end of the *prediction*—the discounted sum of 45 | future rewards that we wish to predict—coincides with the end of the sequence. 46 | i.e., this sequence ends with a discount of 0. Such a sequence could represent a 47 | single episode of a *finite-horizon* RL task. 48 | 49 | ```none 50 | (r, ɣ, obs) | (x, x, ✓) → (✓, ✓, ✓) → (✓, 0, ✓) ⇢ (x, x, ✓) 51 | step_type | FIRST MID LAST FIRST 52 | ``` 53 | 54 | Example: Here the prediction does not terminate at the end of the sequence, 55 | which ends with a nonzero discount. This type of termination is sometimes used 56 | in *infinite-horizon* RL settings. 57 | 58 | ```none 59 | (r, ɣ, obs) | (x, x, ✓) → (✓, ✓, ✓) → (✓, > 0, ✓) ⇢ (x, x, ✓) 60 | step_type | FIRST MID LAST FIRST 61 | ``` 62 | 63 | In general, a discount of `0` does not need to coincide with the end of a 64 | sequence. An environment may return `ɣ = 0` in the middle of a sequence, and 65 | may do this multiple times within a sequence. We do not (typically) call these 66 | sub-sequences episodes. 67 | 68 | The `step_type` can potentially be used by an agent. For instance, some agents 69 | may reset their short-term memory when `step_type` is `LAST`, but not when the 70 | `step_type` is `MID`, even if the discount is `0`. This is up to the 71 | creator of the agent, but it does mean that the aforementioned two ways to 72 | model a termination of the prediction do not necessarily correspond to the same 73 | agent behaviour. 74 | 75 | ## Run loop samples 76 | 77 | Here we show some sample run loops for using an environment with an agent class 78 | that implements a `step(timestep)` method. 79 | 80 | NOTE: Environments do not make any assumptions about the structure of 81 | algorithmic code or agent classes. These examples are illustrative only. 82 | 83 | ### Continuing 84 | 85 | We may call `step()` repeatedly. 86 | 87 | ```python 88 | timestep = env.reset() 89 | while True: 90 | action = agent.step(timestep) 91 | timestep = env.step(action) 92 | 93 | ``` 94 | NOTE: An environment will ignore `action` after a `LAST` step, and return the 95 | `FIRST` step of a new sequence. An agent or algorithm may use the `step_type`, 96 | for example to decide when to reset short-term memory. 97 | 98 | ### Set number of sequences 99 | 100 | We can choose to run a specific number of sequences. Here we use the syntactic 101 | sugar method `.last()` to check whether we are at the end of a sequence. 102 | 103 | ```python 104 | for _ in range(num_sequences): 105 | 106 | timestep = env.reset() 107 | while True: 108 | action = agent.step(timestep) 109 | timestep = env.step(action) 110 | if timestep.last(): 111 | _ = agent.step(timestep) 112 | break 113 | ``` 114 | 115 | A `TimeStep` also has `.first()` and `.mid()` methods. 116 | 117 | ### Manual truncation 118 | 119 | We can truncate a sequence manually at some `step_limit`. 120 | 121 | ```python 122 | step_limit = 100 123 | for _ in range(num_sequences): 124 | 125 | timestep = env.reset() 126 | 127 | step_counter = 1 128 | while True: 129 | action = agent.step(timestep) 130 | timestep = env.step(action) 131 | if step_counter == step_limit: 132 | timestep = timestep._replace(step_type=environment.StepType.LAST) 133 | 134 | if timestep.last(): 135 | _ = agent.step(timestep) 136 | break 137 | 138 | step_counter += 1 139 | ``` 140 | 141 | In this example we've accessed the `step_type` element directly. 142 | 143 | ## The format of observations and actions 144 | 145 | Environments should return observations and accept actions in the form of 146 | [NumPy arrays][numpy_array]. 147 | 148 | An environment may return observations made up of multiple arrays, for example a 149 | list where the first item is an array containing an RGB image and the second 150 | item is an array containing velocities. The arrays may also be values in a 151 | `dict`, or any other structure made up of basic Python containers. Note: A 152 | single array is a perfectly valid format. 153 | 154 | Similarly, actions may be specified as multiple arrays, for example control 155 | signals for distinct parts of a simulated robot. 156 | 157 | Each environment also implements an `observation_spec()` and an `action_spec()` 158 | method. Each method should return a structure of [`Array` specs][specs], 159 | where the structure should correspond exactly to the 160 | format of the actions/observations. 161 | 162 | Each `Array` spec should define the `dtype`, `shape` and, where possible, the 163 | bounds and name of the corresponding action or observation array. 164 | 165 | Note: Actions should almost always specify bounds, e.g. they should use the 166 | [`BoundedArray` spec][specs] subclass. 167 | 168 | [numpy_array]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.array.html 169 | [specs]: ../dm_env/specs.py 170 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The dm_env Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /examples/catch.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Catch reinforcement learning environment.""" 17 | 18 | import dm_env 19 | from dm_env import specs 20 | import numpy as np 21 | 22 | _ACTIONS = (-1, 0, 1) # Left, no-op, right. 23 | 24 | 25 | class Catch(dm_env.Environment): 26 | """A Catch environment built on the `dm_env.Environment` class. 27 | 28 | The agent must move a paddle to intercept falling balls. Falling balls only 29 | move downwards on the column they are in. 30 | 31 | The observation is an array shape (rows, columns), with binary values: 32 | zero if a space is empty; 1 if it contains the paddle or a ball. 33 | 34 | The actions are discrete, and by default there are three available: 35 | stay, move left, and move right. 36 | 37 | The episode terminates when the ball reaches the bottom of the screen. 38 | """ 39 | 40 | def __init__(self, rows: int = 10, columns: int = 5, seed: int = 1): 41 | """Initializes a new Catch environment. 42 | 43 | Args: 44 | rows: number of rows. 45 | columns: number of columns. 46 | seed: random seed for the RNG. 47 | """ 48 | self._rows = rows 49 | self._columns = columns 50 | self._rng = np.random.RandomState(seed) 51 | self._board = np.zeros((rows, columns), dtype=np.float32) 52 | self._ball_x = None 53 | self._ball_y = None 54 | self._paddle_x = None 55 | self._paddle_y = self._rows - 1 56 | self._reset_next_step = True 57 | 58 | def reset(self) -> dm_env.TimeStep: 59 | """Returns the first `TimeStep` of a new episode.""" 60 | self._reset_next_step = False 61 | self._ball_x = self._rng.randint(self._columns) 62 | self._ball_y = 0 63 | self._paddle_x = self._columns // 2 64 | return dm_env.restart(self._observation()) 65 | 66 | def step(self, action: int) -> dm_env.TimeStep: 67 | """Updates the environment according to the action.""" 68 | if self._reset_next_step: 69 | return self.reset() 70 | 71 | # Move the paddle. 72 | dx = _ACTIONS[action] 73 | self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1) 74 | 75 | # Drop the ball. 76 | self._ball_y += 1 77 | 78 | # Check for termination. 79 | if self._ball_y == self._paddle_y: 80 | reward = 1. if self._paddle_x == self._ball_x else -1. 81 | self._reset_next_step = True 82 | return dm_env.termination(reward=reward, observation=self._observation()) 83 | else: 84 | return dm_env.transition(reward=0., observation=self._observation()) 85 | 86 | def observation_spec(self) -> specs.BoundedArray: 87 | """Returns the observation spec.""" 88 | return specs.BoundedArray( 89 | shape=self._board.shape, 90 | dtype=self._board.dtype, 91 | name="board", 92 | minimum=0, 93 | maximum=1, 94 | ) 95 | 96 | def action_spec(self) -> specs.DiscreteArray: 97 | """Returns the action spec.""" 98 | return specs.DiscreteArray( 99 | dtype=int, num_values=len(_ACTIONS), name="action") 100 | 101 | def _observation(self) -> np.ndarray: 102 | self._board.fill(0.) 103 | self._board[self._ball_y, self._ball_x] = 1. 104 | self._board[self._paddle_y, self._paddle_x] = 1. 105 | return self._board.copy() 106 | -------------------------------------------------------------------------------- /examples/catch_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for dm_env.examples.catch.""" 17 | 18 | from absl.testing import absltest 19 | from dm_env import test_utils 20 | from examples import catch 21 | 22 | 23 | class CatchTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 24 | 25 | def make_object_under_test(self): 26 | return catch.Catch() 27 | 28 | 29 | if __name__ == '__main__': 30 | absltest.main() 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | dm-tree==0.1.6 3 | numpy==1.21.6; python_version < '3.8' 4 | numpy==1.22.0; python_version >= '3.8' 5 | pytest==6.2.5 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 The dm_env Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | 17 | """Install script for setuptools.""" 18 | 19 | from importlib import util 20 | from setuptools import find_packages 21 | from setuptools import setup 22 | 23 | 24 | def get_version(): 25 | spec = util.spec_from_file_location('_metadata', 'dm_env/_metadata.py') 26 | mod = util.module_from_spec(spec) 27 | spec.loader.exec_module(mod) 28 | return mod.__version__ 29 | 30 | 31 | setup( 32 | name='dm-env', 33 | version=get_version(), 34 | description='A Python interface for Reinforcement Learning environments.', 35 | author='DeepMind', 36 | license='Apache License, Version 2.0', 37 | keywords='reinforcement-learning python machine learning', 38 | packages=find_packages(exclude=['examples']), 39 | python_requires='>=3.7', 40 | install_requires=[ 41 | 'absl-py', 42 | 'dm-tree', 43 | 'numpy', 44 | ], 45 | tests_require=[ 46 | 'pytest', 47 | ], 48 | classifiers=[ 49 | 'Development Status :: 5 - Production/Stable', 50 | 'Environment :: Console', 51 | 'Intended Audience :: Science/Research', 52 | 'License :: OSI Approved :: Apache Software License', 53 | 'Operating System :: POSIX :: Linux', 54 | 'Operating System :: Microsoft :: Windows', 55 | 'Operating System :: MacOS :: MacOS X', 56 | 'Programming Language :: Python :: 3.7', 57 | 'Programming Language :: Python :: 3.8', 58 | 'Programming Language :: Python :: 3.9', 59 | 'Programming Language :: Python :: 3.10', 60 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 61 | ], 62 | ) 63 | --------------------------------------------------------------------------------