├── .github └── workflows │ └── build-and-test.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── csuite ├── __init__.py ├── csuite_test.py ├── environments │ ├── access_control.py │ ├── access_control_test.py │ ├── base.py │ ├── catch.py │ ├── catch_test.py │ ├── common.py │ ├── dancing_catch.py │ ├── dancing_catch_test.py │ ├── experimental │ │ ├── pendulum_poke.py │ │ └── pendulum_poke_test.py │ ├── pendulum.py │ ├── pendulum_test.py │ ├── taxi.py │ ├── taxi_test.py │ ├── windy_catch.py │ └── windy_catch_test.py └── utils │ ├── dm_env_wrapper.py │ ├── dm_env_wrapper_test.py │ ├── gym_wrapper.py │ └── gym_wrapper_test.py ├── docs ├── Makefile ├── _static │ └── img │ │ ├── taxi_grid.png │ │ └── taxi_pickup.png ├── api.md ├── conf.py ├── environments │ ├── access_control.md │ ├── catch.md │ ├── pendulum.md │ └── taxi.md ├── index.rst └── requirements.txt ├── readthedocs.yml ├── requirements-test.txt ├── requirements.txt └── setup.py /.github/workflows/build-and-test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Build and Test 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 10 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | python-version: ["3.9", "3.10"] 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | python -m pip install flake8 pytest 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | - name: Lint with flake8 34 | run: | 35 | # stop the build if there are Python syntax errors or undefined names 36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | - name: Test with pytest 40 | run: | 41 | pytest 42 | 43 | concurrency: 44 | group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} 45 | cancel-in-progress: true 46 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continuing Environments for Reinforcement Learning (`csuite`) 2 | 3 | 4 | CSuite is a collection of carefully-curated synthetic environments for 5 | research in the continuing setting: the agent-environment interaction goes on 6 | forever without limit, with no natural episode boundaries. 7 | 8 | ## Installation 9 | 10 | Clone the source code into a local directory and install using pip: 11 | 12 | ```sh 13 | git clone https://github.com/deepmind/csuite.git /path/to/local/csuite/ 14 | pip install /path/to/local/csuite/ 15 | ``` 16 | 17 | `csuite` is not yet available from PyPI. 18 | 19 | ## Environment Interface 20 | 21 | CSuite environments adhere to the Python interface defined in `csuite/environment/base.py`. 22 | Find the interface [documentation here](https://rl-csuite.readthedocs.io/en/latest/api.html). 23 | 24 | ```python 25 | import csuite 26 | 27 | env = csuite.load("catch") 28 | action_spec = env.action_spec() 29 | observation = env.start() 30 | print("First observation:\n", observation) 31 | 32 | total_reward = 0 33 | for _ in range(100): 34 | observation, reward = env.step(action_spec.generate_value()) 35 | total_reward += reward 36 | 37 | print("Total reward:", total_reward) 38 | ``` 39 | 40 | ### Using `csuite` with dm_env interface 41 | 42 | For a codebase that uses the [`dm_env`](https://github.com/deepmind/dm_env) interface, use the `DMEnvFromCSuite` wrapper class: 43 | 44 | ```python 45 | import csuite 46 | 47 | env = csuite.dm_env_wrapper.DMEnvFromCSuite(csuite.load("catch")) 48 | action_spec = env.action_spec() 49 | 50 | timestep = env.reset() 51 | print("First observation:\n", timestep.observation) 52 | 53 | total_reward = 0 54 | for _ in range(100): 55 | timestep = env.step(action_spec.generate_value()) 56 | total_reward += timestep.reward 57 | 58 | print("Total reward:", total_reward) 59 | ``` 60 | 61 | -------------------------------------------------------------------------------- /csuite/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper(s) to load csuite environments.""" 17 | 18 | import enum 19 | from typing import Dict, Optional, Union 20 | 21 | from csuite.environments import access_control 22 | from csuite.environments import catch 23 | from csuite.environments import dancing_catch 24 | from csuite.environments import pendulum 25 | from csuite.environments import taxi 26 | from csuite.environments import windy_catch 27 | from csuite.environments.base import Environment 28 | from csuite.environments.experimental import pendulum_poke 29 | from csuite.utils import dm_env_wrapper 30 | from csuite.utils import gym_wrapper 31 | 32 | 33 | class EnvName(enum.Enum): 34 | ACCESS_CONTROL = 'access_control' 35 | CATCH = 'catch' 36 | DANCING_CATCH = 'dancing_catch' 37 | PENDULUM = 'pendulum' 38 | PENDULUM_POKE = 'pendulum_poke' 39 | TAXI = 'taxi' 40 | WINDY_CATCH = 'windy_catch' 41 | 42 | 43 | _ENVS = { 44 | EnvName.ACCESS_CONTROL: access_control.AccessControl, 45 | EnvName.CATCH: catch.Catch, 46 | EnvName.DANCING_CATCH: dancing_catch.DancingCatch, 47 | EnvName.WINDY_CATCH: windy_catch.WindyCatch, 48 | EnvName.TAXI: taxi.Taxi, 49 | EnvName.PENDULUM: pendulum.Pendulum, 50 | EnvName.PENDULUM_POKE: pendulum_poke.PendulumPoke, 51 | } 52 | 53 | 54 | def load(name: Union[EnvName, str], 55 | settings: Optional[Dict[str, Union[float, int, bool]]] = None): 56 | """Loads a csuite environment. 57 | 58 | Args: 59 | name: The enum or string specifying the environment name. 60 | settings: Optional `dict` of keyword arguments for the environment. 61 | 62 | Returns: 63 | An instance of the requested environment. 64 | """ 65 | name = EnvName(name) 66 | settings = settings or {} 67 | return _ENVS[name](**settings) 68 | -------------------------------------------------------------------------------- /csuite/csuite_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests all environments through the csuite interface.""" 17 | 18 | import typing 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | 23 | import csuite 24 | from dm_env import specs 25 | import numpy as np 26 | 27 | 28 | class CSuiteTest(parameterized.TestCase): 29 | 30 | @parameterized.parameters([e.value for e in csuite.EnvName]) 31 | def test_envs(self, env_name): 32 | """Tests that we can use the environment in typical ways.""" 33 | env = csuite.load(env_name) 34 | action_spec = env.action_spec() 35 | observation_spec = env.observation_spec() 36 | 37 | obs = env.start() 38 | env.render() 39 | init_state = env.get_state() 40 | 41 | for i in range(2): 42 | with self.subTest(name="steps-render-successful", step=i): 43 | image = env.render() 44 | with self.subTest(name="steps-render-compliant", step=i): 45 | self._assert_image_compliant(image) 46 | with self.subTest(name="steps-observation_spec", step=i): 47 | observation_spec.validate(obs) 48 | with self.subTest(name="steps-step", step=i): 49 | obs, unused_reward = env.step(action_spec.generate_value()) 50 | 51 | env.set_state(init_state) 52 | 53 | @parameterized.parameters([e.value for e in csuite.EnvName]) 54 | def test_env_state_resets(self, env_name): 55 | """Tests that `get`ing and `set`ing state results in reproducibility.""" 56 | # Since each environment is different, we employ a generic strategy that 57 | # should 58 | # 59 | # a) get us to a variety of states to query the state on, 60 | # b) take a number of steps from that state to check reproducibility. 61 | # 62 | # See the implementation for the specific strategy taken. 63 | num_steps_to_check = 4 64 | env = csuite.load(env_name) 65 | env.start() 66 | action_spec = env.action_spec() 67 | 68 | if not isinstance(action_spec, specs.DiscreteArray): 69 | raise NotImplementedError( 70 | "This test only supports environments with discrete action " 71 | "spaces for now. Please raise an issue if you want to work with a " 72 | "a non-discrete action space.") 73 | 74 | action_spec = typing.cast(specs.DiscreteArray, action_spec) 75 | for action in range(action_spec.num_values): 76 | env.step(action) 77 | orig_state = env.get_state() 78 | outputs_1 = [env.step(action) for _ in range(num_steps_to_check)] 79 | observations_1, rewards_1 = zip(*outputs_1) 80 | env.set_state(orig_state) 81 | outputs_2 = [env.step(action) for _ in range(num_steps_to_check)] 82 | observations_2, rewards_2 = zip(*outputs_2) 83 | with self.subTest("observations", action=action): 84 | self.assertSameObsSequence(observations_1, observations_2) 85 | with self.subTest("rewards", action=action): 86 | self.assertSequenceEqual(rewards_1, rewards_2) 87 | 88 | def assertSameObsSequence(self, seq1, seq2): 89 | """The observations are expected to be numpy objects.""" 90 | # self.assertSameStructure(seq1, seq2) 91 | problems = [] # (idx, problem str) 92 | for idx, (el1, el2) in enumerate(zip(seq1, seq2)): 93 | try: 94 | np.testing.assert_array_equal(el1, el2) 95 | except AssertionError as e: 96 | problems.append((idx, str(e))) 97 | 98 | if problems: 99 | self.fail( 100 | f"The observation sequences (of length {len(seq1)}) are not the " 101 | "same. The differences are:\n" + 102 | "\n".join([f"at idx={idx}: {msg}" for idx, msg in problems])) 103 | 104 | def _assert_image_compliant(self, image: np.ndarray): 105 | if not (len(image.shape) == 3 and image.shape[-1] == 3 and 106 | image.dtype == np.uint8): 107 | self.fail( 108 | "The render() method is expected to return an uint8 rgb image array. " 109 | f"Got an array of shape {image.shape}, dtype {image.dtype}.") 110 | 111 | 112 | if __name__ == "__main__": 113 | absltest.main() 114 | -------------------------------------------------------------------------------- /csuite/environments/access_control.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implementation of the tabular Access-Control environment. 17 | 18 | Environment description and details can be found in the `AccessControl` 19 | environment class. 20 | """ 21 | 22 | import copy 23 | import dataclasses 24 | import enum 25 | import itertools 26 | from typing import Optional 27 | 28 | from csuite.environments import base 29 | from csuite.environments import common 30 | from dm_env import specs 31 | 32 | import numpy as np 33 | 34 | # Default environment variables from Sutton&Barto Example 10.2. 35 | _NUM_SERVERS = 10 36 | _FREE_PROBABILITY = 0.06 37 | _PRIORITIES = (1, 2, 4, 8) 38 | 39 | # Error messages. 40 | _INVALID_ACTION = "Invalid action: expected 0 or 1 but received {action}." 41 | _INVALID_BUSY_SERVERS = ("Invalid state: num_busy_servers not in expected" 42 | "range [0, {}).") 43 | _INVALID_PRIORITY = ("Invalid state: incoming_priority not in expected" 44 | "range {}.") 45 | 46 | 47 | class Action(enum.IntEnum): 48 | REJECT = 0 49 | ACCEPT = 1 50 | 51 | 52 | @dataclasses.dataclass 53 | class Params: 54 | """Parameters of an Access-Control instance. 55 | 56 | Attributes: 57 | num_servers: A positive integer, denoting the total number of available 58 | servers. 59 | free_probability: A positive float, denoting the probability a busy server 60 | becomes free at each timestep. 61 | priorities: A list of floats, giving the possible priorities of incoming 62 | customers. 63 | """ 64 | num_servers: int 65 | free_probability: float 66 | priorities: list[float] 67 | 68 | 69 | @dataclasses.dataclass 70 | class State: 71 | """State of an Access-Control continuing environment. 72 | 73 | Let N be the number of servers. 74 | 75 | Attributes: 76 | num_busy_servers: An integer in the range [0, N] representing the number of 77 | busy servers. 78 | incoming_priority: An integer giving the priority of the incoming customer. 79 | rng: Internal NumPy pseudo-random number generator, included here for 80 | reproducibility purposes. 81 | """ 82 | num_busy_servers: int 83 | incoming_priority: int 84 | rng: np.random.RandomState 85 | 86 | 87 | class AccessControl(base.Environment): 88 | """An Access-Control continuing environment. 89 | 90 | Given access to a set of servers and an infinite queue of customers 91 | with different priorities, the agent must decide whether to accept 92 | or reject the next customer in line based on their priority and the 93 | number of free servers. 94 | 95 | There are two actions: accept or decline the incoming customer. Note that 96 | if no servers are available, the customer is declined regardless of the 97 | action selected. 98 | 99 | The observation is a single state index, enumerating the possible states 100 | (num_busy_servers, incoming_priority). 101 | 102 | The default environment follows that described in Sutton and Barto's book 103 | (Example 10.2 in the second edition). 104 | """ 105 | 106 | def __init__(self, 107 | num_servers=_NUM_SERVERS, 108 | free_probability=_FREE_PROBABILITY, 109 | priorities=_PRIORITIES, 110 | seed=None): 111 | """Initialize Access-Control environment. 112 | 113 | Args: 114 | num_servers: A positive integer, denoting the total number of available 115 | servers. 116 | free_probability: A positive float, denoting the probability a busy server 117 | becomes free at each timestep. 118 | priorities: A list of floats, giving the possible priorities of incoming 119 | customers. 120 | seed: Seed for the internal random number generator. 121 | """ 122 | self._seed = seed 123 | self._params = Params( 124 | num_servers=num_servers, 125 | free_probability=free_probability, 126 | priorities=priorities) 127 | self.num_states = ((self._params.num_servers + 1) * 128 | len(self._params.priorities)) 129 | 130 | # Populate lookup table for observations. 131 | self.lookup_table = {} 132 | for idx, state in enumerate( 133 | itertools.product( 134 | range(self._params.num_servers + 1), self._params.priorities)): 135 | self.lookup_table[state] = idx 136 | 137 | self._state = None 138 | self._last_action = -1 # Only used for visualization. 139 | 140 | def start(self, seed: Optional[int] = None): 141 | """Initializes the environment and returns an initial observation.""" 142 | rng = np.random.RandomState(self._seed if seed is None else seed) 143 | self._state = State( 144 | num_busy_servers=0, 145 | incoming_priority=rng.choice(self._params.priorities), 146 | rng=rng) 147 | return self._get_observation() 148 | 149 | @property 150 | def started(self): 151 | """True if the environment has been started, False otherwise.""" 152 | # An unspecified state implies that the environment needs to be started. 153 | return self._state is not None 154 | 155 | def step(self, action): 156 | """Updates the environment state and returns an observation and reward. 157 | 158 | Args: 159 | action: An integer equalling 0 or 1 to reject or accept the customer. 160 | 161 | Returns: 162 | A tuple of type (int, float) giving the next observation and the reward. 163 | 164 | Raises: 165 | RuntimeError: If state has not yet been initialized by `start`. 166 | ValueError: If input action has an invalid value. 167 | """ 168 | # Check if state has been initialized. 169 | if not self.started: 170 | raise RuntimeError(base.STEP_WITHOUT_START_ERR) 171 | 172 | # Check if input action is valid. 173 | if action not in [Action.REJECT, Action.ACCEPT]: 174 | raise ValueError(_INVALID_ACTION.format(action=action)) 175 | 176 | self._last_action = action 177 | 178 | reward = 0 179 | # If customer is accepted, ensure there are enough free servers. 180 | if (action == Action.ACCEPT and 181 | self._state.num_busy_servers < self._params.num_servers): 182 | reward = self._state.incoming_priority 183 | self._state.num_busy_servers += 1 184 | 185 | new_priority = self._state.rng.choice(self._params.priorities) 186 | 187 | # Update internal state by freeing busy servers with a given probability. 188 | num_busy_servers = self._state.num_busy_servers 189 | num_new_free_servers = self._state.rng.binomial( 190 | num_busy_servers, p=self._params.free_probability) 191 | self._state.num_busy_servers = num_busy_servers - num_new_free_servers 192 | self._state.incoming_priority = new_priority 193 | 194 | return self._get_observation(), reward 195 | 196 | def _get_observation(self): 197 | """Converts internal state to an index uniquely identifying the state. 198 | 199 | Returns: 200 | An integer denoting the current state's index according to the 201 | enumeration of the state space stored by the environment's lookup table. 202 | """ 203 | state_key = (self._state.num_busy_servers, self._state.incoming_priority) 204 | return self.lookup_table[state_key] 205 | 206 | def observation_spec(self): 207 | """Describes the observation specs of the environment.""" 208 | return specs.DiscreteArray(self.num_states, dtype=int, name="observation") 209 | 210 | def action_spec(self): 211 | """Describes the action specs of the environment.""" 212 | return specs.DiscreteArray(2, dtype=int, name="action") 213 | 214 | def get_state(self): 215 | """Returns a copy of the current environment state.""" 216 | return copy.deepcopy(self._state) if self._state is not None else None 217 | 218 | def set_state(self, state): 219 | """Sets environment state to state provided. 220 | 221 | Args: 222 | state: A State object which overrides the current state. 223 | """ 224 | # Check that input state values are valid. 225 | if state.num_busy_servers not in range(self._params.num_servers + 1): 226 | raise ValueError(_INVALID_BUSY_SERVERS.format(self._params.num_servers)) 227 | elif state.incoming_priority not in self._params.priorities: 228 | raise ValueError(_INVALID_PRIORITY.format(self._params.priorities)) 229 | 230 | self._state = copy.deepcopy(state) 231 | 232 | def get_config(self): 233 | """Returns a copy of the environment configuration.""" 234 | return copy.deepcopy(self._params) 235 | 236 | def render(self): 237 | board = np.ones((len(_PRIORITIES), _NUM_SERVERS + 1), dtype=np.uint8) 238 | priority = _PRIORITIES.index(self._state.incoming_priority) 239 | busy_num = self._state.num_busy_servers 240 | board[priority, busy_num] = 0 241 | rgb_array = common.binary_board_to_rgb(board) 242 | 243 | if self._last_action == Action.ACCEPT: 244 | rgb_array[priority, busy_num, 1] = 1 # Green. 245 | elif self._last_action == Action.REJECT: 246 | rgb_array[priority, busy_num, 0] = 1 # Red. 247 | else: 248 | # Will remain black. 249 | assert self._last_action == -1, "Only other possible value." 250 | 251 | return rgb_array 252 | -------------------------------------------------------------------------------- /csuite/environments/access_control_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for access_control.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from csuite.environments import access_control 21 | 22 | 23 | class AccessControlTest(parameterized.TestCase): 24 | 25 | def test_environment_setup(self): 26 | """Tests environment initialization.""" 27 | access_control.AccessControl() 28 | 29 | def test_start(self): 30 | """Tests environment start.""" 31 | env = access_control.AccessControl() 32 | params = env.get_config() 33 | 34 | with self.subTest(name='step_without_start'): 35 | # Calling step before start should raise an error. 36 | with self.assertRaises(RuntimeError): 37 | env.step(access_control.Action.REJECT) 38 | 39 | with self.subTest(name='start_state'): 40 | _ = env.start() 41 | state = env.get_state() 42 | self.assertEqual(state.num_busy_servers, 0) 43 | self.assertIn(state.incoming_priority, params.priorities) 44 | 45 | def test_invalid_state(self): 46 | """Tests setting environment state with invalid fields.""" 47 | env = access_control.AccessControl() 48 | _ = env.start() 49 | cur_state = env.get_state() 50 | 51 | with self.subTest(name='invalid_state'): 52 | cur_state.num_busy_servers = 5 53 | cur_state.incoming_priority = -1 54 | with self.assertRaises(ValueError): 55 | env.set_state(cur_state) 56 | 57 | with self.subTest(name='invalid_priority'): 58 | cur_state.num_busy_servers = -1 59 | cur_state.incoming_priority = 8 60 | with self.assertRaises(ValueError): 61 | env.set_state(cur_state) 62 | 63 | @parameterized.parameters(0, 1, 9, 10) 64 | def test_one_step(self, new_num_busy_servers): 65 | """Tests environment step.""" 66 | env = access_control.AccessControl() 67 | _ = env.start() 68 | params = env.get_config() 69 | 70 | with self.subTest(name='invalid_action'): 71 | with self.assertRaises(ValueError): 72 | env.step(5) 73 | 74 | with self.subTest(name='reject_step'): 75 | # Change the number of busy servers in the environment state. 76 | current_state = env.get_state() 77 | current_state.num_busy_servers = new_num_busy_servers 78 | env.set_state(current_state) 79 | 80 | next_obs, reward = env.step(access_control.Action.REJECT) 81 | state = env.get_state() 82 | # Next observation should give a valid state index, reward is zero, and 83 | # number of busy servers has not increased. 84 | self.assertIn(next_obs, range(env.num_states)) 85 | self.assertEqual(reward, 0) 86 | self.assertLessEqual(state.num_busy_servers, new_num_busy_servers) 87 | 88 | with self.subTest(name='accept_step'): 89 | # Change current state to new number of busy servers. 90 | current_state = env.get_state() 91 | current_state.num_busy_servers = new_num_busy_servers 92 | env.set_state(current_state) 93 | 94 | next_obs, reward = env.step(access_control.Action.ACCEPT) 95 | state = env.get_state() 96 | 97 | if new_num_busy_servers == params.num_servers: # all servers busy. 98 | # Reward is zero even if agent tries accepting, and number of busy 99 | # servers does not increase over the total available. 100 | self.assertEqual(reward, 0) 101 | self.assertLessEqual(state.num_busy_servers, new_num_busy_servers) 102 | else: 103 | # Reward is incoming priority, and the number of busy servers can 104 | # increase by one. 105 | self.assertIn(reward, params.priorities) 106 | self.assertLessEqual(state.num_busy_servers, new_num_busy_servers + 1) 107 | 108 | def test_runs_from_start(self): 109 | """Creates an environment and runs for 10 steps.""" 110 | env = access_control.AccessControl() 111 | _ = env.start() 112 | 113 | for _ in range(10): 114 | _, _ = env.step(access_control.Action.ACCEPT) 115 | 116 | 117 | if __name__ == '__main__': 118 | absltest.main() 119 | -------------------------------------------------------------------------------- /csuite/environments/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Abstract base class for csuite environments.""" 17 | 18 | import abc 19 | from typing import Any, Optional, Tuple 20 | 21 | from dm_env import specs 22 | import numpy as np 23 | 24 | # TODO(b/243715530): The base environment should implementing this check. 25 | STEP_WITHOUT_START_ERR = ("Environment state has not been initialized. `start`" 26 | " must be called before calling `step`.") 27 | 28 | 29 | class Environment(abc.ABC): 30 | """Base class for continuing environments. 31 | 32 | Observations and valid actions are described by the `specs` module in dm_env. 33 | Environment implementations should return specs as specific as possible. 34 | 35 | Each environment will specify its own environment State, Configuration, and 36 | internal random number generator. 37 | """ 38 | 39 | @abc.abstractmethod 40 | def start(self, seed: Optional[int] = None) -> Any: 41 | """Starts (or restarts) the environment and returns an observation.""" 42 | 43 | @abc.abstractmethod 44 | def step(self, action: Any) -> Tuple[Any, Any]: 45 | """Takes a step in the environment, returning an observation and reward.""" 46 | 47 | @abc.abstractmethod 48 | def observation_spec(self) -> specs.Array: 49 | """Describes the observation space of the environment. 50 | 51 | May use a subclass of `specs.Array` that specifies additional properties 52 | such as min and max bounds on the values. 53 | """ 54 | 55 | @abc.abstractmethod 56 | def action_spec(self) -> specs.Array: 57 | """Describes the valid action space of the environment. 58 | 59 | May use a subclass of `specs.Array` that specifies additional properties 60 | such as min and max bounds on the values. 61 | """ 62 | 63 | @abc.abstractmethod 64 | def get_state(self) -> Any: 65 | """Returns the environment state.""" 66 | 67 | @abc.abstractmethod 68 | def set_state(self, state: Any): 69 | """Sets the environment state.""" 70 | 71 | @abc.abstractmethod 72 | def render(self) -> np.ndarray: 73 | """Returns an rgb (uint8) numpy array to facilitate visualization. 74 | 75 | The shape of this array should be (width, height, 3), where the last 76 | dimension is for red, green, and blue. The values are in [0, 255]. 77 | """ 78 | -------------------------------------------------------------------------------- /csuite/environments/catch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implementation of a continuing Catch environment. 17 | 18 | Environment description can be found in the `Catch` 19 | environment class. 20 | """ 21 | import copy 22 | import dataclasses 23 | import enum 24 | from typing import Optional 25 | 26 | from csuite.environments import base 27 | from csuite.environments import common 28 | from dm_env import specs 29 | 30 | import numpy as np 31 | 32 | # Error messages. 33 | _INVALID_ACTION = "Invalid action: expected 0, 1, or 2 but received {action}." 34 | _INVALID_PADDLE_POS = ("Invalid state: paddle should be positioned at the" 35 | " bottom of the board.") 36 | _INVALID_BALLS_RANGE = ( 37 | "Invalid state: positions of balls and paddle not in expected" 38 | " row range [0, {rows}) and column range [0, {columns}).") 39 | 40 | # Default environment variables. 41 | _ROWS = 10 42 | _COLUMNS = 5 43 | _SPAWN_PROBABILITY = 0.1 44 | 45 | 46 | class Action(enum.IntEnum): 47 | LEFT = 0 48 | STAY = 1 49 | RIGHT = 2 50 | 51 | @property 52 | def dx(self): 53 | """Maps LEFT to -1, STAY to 0 and RIGHT to 1.""" 54 | return self.value - 1 55 | 56 | 57 | @dataclasses.dataclass 58 | class Params: 59 | """Parameters of a continuing Catch instance. 60 | 61 | Attributes: 62 | rows: Integer number of rows. 63 | columns: Integer number of columns. 64 | spawn_probability: Probability of a new ball spawning. 65 | """ 66 | rows: int 67 | columns: int 68 | spawn_probability: float 69 | 70 | 71 | @dataclasses.dataclass 72 | class State: 73 | """State of a continuing Catch instance. 74 | 75 | Attributes: 76 | paddle_x: An integer denoting the x-coordinate of the paddle. 77 | paddle_y: An integer denoting the y-coordinate of the paddle 78 | balls: A list of (x, y) coordinates representing the present balls. 79 | rng: Internal NumPy pseudo-random number generator, included here for 80 | reproducibility purposes. 81 | """ 82 | paddle_x: int 83 | paddle_y: int 84 | balls: list[tuple[int, int]] 85 | rng: np.random.Generator 86 | 87 | 88 | class Catch(base.Environment): 89 | """A continuing Catch environment. 90 | 91 | The agent must control a breakout-like paddle to catch as many falling balls 92 | as possible. Falling balls move strictly down in their column. In this 93 | continuing version, a new ball can spawn at the top with a low probability 94 | at each timestep. A new ball will always spawn when a ball falls to the 95 | bottom of the board. At most one ball is added at each timestep. A reward of 96 | +1 is given when the paddle successfully catches a ball and a reward of -1 is 97 | given when the paddle fails to catch a ball. The reward is 0 otherwise. 98 | 99 | There are three discrete actions: move left, move right, and stay. 100 | 101 | The observation is a binary array with shape (rows, columns) with entry one 102 | if it contains the paddle or a ball, and zero otherwise. 103 | """ 104 | 105 | def __init__(self, 106 | rows=_ROWS, 107 | columns=_COLUMNS, 108 | spawn_probability=_SPAWN_PROBABILITY, 109 | seed=None): 110 | """Initializes a continuing Catch environment. 111 | 112 | Args: 113 | rows: A positive integer denoting the number of rows. 114 | columns: A positive integer denoting the number of columns. 115 | spawn_probability: Float giving the probability of a new ball appearing. 116 | seed: Seed for the internal random number generator. 117 | """ 118 | self._seed = seed 119 | self._params = Params( 120 | rows=rows, columns=columns, spawn_probability=spawn_probability) 121 | self._state = None 122 | 123 | def start(self, seed: Optional[int] = None): 124 | """Initializes the environment and returns an initial observation.""" 125 | 126 | # The initial state has one ball appearing in a random column at the top, 127 | # and the paddle centered at the bottom. 128 | rng = np.random.default_rng(self._seed if seed is None else seed) 129 | self._state = State( 130 | paddle_x=self._params.columns // 2, 131 | paddle_y=self._params.rows - 1, 132 | balls=[(rng.integers(self._params.columns), 0)], 133 | rng=rng, 134 | ) 135 | return self._get_observation() 136 | 137 | @property 138 | def started(self): 139 | """True if the environment has been started, False otherwise.""" 140 | # An unspecified state implies that the environment needs to be started. 141 | return self._state is not None 142 | 143 | def step(self, action): 144 | """Updates the environment state and returns an observation and reward. 145 | 146 | Args: 147 | action: An integer equalling 0, 1, or 2 indicating whether to move the 148 | paddle left, stay, or move the paddle right respectively. 149 | 150 | Returns: 151 | A tuple of type (int, float) giving the next observation and the reward. 152 | 153 | Raises: 154 | RuntimeError: If state has not yet been initialized by `start`. 155 | """ 156 | # Check if state has been initialized. 157 | if not self.started: 158 | raise RuntimeError(base.STEP_WITHOUT_START_ERR) 159 | 160 | # Check if input action is valid. 161 | if action not in [Action.LEFT, Action.STAY, Action.RIGHT]: 162 | raise ValueError(_INVALID_ACTION.format(action=action)) 163 | 164 | # Move the paddle. 165 | self._state.paddle_x = np.clip( 166 | self._state.paddle_x + Action(action).dx, 0, self._params.columns - 1) 167 | 168 | # Move all balls down by one unit. 169 | self._state.balls = [(x, y + 1) for x, y in self._state.balls] 170 | 171 | # Since at most one ball is added at each timestep, at most one ball 172 | # can be at the bottom of the board, and must be the 'oldest' ball. 173 | reward = 0. 174 | if self._state.balls and self._state.balls[0][1] == self._state.paddle_y: 175 | if self._state.balls[0][0] == self._state.paddle_x: 176 | reward = 1. 177 | else: 178 | reward = -1. 179 | # Remove ball from list. 180 | self._state.balls = self._state.balls[1:] 181 | 182 | # Add new ball with given probability. 183 | if self._state.rng.random() < self._params.spawn_probability: 184 | self._state.balls.append( 185 | (self._state.rng.integers(self._params.columns), 0)) 186 | 187 | return self._get_observation(), reward 188 | 189 | def _get_observation(self) -> np.ndarray: 190 | """Converts internal environment state to an array observation. 191 | 192 | Returns: 193 | A binary array of size (rows, columns) with entry 1 if it contains either 194 | a ball or a paddle, and entry 0 if the cell is empty. 195 | """ 196 | board = np.zeros((_ROWS, _COLUMNS), dtype=int) 197 | board.fill(0) 198 | board[self._state.paddle_y, self._state.paddle_x] = 1 199 | for x, y in self._state.balls: 200 | board[y, x] = 1 201 | return board 202 | 203 | def observation_spec(self): 204 | """Describes the observation specs of the environment.""" 205 | return specs.BoundedArray( 206 | shape=(self._params.rows, self._params.columns), 207 | dtype=int, 208 | minimum=0, 209 | maximum=1, 210 | name="board") 211 | 212 | def action_spec(self): 213 | """Describes the action specs of the environment.""" 214 | return specs.DiscreteArray(num_values=3, dtype=int, name="action") 215 | 216 | def get_state(self): 217 | """Returns a copy of the current environment state.""" 218 | return copy.deepcopy(self._state) if self._state is not None else None 219 | 220 | def set_state(self, state): 221 | """Sets environment state to state provided. 222 | 223 | Args: 224 | state: A State object which overrides the current state. 225 | """ 226 | # Check that input state values are valid. 227 | if not (0 <= state.paddle_x < self._params.columns and 228 | state.paddle_y == self._params.rows - 1): 229 | raise ValueError(_INVALID_PADDLE_POS) 230 | 231 | for x, y in state.balls: 232 | if not (0 <= x < self._params.columns and 0 <= y < self._params.rows): 233 | raise ValueError( 234 | _INVALID_BALLS_RANGE.format( 235 | rows=self._params.rows, columns=self._params.columns)) 236 | 237 | self._state = copy.deepcopy(state) 238 | 239 | def get_config(self): 240 | """Returns a copy of the environment configuration.""" 241 | return copy.deepcopy(self._params) 242 | 243 | def render(self) -> np.ndarray: 244 | return common.binary_board_to_rgb(self._get_observation()) 245 | -------------------------------------------------------------------------------- /csuite/environments/catch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for catch.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from csuite.environments import catch 21 | 22 | 23 | class CatchTest(parameterized.TestCase): 24 | 25 | def test_environment_setup(self): 26 | """Tests environment initialization.""" 27 | env = catch.Catch() 28 | self.assertIsNotNone(env) 29 | 30 | def test_start(self): 31 | """Tests environment start.""" 32 | env = catch.Catch() 33 | params = env.get_config() 34 | 35 | with self.subTest(name='step_without_start'): 36 | # Calling step before start should raise an error. 37 | with self.assertRaises(RuntimeError): 38 | env.step(catch.Action.LEFT) 39 | 40 | with self.subTest(name='start_state'): 41 | start_obs = env.start() 42 | state = env.get_state() 43 | # Paddle should be positioned at the bottom of the board. 44 | self.assertEqual(state.paddle_y, params.rows - 1) 45 | self.assertEqual(start_obs[state.paddle_y, state.paddle_x], 1) 46 | 47 | # First ball should be positioned at the top of the board. 48 | ball_x = state.balls[0][0] 49 | ball_y = state.balls[0][1] 50 | self.assertEqual(ball_y, 0) 51 | self.assertEqual(start_obs[ball_y, ball_x], 1) 52 | 53 | def test_invalid_state(self): 54 | """Tests setting environment state with invalid fields.""" 55 | env = catch.Catch() 56 | env.start() 57 | 58 | with self.subTest(name='paddle_out_of_range'): 59 | new_state = env.get_state() 60 | new_state.paddle_x = 5 61 | with self.assertRaises(ValueError): 62 | env.set_state(new_state) 63 | 64 | with self.subTest(name='balls_out_of_range'): 65 | new_state = env.get_state() 66 | new_state.balls = [(0, -1)] 67 | with self.assertRaises(ValueError): 68 | env.set_state(new_state) 69 | 70 | @parameterized.parameters((0, 0, 1), (2, 1, 3), (4, 3, 4)) 71 | def test_one_step(self, paddle_x, expected_left_x, expected_right_x): 72 | """Tests one environment step given the x-position of the paddle.""" 73 | env = catch.Catch() 74 | env.start() 75 | 76 | with self.subTest(name='invalid_action'): 77 | with self.assertRaises(ValueError): 78 | env.step(3) 79 | 80 | with self.subTest(name='move_left_step'): 81 | current_state = env.get_state() 82 | current_state.paddle_x = paddle_x 83 | env.set_state(current_state) 84 | 85 | env.step(catch.Action.LEFT) 86 | state = env.get_state() 87 | 88 | # Paddle x-position should have moved left by 1 unless at the edge. 89 | self.assertEqual(state.paddle_x, expected_left_x) 90 | 91 | with self.subTest(name='move_right_step'): 92 | current_state = env.get_state() 93 | current_state.paddle_x = paddle_x 94 | env.set_state(current_state) 95 | 96 | env.step(catch.Action.RIGHT) 97 | state = env.get_state() 98 | 99 | # Paddle x-position should have moved right by 1 unless at the edge. 100 | self.assertEqual(state.paddle_x, expected_right_x) 101 | 102 | with self.subTest(name='stay_step'): 103 | current_state = env.get_state() 104 | current_state.paddle_x = paddle_x 105 | env.set_state(current_state) 106 | 107 | env.step(catch.Action.STAY) 108 | state = env.get_state() 109 | self.assertEqual(state.paddle_x, paddle_x) 110 | 111 | def test_ball_hitting_bottom(self): 112 | """Tests environment updates when a ball hits the bottom of board.""" 113 | env = catch.Catch() 114 | env.start() 115 | params = env.get_config() 116 | cur_state = env.get_state() 117 | 118 | with self.subTest(name='no_collision_with_paddle'): 119 | # Set environment state to immediately before ball falls to the bottom. 120 | cur_state.paddle_x = 0 121 | cur_state.paddle_y = params.rows - 1 122 | cur_state.balls = [(2, params.rows - 2)] 123 | env.set_state(cur_state) 124 | _, reward = env.step(catch.Action.STAY) 125 | 126 | # Reward returned should equal -1. 127 | self.assertEqual(reward, -1) 128 | 129 | with self.subTest(name='collision_with_paddle'): 130 | # Set environment state to immediately before ball falls to the bottom. 131 | cur_state.paddle_x = 2 132 | cur_state.paddle_y = params.rows - 1 133 | cur_state.balls = [(2, params.rows - 2)] 134 | env.set_state(cur_state) 135 | _, reward = env.step(catch.Action.STAY) 136 | 137 | # Reward returned should equal 1. 138 | self.assertEqual(reward, 1) 139 | 140 | def test_catching_one_ball_from_start(self): 141 | """Test running from environment start for the duration of one ball falling.""" 142 | env = catch.Catch() 143 | env.start() 144 | params = env.get_config() 145 | cur_state = env.get_state() 146 | 147 | # Set environment state such that ball and paddle are horizontally centered 148 | # and the ball is at the top of the board. 149 | cur_state.paddle_x = 2 150 | cur_state.paddle_y = params.rows - 1 151 | cur_state.balls = [(2, 0)] 152 | env.set_state(cur_state) 153 | 154 | # For eight steps, alternate between moving left and right. 155 | for _ in range(4): 156 | # Here reward should equal 0. 157 | _, reward = env.step(catch.Action.RIGHT) 158 | self.assertEqual(reward, 0) 159 | _, reward = env.step(catch.Action.LEFT) 160 | self.assertEqual(reward, 0) 161 | 162 | # For the last step, choose to stay - ball should fall on paddle 163 | # and reward should equal 1. 164 | _, reward = env.step(catch.Action.STAY) 165 | self.assertEqual(reward, 1) 166 | 167 | def test_catching_two_balls_from_start(self): 168 | """Test running environment for the duration of two balls falling.""" 169 | env = catch.Catch() 170 | env.start() 171 | params = env.get_config() 172 | cur_state = env.get_state() 173 | 174 | # Set environment state such that there are two balls at the top and the 175 | # second row of the board, and paddle is horizontally centered. 176 | cur_state.paddle_x = 2 177 | cur_state.paddle_y = params.rows - 1 178 | cur_state.balls = [(0, 1), (2, 0)] 179 | env.set_state(cur_state) 180 | 181 | # For eight steps, repeatedly move left - ball in second row should fall on 182 | # paddle. 183 | for _ in range(7): 184 | # Here reward should equal 0. 185 | _, reward = env.step(catch.Action.LEFT) 186 | self.assertEqual(reward, 0) 187 | 188 | _, reward = env.step(catch.Action.LEFT) 189 | self.assertEqual(reward, 1) 190 | # Now move right - the second ball should reach the bottom of the board 191 | # and the paddle should not catch it. 192 | _, reward = env.step(catch.Action.RIGHT) 193 | self.assertEqual(reward, -1) 194 | 195 | 196 | if __name__ == '__main__': 197 | absltest.main() 198 | -------------------------------------------------------------------------------- /csuite/environments/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Common utility functions.""" 17 | 18 | import numpy as np 19 | 20 | 21 | def binary_board_to_rgb(board: np.ndarray) -> np.ndarray: 22 | """Converts a binary 2D array to an rgb array.""" 23 | board = board.astype(np.uint8) * 255 24 | board = np.expand_dims(board, -1) 25 | board = np.tile(board, (1, 1, 3)) 26 | return board 27 | -------------------------------------------------------------------------------- /csuite/environments/dancing_catch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Continuing catch environment with non-stationary permutations of the observation features. 17 | 18 | Environment description can be found in the `DancingCatch` environment class. 19 | """ 20 | import copy 21 | import dataclasses 22 | import enum 23 | from typing import Optional 24 | 25 | from csuite.environments import base 26 | from csuite.environments import common 27 | from dm_env import specs 28 | 29 | import numpy as np 30 | 31 | # Error messages. 32 | _INVALID_ACTION = "Invalid action: expected 0, 1, or 2 but received {action}." 33 | _INVALID_PADDLE_POS = ("Invalid state: paddle should be positioned at the" 34 | " bottom of the board.") 35 | _INVALID_BALLS_RANGE = ( 36 | "Invalid state: positions of balls and paddle not in expected" 37 | " row range [0, {rows}) and column range [0, {columns}).") 38 | 39 | # Default environment variables. 40 | _ROWS = 10 41 | _COLUMNS = 5 42 | _SPAWN_PROBABILITY = 0.1 43 | _SWAP_EVERY = 10000 44 | 45 | 46 | class Action(enum.IntEnum): 47 | LEFT = 0 48 | STAY = 1 49 | RIGHT = 2 50 | 51 | @property 52 | def dx(self): 53 | """Maps LEFT to -1, STAY to 0 and RIGHT to 1.""" 54 | return self.value - 1 55 | 56 | 57 | @dataclasses.dataclass 58 | class Params: 59 | """Parameters of a continuing Catch instance. 60 | 61 | Attributes: 62 | rows: Integer number of rows. 63 | columns: Integer number of columns. 64 | observation_dim: Integer dimension of the observation features. 65 | spawn_probability: Probability of a new ball spawning. 66 | swap_every: Integer giving the interval at which a swap occurs. 67 | """ 68 | rows: int 69 | columns: int 70 | observation_dim: int 71 | spawn_probability: float 72 | swap_every: int 73 | 74 | 75 | @dataclasses.dataclass 76 | class State: 77 | """State of a continuing Catch instance. 78 | 79 | Attributes: 80 | paddle_x: An integer denoting the x-coordinate of the paddle. 81 | paddle_y: An integer denoting the y-coordinate of the paddle 82 | balls: A list of (x, y) coordinates representing the present balls. 83 | shuffle_idx: Indices for performing the observation shuffle as a result of 84 | the random swaps. 85 | time_since_swap: An integer denoting how many timesteps have elapsed since 86 | the last swap. 87 | rng: Internal NumPy pseudo-random number generator, included here for 88 | reproducibility purposes. 89 | """ 90 | paddle_x: int 91 | paddle_y: int 92 | balls: list[tuple[int, int]] 93 | shuffle_idx: np.ndarray 94 | time_since_swap: int 95 | rng: np.random.Generator 96 | 97 | 98 | class DancingCatch(base.Environment): 99 | """A continuing Catch environment with random swaps. 100 | 101 | This environment is the same as the continuing Catch environment, but 102 | at each timestep, there is a low probability that two entries in the 103 | observation are swapped. 104 | 105 | The observations are flattened, with entry one if the corresponding unshuffled 106 | position contains the paddle or a ball, and zero otherwise. 107 | """ 108 | 109 | def __init__(self, 110 | rows=_ROWS, 111 | columns=_COLUMNS, 112 | spawn_probability=_SPAWN_PROBABILITY, 113 | seed=None, 114 | swap_every=_SWAP_EVERY): 115 | """Initializes a continuing Catch environment. 116 | 117 | Args: 118 | rows: A positive integer denoting the number of rows. 119 | columns: A positive integer denoting the number of columns. 120 | spawn_probability: Float giving the probability of a new ball appearing. 121 | seed: Seed for the internal random number generator. 122 | swap_every: A positive integer denoting the interval at which a swap in 123 | the observation occurs. 124 | """ 125 | self._seed = seed 126 | self._params = Params( 127 | rows=rows, 128 | columns=columns, 129 | observation_dim=rows * columns, 130 | spawn_probability=spawn_probability, 131 | swap_every=swap_every) 132 | self._state = None 133 | 134 | def start(self, seed: Optional[int] = None): 135 | """Initializes the environment and returns an initial observation.""" 136 | 137 | # The initial state has one ball appearing in a random column at the top, 138 | # and the paddle centered at the bottom. 139 | 140 | rng = np.random.default_rng(self._seed if seed is None else seed) 141 | self._state = State( 142 | paddle_x=self._params.columns // 2, 143 | paddle_y=self._params.rows - 1, 144 | balls=[(rng.integers(self._params.columns), 0)], 145 | shuffle_idx=np.arange(self._params.observation_dim), 146 | time_since_swap=0, 147 | rng=rng, 148 | ) 149 | return self._get_observation() 150 | 151 | @property 152 | def started(self): 153 | """True if the environment has been started, False otherwise.""" 154 | # An unspecified state implies that the environment needs to be started. 155 | return self._state is not None 156 | 157 | def step(self, action): 158 | """Updates the environment state and returns an observation and reward. 159 | 160 | Args: 161 | action: An integer equalling 0, 1, or 2 indicating whether to move the 162 | paddle left, stay, or move the paddle right respectively. 163 | 164 | Returns: 165 | A tuple of type (int, float) giving the next observation and the reward. 166 | 167 | Raises: 168 | RuntimeError: If state has not yet been initialized by `start`. 169 | """ 170 | # Check if state has been initialized. 171 | if not self.started: 172 | raise RuntimeError(base.STEP_WITHOUT_START_ERR) 173 | 174 | # Check if input action is valid. 175 | if action not in [Action.LEFT, Action.STAY, Action.RIGHT]: 176 | raise ValueError(_INVALID_ACTION.format(action=action)) 177 | 178 | # Move the paddle. 179 | self._state.paddle_x = np.clip(self._state.paddle_x + Action(action).dx, 0, 180 | self._params.columns - 1) 181 | 182 | # Move all balls down by one unit. 183 | self._state.balls = [(x, y + 1) for x, y in self._state.balls] 184 | 185 | # Since at most one ball is added at each timestep, at most one ball 186 | # can be at the bottom of the board, and must be the 'oldest' ball. 187 | reward = 0. 188 | if self._state.balls and self._state.balls[0][1] == self._state.paddle_y: 189 | reward = 1. if self._state.balls[0][0] == self._state.paddle_x else -1. 190 | # Remove ball from list. 191 | self._state.balls = self._state.balls[1:] 192 | 193 | # Add new ball with given probability. 194 | if self._state.rng.random() < self._params.spawn_probability: 195 | self._state.balls.append( 196 | (self._state.rng.integers(self._params.columns), 0)) 197 | 198 | # Update time since last swap. 199 | self._state.time_since_swap += 1 200 | 201 | # Update the observation permutation indices by swapping two indices, 202 | # at the given interval. 203 | if self._state.time_since_swap % self._params.swap_every == 0: 204 | idx_1, idx_2 = self._state.rng.integers( 205 | self._params.observation_dim, size=2).T 206 | self._state.shuffle_idx[[idx_1, idx_2]] = ( 207 | self._state.shuffle_idx[[idx_2, idx_1]]) 208 | self._state.time_since_swap = 0 209 | 210 | return self._get_observation(), reward 211 | 212 | def _get_observation(self) -> np.ndarray: 213 | """Converts internal environment state to an array observation. 214 | 215 | Returns: 216 | A binary array of size (rows, columns) with entry 1 if it contains either 217 | a ball or a paddle, and entry 0 if the cell is empty. 218 | """ 219 | board = np.zeros((_ROWS, _COLUMNS), dtype=int) 220 | board.fill(0) 221 | board[self._state.paddle_y, self._state.paddle_x] = 1 222 | for x, y in self._state.balls: 223 | board[y, x] = 1 224 | board = board.flatten() 225 | return board[self._state.shuffle_idx] 226 | 227 | def observation_spec(self): 228 | """Describes the observation specs of the environment.""" 229 | return specs.BoundedArray( 230 | shape=(self._params.observation_dim,), 231 | dtype=int, 232 | minimum=0, 233 | maximum=1, 234 | name="board") 235 | 236 | def action_spec(self): 237 | """Describes the action specs of the environment.""" 238 | return specs.DiscreteArray(num_values=3, dtype=int, name="action") 239 | 240 | def get_state(self): 241 | """Returns a copy of the current environment state.""" 242 | return copy.deepcopy(self._state) if self._state is not None else None 243 | 244 | def set_state(self, state): 245 | """Sets environment state to state provided. 246 | 247 | Args: 248 | state: A State object which overrides the current state. 249 | """ 250 | # Check that input state values are valid. 251 | if not (0 <= state.paddle_x < self._params.columns and 252 | state.paddle_y == self._params.rows - 1): 253 | raise ValueError(_INVALID_PADDLE_POS) 254 | 255 | for x, y in state.balls: 256 | if not (0 <= x < self._params.columns and 0 <= y < self._params.rows): 257 | raise ValueError( 258 | _INVALID_BALLS_RANGE.format( 259 | rows=self._params.rows, columns=self._params.columns)) 260 | 261 | self._state = copy.deepcopy(state) 262 | 263 | def get_config(self): 264 | """Returns a copy of the environment configuration.""" 265 | return copy.deepcopy(self._params) 266 | 267 | def render(self) -> np.ndarray: 268 | board = self._get_observation().reshape(_ROWS, _COLUMNS) 269 | return common.binary_board_to_rgb(board) 270 | -------------------------------------------------------------------------------- /csuite/environments/dancing_catch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for DancingCatch.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from csuite.environments import dancing_catch 21 | 22 | 23 | class DancingCatchTest(parameterized.TestCase): 24 | 25 | def test_environment_setup(self): 26 | """Tests environment initialization.""" 27 | env = dancing_catch.DancingCatch() 28 | self.assertIsNotNone(env) 29 | 30 | def test_start(self): 31 | """Tests environment start.""" 32 | env = dancing_catch.DancingCatch() 33 | params = env.get_config() 34 | 35 | with self.subTest(name='step_without_start'): 36 | # Calling step before start should raise an error. 37 | with self.assertRaises(RuntimeError): 38 | env.step(dancing_catch.Action.LEFT) 39 | 40 | with self.subTest(name='start_state'): 41 | start_obs = env.start() 42 | state = env.get_state() 43 | # Paddle should be positioned at the bottom of the board. 44 | self.assertEqual(state.paddle_y, params.rows - 1) 45 | paddle_idx = state.paddle_y * params.columns + state.paddle_x 46 | self.assertEqual(start_obs[paddle_idx], 1) 47 | 48 | # First ball should be positioned at the top of the board. 49 | ball_x = state.balls[0][0] 50 | ball_y = state.balls[0][1] 51 | self.assertEqual(ball_y, 0) 52 | ball_idx = ball_y * params.columns + ball_x 53 | self.assertEqual(start_obs[ball_idx], 1) 54 | 55 | def test_invalid_state(self): 56 | """Tests setting environment state with invalid fields.""" 57 | env = dancing_catch.DancingCatch() 58 | env.start() 59 | 60 | with self.subTest(name='paddle_out_of_range'): 61 | new_state = env.get_state() 62 | new_state.paddle_x = 5 63 | with self.assertRaises(ValueError): 64 | env.set_state(new_state) 65 | 66 | with self.subTest(name='balls_out_of_range'): 67 | new_state = env.get_state() 68 | new_state.balls = [(0, -1)] 69 | with self.assertRaises(ValueError): 70 | env.set_state(new_state) 71 | 72 | @parameterized.parameters((0, 0, 1), (2, 1, 3), (4, 3, 4)) 73 | def test_one_step(self, paddle_x, expected_left_x, expected_right_x): 74 | """Tests one environment step given the x-position of the paddle.""" 75 | env = dancing_catch.DancingCatch() 76 | env.start() 77 | 78 | with self.subTest(name='invalid_action'): 79 | with self.assertRaises(ValueError): 80 | env.step(3) 81 | 82 | with self.subTest(name='move_left_step'): 83 | current_state = env.get_state() 84 | current_state.paddle_x = paddle_x 85 | env.set_state(current_state) 86 | 87 | env.step(dancing_catch.Action.LEFT) 88 | state = env.get_state() 89 | 90 | # Paddle x-position should have moved left by 1 unless at the edge. 91 | self.assertEqual(state.paddle_x, expected_left_x) 92 | 93 | with self.subTest(name='move_right_step'): 94 | current_state = env.get_state() 95 | current_state.paddle_x = paddle_x 96 | env.set_state(current_state) 97 | 98 | env.step(dancing_catch.Action.RIGHT) 99 | state = env.get_state() 100 | 101 | # Paddle x-position should have moved right by 1 unless at the edge. 102 | self.assertEqual(state.paddle_x, expected_right_x) 103 | 104 | with self.subTest(name='stay_step'): 105 | current_state = env.get_state() 106 | current_state.paddle_x = paddle_x 107 | env.set_state(current_state) 108 | 109 | env.step(dancing_catch.Action.STAY) 110 | state = env.get_state() 111 | self.assertEqual(state.paddle_x, paddle_x) 112 | 113 | 114 | if __name__ == '__main__': 115 | absltest.main() 116 | -------------------------------------------------------------------------------- /csuite/environments/experimental/pendulum_poke.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Continuing pendulum with random perturbations in the reward region. 17 | 18 | Environment description can be found in the `PendulumPoke' Environment 19 | class. 20 | """ 21 | 22 | import copy 23 | import dataclasses 24 | import enum 25 | from typing import Any, Callable 26 | from csuite.environments import base 27 | from dm_env import specs 28 | 29 | import numpy as np 30 | from PIL import Image 31 | from PIL import ImageDraw 32 | 33 | # Default environment variables. 34 | _NUM_ACTIONS = 3 # Size of action space discretization. 35 | _FRICTION = 0.1 36 | _GRAVITY = 9.81 37 | _SIMULATION_STEP_SIZE = 0.05 38 | _ACT_STEP_PERIOD = 4 39 | _MAX_SPEED = np.inf 40 | _REWARD_ANGLE = 30 41 | # Default environment variables for adding perturbations. 42 | _PERTURB_PROB = 0.01 43 | # TODO(b/243969989): Decide appropriate amount of perturbation torque. 44 | _PERTURB_TORQUE = 10. 45 | 46 | # Converter for degrees to radians. 47 | _RADIAN_MULTIPLIER = np.pi / 180 48 | 49 | # Error messages. 50 | _INVALID_ANGLE = ("Invalid state: expected angle to be in range [0, 2pi].") 51 | 52 | # Variables for pixel visualization of the environment. 53 | _IMAGE_SIZE = 256 54 | _CENTER_IMAGE = _IMAGE_SIZE // 2 - 1 55 | _SCALE_FACTOR = 0.75 56 | _PENDULUM_WIDTH = _IMAGE_SIZE // 64 57 | _TIP_RADIUS = _IMAGE_SIZE // 24 58 | _LIGHT_GREEN = "#d4ffd6" # For shading the reward region. 59 | _ARROW_WIDTH = _IMAGE_SIZE // 44 60 | _TORQUE_ANGLE = 20 61 | 62 | 63 | class Action(enum.IntEnum): 64 | """Actions for the PendulumPoke environment. 65 | 66 | There are three actions: 67 | 0: Apply -1 torque. 68 | 1: Do nothing. 69 | 2: Apply +1 torque. 70 | """ 71 | NEGATIVE, STAY, POSITIVE = range(3) 72 | 73 | @property 74 | def tau(self): 75 | """Maps NEGATIVE to -1, STAY to 0, and POSITIVE to 1.""" 76 | return self.value - 1 77 | 78 | 79 | @dataclasses.dataclass 80 | class State: 81 | """State of a PendulumPoke environment. 82 | 83 | Attributes: 84 | angle: a float in [0, 2*pi] giving the angle in radians of the pendulum. An 85 | angle of 0 indicates that the pendulum is hanging downwards. 86 | velocity: a float in [-max_speed, max_speed] giving the angular velocity. 87 | rng: Internal NumPy pseudo-random number generator, included here for 88 | reproducibility purposes. 89 | """ 90 | angle: float 91 | velocity: float 92 | rng: np.random.Generator 93 | 94 | 95 | @dataclasses.dataclass 96 | class Params: 97 | """Parameters of a PendulumPoke environment.""" 98 | friction: float 99 | gravity: float 100 | simulation_step_size: float 101 | act_step_period: int 102 | max_speed: float 103 | reward_fn: Callable[..., float] 104 | perturb_prob: float 105 | perturb_torque: float 106 | 107 | 108 | def sparse_reward(state: State, 109 | unused_torque: Any, 110 | unused_step_size: Any, 111 | reward_angle: int = _REWARD_ANGLE) -> float: 112 | """Returns a sparse reward for the continuing pendulum problem. 113 | 114 | Args: 115 | state: A State object containing the current angle and velocity. 116 | reward_angle: An integer denoting the angle from vertical, in degrees, where 117 | the pendulum is rewarding. 118 | 119 | Returns: 120 | A reward of 1 if the angle of the pendulum is within the range, and 121 | a reward of 0 otherwise. 122 | """ 123 | reward_angle_radians = reward_angle * _RADIAN_MULTIPLIER 124 | if (np.pi - reward_angle_radians < state.angle < 125 | np.pi + reward_angle_radians): 126 | return 1. 127 | else: 128 | return 0. 129 | 130 | 131 | def _alias_angle(angle: float) -> float: 132 | """Returns an angle between 0 and 2*pi.""" 133 | return angle % (2 * np.pi) 134 | 135 | 136 | class PendulumPoke(base.Environment): 137 | """A pendulum environment with a random addition of force in the reward region. 138 | 139 | This environment has the same parameters as the `Pendulum` environment, and 140 | additionally with a small probability, a force is applied to the pendulum 141 | if it is in the rewarding region to knock the pendulum over. The magnitude 142 | of the force is constant, and the direction that the force is applied is 143 | chosen uniformly at random. 144 | """ 145 | 146 | def __init__(self, 147 | friction=_FRICTION, 148 | gravity=_GRAVITY, 149 | simulation_step_size=_SIMULATION_STEP_SIZE, 150 | act_step_period=_ACT_STEP_PERIOD, 151 | max_speed=_MAX_SPEED, 152 | reward_fn=sparse_reward, 153 | perturb_prob=_PERTURB_PROB, 154 | perturb_torque=_PERTURB_TORQUE, 155 | seed=None): 156 | """Initializes a new pendulum environment with random perturbations in the rewarding region. 157 | 158 | Args: 159 | friction: A positive float giving the coefficient of friction. 160 | gravity: A float giving the acceleration due to gravity. 161 | simulation_step_size: The step size (in seconds) of the simulation. 162 | act_step_period: An integer giving the number of simulation steps for each 163 | action input. 164 | max_speed: A float giving the maximum speed (in radians/second) allowed in 165 | the simulation. 166 | reward_fn: A callable which returns a float reward given current state. 167 | perturb_prob: A float giving the probability that a random force is 168 | applied to the pendulum if it is in the rewarding region. 169 | perturb_torque: A float giving the magnitude of the random force. 170 | seed: Seed for the internal random number generator. 171 | """ 172 | self._params = Params( 173 | friction=friction, 174 | gravity=gravity, 175 | simulation_step_size=simulation_step_size, 176 | act_step_period=act_step_period, 177 | max_speed=max_speed, 178 | reward_fn=reward_fn, 179 | perturb_prob=perturb_prob, 180 | perturb_torque=perturb_torque) 181 | self._seed = seed 182 | self._state = None 183 | self._torque = 0 184 | self._perturb_direction = 0 # For visualization purposes. 185 | 186 | def start(self, seed=None): 187 | """Initializes the environment and returns an initial observation.""" 188 | self._state = State( 189 | angle=0., 190 | velocity=0., 191 | rng=np.random.default_rng(self._seed if seed is None else seed)) 192 | return np.array((np.cos(self._state.angle), np.sin( 193 | self._state.angle), self._state.velocity), 194 | dtype=np.float32) 195 | 196 | @property 197 | def started(self): 198 | """True if the environment has been started, False otherwise.""" 199 | # An unspecified state implies that the environment needs to be started. 200 | return self._state is not None 201 | 202 | def step(self, action): 203 | """Updates the environment state and returns an observation and reward. 204 | 205 | Args: 206 | action: An integer in {0, 1, 2} indicating whether to subtract one unit of 207 | torque, do nothing, or add one unit of torque. 208 | 209 | Returns: 210 | A tuple giving the next observation in the form of a NumPy array 211 | and the reward as a float. 212 | 213 | Raises: 214 | RuntimeError: If state has not yet been initialized by `start`. 215 | """ 216 | # Check if state has been initialized. 217 | if not self.started: 218 | raise RuntimeError(base.STEP_WITHOUT_START_ERR) 219 | 220 | self._torque = Action(action).tau 221 | 222 | # Integrate over time steps to get new angle and velocity. 223 | new_angle = self._state.angle 224 | new_velocity = self._state.velocity 225 | 226 | # If the pendulum is in the rewarding region, with the given probability 227 | # add a force in a direction chosen uniformly at random. 228 | reward_angle_rad = _REWARD_ANGLE * _RADIAN_MULTIPLIER 229 | if ((np.pi - reward_angle_rad < new_angle < np.pi + reward_angle_rad) and 230 | self._state.rng.uniform() < self._params.perturb_prob): 231 | if self._state.rng.uniform() < 0.5: 232 | applied_torque = self._torque - self._params.perturb_torque 233 | self._perturb_direction = -1 234 | else: 235 | applied_torque = self._torque + self._params.perturb_torque 236 | self._perturb_direction = 1 237 | else: 238 | applied_torque = self._torque 239 | self._perturb_direction = 0 240 | 241 | for _ in range(self._params.act_step_period): 242 | new_velocity += ((applied_torque - self._params.friction * new_velocity - 243 | self._params.gravity * np.sin(new_angle)) * 244 | self._params.simulation_step_size) 245 | new_angle += new_velocity * self._params.simulation_step_size 246 | 247 | # Ensure the angle is between 0 and 2*pi. 248 | new_angle = _alias_angle(new_angle) 249 | 250 | # Clip velocity to max_speed. 251 | new_velocity = np.clip(new_velocity, -self._params.max_speed, 252 | self._params.max_speed) 253 | 254 | self._state = State( 255 | angle=new_angle, velocity=new_velocity, rng=self._state.rng) 256 | return (np.array((np.cos(self._state.angle), np.sin( 257 | self._state.angle), self._state.velocity), 258 | dtype=np.float32), 259 | self._params.reward_fn(self._state, self._torque, 260 | self._params.simulation_step_size)) 261 | 262 | def observation_spec(self): 263 | """Describes the observation specs of the environment.""" 264 | return specs.BoundedArray((3,), 265 | dtype=np.float32, 266 | minimum=[-1, -1, -self._params.max_speed], 267 | maximum=[1, 1, self._params.max_speed]) 268 | 269 | def action_spec(self): 270 | """Describes the action specs of the environment.""" 271 | return specs.DiscreteArray(_NUM_ACTIONS, dtype=int, name="action") 272 | 273 | def get_state(self): 274 | """Returns a copy of the current environment state.""" 275 | return copy.deepcopy(self._state) if self._state is not None else None 276 | 277 | def set_state(self, state): 278 | """Sets environment state to state provided. 279 | 280 | Args: 281 | state: A State object which overrides the current state. 282 | 283 | Returns: 284 | A NumPy array for the observation including the angle and velocity. 285 | """ 286 | # Check that input state values are valid. 287 | if not 0 <= state.angle <= 2 * np.pi: 288 | raise ValueError(_INVALID_ANGLE) 289 | 290 | self._state = copy.deepcopy(state) 291 | 292 | return np.array((np.cos(self._state.angle), np.sin( 293 | self._state.angle), self._state.velocity), 294 | dtype=np.float32) 295 | 296 | def render(self): 297 | image = Image.new("RGB", (_IMAGE_SIZE, _IMAGE_SIZE), "white") 298 | dct = ImageDraw.Draw(image) 299 | # Get x and y positions of the pendulum tip relative to the center. 300 | x_pos = np.sin(self._state.angle) 301 | y_pos = np.cos(self._state.angle) 302 | 303 | def abs_coordinates(x, y): 304 | """Return absolute coordinates given coordinates relative to center.""" 305 | return (x * _SCALE_FACTOR * _CENTER_IMAGE + _CENTER_IMAGE, 306 | y * _SCALE_FACTOR * _CENTER_IMAGE + _CENTER_IMAGE) 307 | 308 | # Draw reward range region. 309 | boundary_x = _CENTER_IMAGE * (1 - _SCALE_FACTOR) 310 | pendulum_bounding_box = [(boundary_x, boundary_x), 311 | (_IMAGE_SIZE - boundary_x, 312 | _IMAGE_SIZE - boundary_x)] 313 | dct.pieslice( 314 | pendulum_bounding_box, 315 | start=(270 - _REWARD_ANGLE), 316 | end=(270 + _REWARD_ANGLE), 317 | fill=_LIGHT_GREEN) 318 | 319 | # Get absolute coordinates of the pendulum tip. 320 | tip_coords = abs_coordinates(x_pos, y_pos) 321 | # Draw pendulum line. 322 | dct.line([(_CENTER_IMAGE, _CENTER_IMAGE), tip_coords], 323 | fill="black", 324 | width=_PENDULUM_WIDTH) 325 | # Draw circular pendulum tip. 326 | x, y = tip_coords 327 | tip_bounding_box = [(x - _TIP_RADIUS, y - _TIP_RADIUS), 328 | (x + _TIP_RADIUS, y + _TIP_RADIUS)] 329 | dct.ellipse(tip_bounding_box, fill="red") 330 | 331 | # Draw torque arrow. 332 | if self._torque > 0: 333 | dct.arc( 334 | pendulum_bounding_box, 335 | start=360 - _TORQUE_ANGLE, 336 | end=_TORQUE_ANGLE, 337 | fill="blue", 338 | width=_ARROW_WIDTH) 339 | # Draw arrow heads. 340 | arrow_x, arrow_y = abs_coordinates( 341 | np.cos(_TORQUE_ANGLE * _RADIAN_MULTIPLIER), 342 | -np.sin(_TORQUE_ANGLE * _RADIAN_MULTIPLIER)) 343 | dct.regular_polygon((arrow_x, arrow_y, _ARROW_WIDTH * 1.5), 344 | n_sides=3, 345 | rotation=_TORQUE_ANGLE, 346 | fill="blue") 347 | 348 | elif self._torque < 0: 349 | dct.arc( 350 | pendulum_bounding_box, 351 | start=180 - _TORQUE_ANGLE, 352 | end=180 + _TORQUE_ANGLE, 353 | fill="blue", 354 | width=_ARROW_WIDTH) 355 | # Draw arrow heads. 356 | arrow_x, arrow_y = abs_coordinates( 357 | -np.cos(_TORQUE_ANGLE * _RADIAN_MULTIPLIER), 358 | -np.sin(_TORQUE_ANGLE * _RADIAN_MULTIPLIER)) 359 | dct.regular_polygon((arrow_x, arrow_y, _ARROW_WIDTH * 1.5), 360 | n_sides=3, 361 | rotation=-_TORQUE_ANGLE, 362 | fill="blue") 363 | 364 | # Drawing perturbation arrow. 365 | if self._perturb_direction == 1: 366 | tip_coords = (_IMAGE_SIZE // 4 - _TIP_RADIUS, _IMAGE_SIZE // 8) 367 | arrow_coords = [ 368 | tip_coords, (_IMAGE_SIZE // 4 + _TIP_RADIUS, _IMAGE_SIZE // 8) 369 | ] 370 | dct.line(arrow_coords, fill="red", width=_PENDULUM_WIDTH) 371 | dct.regular_polygon((tip_coords, _ARROW_WIDTH * 1.2), 372 | rotation=90, 373 | n_sides=3, 374 | fill="red") 375 | elif self._perturb_direction == -1: 376 | tip_coords = (_IMAGE_SIZE * 3 // 4 + _TIP_RADIUS, _IMAGE_SIZE // 8) 377 | arrow_coords = [(_IMAGE_SIZE * 3 // 4 - _TIP_RADIUS, _IMAGE_SIZE // 8), 378 | tip_coords] 379 | dct.line(arrow_coords, fill="red", width=_PENDULUM_WIDTH) 380 | dct.regular_polygon((tip_coords, _ARROW_WIDTH * 1.2), 381 | rotation=270, 382 | n_sides=3, 383 | fill="red") 384 | return np.asarray(image, dtype=np.uint8) 385 | -------------------------------------------------------------------------------- /csuite/environments/experimental/pendulum_poke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for pendulum_poke.""" 17 | 18 | from absl.testing import absltest 19 | from csuite.environments.experimental import pendulum_poke 20 | 21 | 22 | class PendulumTest(absltest.TestCase): 23 | 24 | def test_environment_setup(self): 25 | """Tests environment initialization.""" 26 | env = pendulum_poke.PendulumPoke() 27 | self.assertIsNotNone(env) 28 | 29 | def test_start(self): 30 | """Tests environment start.""" 31 | env = pendulum_poke.PendulumPoke() 32 | 33 | with self.subTest(name='step_without_start'): 34 | # Calling step before start should raise an error. 35 | with self.assertRaises(RuntimeError): 36 | env.step(pendulum_poke.Action.NEGATIVE) 37 | 38 | with self.subTest(name='start_state'): 39 | start_obs = env.start() 40 | # Initial cosine of the angle should be 1. 41 | # Initial sine of the angle and initial velocity should be 0. 42 | self.assertEqual(start_obs[0], 1.) 43 | self.assertEqual(start_obs[1], 0.) 44 | self.assertEqual(start_obs[2], 0.) 45 | 46 | def test_one_step(self): 47 | """Tests one environment step.""" 48 | env = pendulum_poke.PendulumPoke() 49 | env.start() 50 | _, reward = env.step(pendulum_poke.Action.NEGATIVE) 51 | self.assertEqual(reward, 0.) 52 | _, reward = env.step(pendulum_poke.Action.POSITIVE) 53 | self.assertEqual(reward, 0.) 54 | _, reward = env.step(pendulum_poke.Action.STAY) 55 | self.assertEqual(reward, 0.) 56 | 57 | def test_setting_state(self): 58 | """Tests setting environment state and solver.""" 59 | env = pendulum_poke.PendulumPoke() 60 | old_obs = env.start() 61 | prev_state = env.get_state() 62 | # Take two steps adding +1 torque, then set state to downwards position. 63 | for _ in range(2): 64 | old_obs, _ = env.step(pendulum_poke.Action.POSITIVE) 65 | new_state = pendulum_poke.State(angle=0., velocity=0., rng=prev_state.rng) 66 | new_obs = env.set_state(new_state) 67 | for _ in range(2): 68 | new_obs, _ = env.step(pendulum_poke.Action.POSITIVE) 69 | 70 | # If the solver was properly updated, the two observations are the same. 71 | self.assertLessEqual(abs(new_obs[0]), old_obs[0]) 72 | self.assertLessEqual(abs(new_obs[1]), old_obs[1]) 73 | self.assertLessEqual(abs(new_obs[2]), old_obs[2]) 74 | 75 | 76 | if __name__ == '__main__': 77 | absltest.main() 78 | -------------------------------------------------------------------------------- /csuite/environments/pendulum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implementation of a continuing Pendulum environment with discrete actions. 17 | 18 | Environment description can be found in the `Pendulum' Environment 19 | class. 20 | """ 21 | 22 | import copy 23 | import dataclasses 24 | import enum 25 | from typing import Any, Callable, Optional 26 | 27 | from csuite.environments import base 28 | from dm_env import specs 29 | 30 | import numpy as np 31 | from PIL import Image 32 | from PIL import ImageDraw 33 | 34 | # Default environment variables. 35 | _NUM_ACTIONS = 3 # Size of action space discretization. 36 | _FRICTION = 0.1 37 | _GRAVITY = 9.81 38 | _SIMULATION_STEP_SIZE = 0.05 39 | _ACT_STEP_PERIOD = 4 40 | _MAX_SPEED = np.inf 41 | _REWARD_ANGLE = 30 42 | 43 | # Converter for degrees to radians. 44 | _RADIAN_MULTIPLIER = np.pi / 180 45 | 46 | # Error messages. 47 | _INVALID_ANGLE = ("Invalid state: expected angle to be in range [0, 2pi].") 48 | 49 | # Variables for pixel visualization of the environment. 50 | _IMAGE_SIZE = 256 51 | _CENTER_IMAGE = _IMAGE_SIZE // 2 - 1 52 | _SCALE_FACTOR = 0.75 53 | _PENDULUM_WIDTH = _IMAGE_SIZE // 64 54 | _TIP_RADIUS = _IMAGE_SIZE // 24 55 | _LIGHT_GREEN = "#d4ffd6" # For shading the reward region. 56 | _ARROW_WIDTH = _IMAGE_SIZE // 44 57 | _TORQUE_ANGLE = 20 58 | 59 | 60 | class Action(enum.IntEnum): 61 | """Actions for the Pendulum environment. 62 | 63 | There are three actions: 64 | 0: Apply -1 torque. 65 | 1: Do nothing. 66 | 2: Apply +1 torque. 67 | """ 68 | NEGATIVE, STAY, POSITIVE = range(3) 69 | 70 | @property 71 | def tau(self): 72 | """Maps NEGATIVE to -1, STAY to 0, and POSITIVE to 1.""" 73 | return self.value - 1 74 | 75 | 76 | @dataclasses.dataclass 77 | class State: 78 | """State of a continuing pendulum environment. 79 | 80 | Attributes: 81 | angle: a float in [0, 2*pi] giving the angle in radians of the pendulum. An 82 | angle of 0 indicates that the pendulum is hanging downwards. 83 | velocity: a float in [-max_speed, max_speed] giving the angular velocity. 84 | """ 85 | angle: float 86 | velocity: float 87 | 88 | 89 | @dataclasses.dataclass 90 | class Params: 91 | """Parameters of a continuing pendulum environment.""" 92 | start_state_fn: Callable[[], State] 93 | friction: float 94 | gravity: float 95 | simulation_step_size: float 96 | act_step_period: int 97 | max_speed: float 98 | reward_fn: Callable[..., float] 99 | 100 | 101 | # Default start state and reward function. 102 | def start_from_bottom(): 103 | """Returns default start state with pendulum hanging vertically downwards.""" 104 | return State(angle=0., velocity=0.) 105 | 106 | 107 | def sparse_reward(state: State, 108 | unused_torque: Any, 109 | unused_step_size: Any, 110 | reward_angle: int = _REWARD_ANGLE) -> float: 111 | """Returns a sparse reward for the continuing pendulum problem. 112 | 113 | Args: 114 | state: A State object containing the current angle and velocity. 115 | reward_angle: An integer denoting the angle from vertical, in degrees, where 116 | the pendulum is rewarding. 117 | 118 | Returns: 119 | A reward of 1 if the angle of the pendulum is within the range, and 120 | a reward of 0 otherwise. 121 | """ 122 | reward_angle_radians = reward_angle * _RADIAN_MULTIPLIER 123 | if (np.pi - reward_angle_radians < state.angle < 124 | np.pi + reward_angle_radians): 125 | return 1. 126 | else: 127 | return 0. 128 | 129 | 130 | def _alias_angle(angle: float) -> float: 131 | """Returns an angle between 0 and 2*pi.""" 132 | return angle % (2 * np.pi) 133 | 134 | 135 | class Pendulum(base.Environment): 136 | """A continuing pendulum environment. 137 | 138 | Starting from a hanging down position, swing up a single pendulum to the 139 | inverted position and maintain the pendulum in this position. 140 | 141 | Most of the default arguments model the pendulum described in "A Comparison of 142 | Direct and Model-Based Reinforcement Learning" (Atkeson & Santamaria, 1997). 143 | The key differences are the following: 144 | 1) This environment is now continuing, i.e. the "trial length" is infinite. 145 | 2) Instead of directly returning the angle of the pendulum in the observation, 146 | this environment returns the cosine and sine of the angle. 147 | 3) There are only three discrete actions (apply a torque of -1, apply a torque 148 | of +1, and apply no-op) as opposed to a continuous torque value. 149 | 4) The default reward function is implemented as a sparse reward, i.e. there 150 | is only a reward of 1 attained when the angle is in the region specified by 151 | the interval (pi - reward_angle, pi + reward_angle). 152 | 153 | The pendulum's motion is described by the equation 154 | ``` 155 | theta'' = tau - mu * theta' - g * sin(theta), 156 | ``` 157 | 158 | where theta is the angle, mu is the friction coefficient, tau is the torque, 159 | and g is the acceleration due to gravity. 160 | 161 | Observations are NumPy arrays of the form [cos(angle), sin(angle), velocity] 162 | where angle is in [0, 2*pi] and velocity is in [-max_speed, max_speed]. 163 | """ 164 | 165 | def __init__(self, 166 | start_state_fn=start_from_bottom, 167 | friction=_FRICTION, 168 | gravity=_GRAVITY, 169 | simulation_step_size=_SIMULATION_STEP_SIZE, 170 | act_step_period=_ACT_STEP_PERIOD, 171 | max_speed=_MAX_SPEED, 172 | reward_fn=sparse_reward, 173 | seed=None): 174 | """Initializes a new pendulum environment. 175 | 176 | Args: 177 | start_state_fn: A callable which returns a `State` object giving the 178 | initial state. 179 | friction: A positive float giving the coefficient of friction. 180 | gravity: A float giving the acceleration due to gravity. 181 | simulation_step_size: The step size (in seconds) of the simulation. 182 | act_step_period: An integer giving the number of simulation steps for each 183 | action input. 184 | max_speed: A float giving the maximum speed (in radians/second) allowed in 185 | the simulation. 186 | reward_fn: A callable which returns a float reward given current state. 187 | seed: Seed for the internal random number generator. 188 | """ 189 | del seed 190 | self._params = Params( 191 | start_state_fn=start_state_fn, 192 | friction=friction, 193 | gravity=gravity, 194 | simulation_step_size=simulation_step_size, 195 | act_step_period=act_step_period, 196 | max_speed=max_speed, 197 | reward_fn=reward_fn) 198 | self._state = None 199 | self._torque = 0 200 | 201 | def start(self, seed: Optional[int] = None): 202 | """Initializes the environment and returns an initial observation.""" 203 | del seed 204 | self._state = self._params.start_state_fn() 205 | return np.array((np.cos(self._state.angle), 206 | np.sin(self._state.angle), 207 | self._state.velocity), 208 | dtype=np.float32) 209 | 210 | @property 211 | def started(self): 212 | """True if the environment has been started, False otherwise.""" 213 | # An unspecified state implies that the environment needs to be started. 214 | return self._state is not None 215 | 216 | def step(self, action): 217 | """Updates the environment state and returns an observation and reward. 218 | 219 | Args: 220 | action: An integer in {0, 1, 2} indicating whether to subtract one unit of 221 | torque, do nothing, or add one unit of torque. 222 | 223 | Returns: 224 | A tuple giving the next observation in the form of a NumPy array 225 | and the reward as a float. 226 | 227 | Raises: 228 | RuntimeError: If state has not yet been initialized by `start`. 229 | """ 230 | # Check if state has been initialized. 231 | if not self.started: 232 | raise RuntimeError(base.STEP_WITHOUT_START_ERR) 233 | 234 | self._torque = Action(action).tau 235 | 236 | # Integrate over time steps to get new angle and velocity. 237 | new_angle = self._state.angle 238 | new_velocity = self._state.velocity 239 | 240 | for _ in range(self._params.act_step_period): 241 | new_velocity += ((self._torque - self._params.friction * new_velocity - 242 | self._params.gravity * np.sin(new_angle)) * 243 | self._params.simulation_step_size) 244 | new_angle += new_velocity * self._params.simulation_step_size 245 | 246 | # Ensure the angle is between 0 and 2*pi. 247 | new_angle = _alias_angle(new_angle) 248 | 249 | # Clip velocity to max_speed. 250 | new_velocity = np.clip(new_velocity, -self._params.max_speed, 251 | self._params.max_speed) 252 | 253 | self._state = State(angle=new_angle, velocity=new_velocity) 254 | return (np.array((np.cos(self._state.angle), 255 | np.sin(self._state.angle), 256 | self._state.velocity), 257 | dtype=np.float32), 258 | self._params.reward_fn(self._state, self._torque, 259 | self._params.simulation_step_size)) 260 | 261 | def observation_spec(self): 262 | """Describes the observation specs of the environment.""" 263 | return specs.BoundedArray((3,), 264 | dtype=np.float32, 265 | minimum=[-1, -1, -self._params.max_speed], 266 | maximum=[1, 1, self._params.max_speed]) 267 | 268 | def action_spec(self): 269 | """Describes the action specs of the environment.""" 270 | return specs.DiscreteArray(_NUM_ACTIONS, dtype=int, name="action") 271 | 272 | def get_state(self): 273 | """Returns a copy of the current environment state.""" 274 | return copy.deepcopy(self._state) if self._state is not None else None 275 | 276 | def set_state(self, state): 277 | """Sets environment state to state provided. 278 | 279 | Args: 280 | state: A State object which overrides the current state. 281 | 282 | Returns: 283 | A NumPy array for the observation including the angle and velocity. 284 | """ 285 | # Check that input state values are valid. 286 | if not 0 <= state.angle <= 2 * np.pi: 287 | raise ValueError(_INVALID_ANGLE) 288 | 289 | self._state = copy.deepcopy(state) 290 | 291 | return np.array((np.cos(self._state.angle), 292 | np.sin(self._state.angle), 293 | self._state.velocity), 294 | dtype=np.float32) 295 | 296 | def render(self): 297 | image = Image.new("RGB", (_IMAGE_SIZE, _IMAGE_SIZE), "white") 298 | dct = ImageDraw.Draw(image) 299 | # Get x and y positions of the pendulum tip relative to the center. 300 | x_pos = np.sin(self._state.angle) 301 | y_pos = np.cos(self._state.angle) 302 | 303 | def abs_coordinates(x, y): 304 | """Return absolute coordinates given coordinates relative to center.""" 305 | return (x * _SCALE_FACTOR * _CENTER_IMAGE + _CENTER_IMAGE, 306 | y * _SCALE_FACTOR * _CENTER_IMAGE + _CENTER_IMAGE) 307 | 308 | # Draw reward range region. 309 | boundary_x = _CENTER_IMAGE * (1 - _SCALE_FACTOR) 310 | pendulum_bounding_box = [(boundary_x, boundary_x), 311 | (_IMAGE_SIZE - boundary_x, 312 | _IMAGE_SIZE - boundary_x)] 313 | dct.pieslice( 314 | pendulum_bounding_box, 315 | start=(270 - _REWARD_ANGLE), 316 | end=(270 + _REWARD_ANGLE), 317 | fill=_LIGHT_GREEN) 318 | 319 | # Get absolute coordinates of the pendulum tip. 320 | tip_coords = abs_coordinates(x_pos, y_pos) 321 | # Draw pendulum line. 322 | dct.line([(_CENTER_IMAGE, _CENTER_IMAGE), tip_coords], 323 | fill="black", 324 | width=_PENDULUM_WIDTH) 325 | # Draw circular pendulum tip. 326 | x, y = tip_coords 327 | tip_bounding_box = [(x - _TIP_RADIUS, y - _TIP_RADIUS), 328 | (x + _TIP_RADIUS, y + _TIP_RADIUS)] 329 | dct.ellipse(tip_bounding_box, fill="red") 330 | 331 | # Draw torque arrow. 332 | if self._torque > 0: 333 | dct.arc( 334 | pendulum_bounding_box, 335 | start=360 - _TORQUE_ANGLE, 336 | end=_TORQUE_ANGLE, 337 | fill="blue", 338 | width=_ARROW_WIDTH) 339 | # Draw arrow heads. 340 | arrow_x, arrow_y = abs_coordinates( 341 | np.cos(_TORQUE_ANGLE * _RADIAN_MULTIPLIER), 342 | -np.sin(_TORQUE_ANGLE * _RADIAN_MULTIPLIER)) 343 | dct.regular_polygon((arrow_x, arrow_y, _ARROW_WIDTH * 1.5), 344 | n_sides=3, 345 | rotation=_TORQUE_ANGLE, 346 | fill="blue") 347 | 348 | elif self._torque < 0: 349 | dct.arc( 350 | pendulum_bounding_box, 351 | start=180 - _TORQUE_ANGLE, 352 | end=180 + _TORQUE_ANGLE, 353 | fill="blue", 354 | width=_ARROW_WIDTH) 355 | # Draw arrow heads. 356 | arrow_x, arrow_y = abs_coordinates( 357 | -np.cos(_TORQUE_ANGLE * _RADIAN_MULTIPLIER), 358 | -np.sin(_TORQUE_ANGLE * _RADIAN_MULTIPLIER)) 359 | dct.regular_polygon((arrow_x, arrow_y, _ARROW_WIDTH * 1.5), 360 | n_sides=3, 361 | rotation=-_TORQUE_ANGLE, 362 | fill="blue") 363 | 364 | return np.asarray(image, dtype=np.uint8) 365 | -------------------------------------------------------------------------------- /csuite/environments/pendulum_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for pendulum.""" 17 | 18 | from absl.testing import absltest 19 | from csuite.environments import pendulum 20 | 21 | 22 | class PendulumTest(absltest.TestCase): 23 | 24 | def test_environment_setup(self): 25 | """Tests environment initialization.""" 26 | env = pendulum.Pendulum() 27 | self.assertIsNotNone(env) 28 | 29 | def test_start(self): 30 | """Tests environment start.""" 31 | env = pendulum.Pendulum() 32 | 33 | with self.subTest(name='step_without_start'): 34 | # Calling step before start should raise an error. 35 | with self.assertRaises(RuntimeError): 36 | env.step(pendulum.Action.NEGATIVE) 37 | 38 | with self.subTest(name='start_state'): 39 | start_obs = env.start() 40 | # Initial cosine of the angle should be 1. 41 | # Initial sine of the angle and initial velocity should be 0. 42 | self.assertEqual(start_obs[0], 1.) 43 | self.assertEqual(start_obs[1], 0.) 44 | self.assertEqual(start_obs[2], 0.) 45 | 46 | def test_one_step(self): 47 | """Tests one environment step.""" 48 | env = pendulum.Pendulum() 49 | env.start() 50 | _, reward = env.step(pendulum.Action.NEGATIVE) 51 | self.assertEqual(reward, 0.) 52 | _, reward = env.step(pendulum.Action.POSITIVE) 53 | self.assertEqual(reward, 0.) 54 | _, reward = env.step(pendulum.Action.STAY) 55 | self.assertEqual(reward, 0.) 56 | 57 | def test_setting_state(self): 58 | """Tests setting environment state and solver.""" 59 | env = pendulum.Pendulum() 60 | old_obs = env.start() 61 | # Take two steps adding +1 torque, then set state to downwards position. 62 | for _ in range(2): 63 | old_obs, _ = env.step(pendulum.Action.POSITIVE) 64 | new_state = pendulum.State(angle=0., velocity=0.) 65 | new_obs = env.set_state(new_state) 66 | for _ in range(2): 67 | new_obs, _ = env.step(pendulum.Action.POSITIVE) 68 | 69 | # If the solver was properly updated, the two observations are the same. 70 | self.assertLessEqual(abs(new_obs[0]), old_obs[0]) 71 | self.assertLessEqual(abs(new_obs[1]), old_obs[1]) 72 | self.assertLessEqual(abs(new_obs[2]), old_obs[2]) 73 | 74 | 75 | if __name__ == '__main__': 76 | absltest.main() 77 | -------------------------------------------------------------------------------- /csuite/environments/taxi.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implementation of the tabular Taxi environment. 17 | 18 | Environment description and details can be found in the `Taxi` 19 | environment class. 20 | 21 | The 5x5 gridworld is depicted below, where we use xy-coordinates to describe the 22 | position of the squares; the coordinate (0, 0) represents the top left corner. 23 | 24 | ||||||||||| 25 | |R: | : :G| 26 | | : | : : | 27 | | : : : : | 28 | | | : | : | 29 | |Y| : |B: | 30 | ||||||||||| 31 | """ 32 | 33 | import copy 34 | import dataclasses 35 | import enum 36 | import itertools 37 | from typing import Optional 38 | 39 | from csuite.environments import base 40 | from dm_env import specs 41 | 42 | import numpy as np 43 | from PIL import Image 44 | from PIL import ImageDraw 45 | 46 | # Default environment variables. 47 | _NUM_ROWS = 5 48 | _NUM_COLUMNS = 5 49 | # Passenger positions: one of the four pickup locations, or in the taxi. 50 | _NUM_POSITIONS = 5 51 | _NUM_DEST = 4 52 | _NUM_STATES = _NUM_ROWS * _NUM_COLUMNS * _NUM_POSITIONS * _NUM_DEST 53 | _NUM_ACTIONS = 6 54 | 55 | # Error messages. 56 | _INVALID_ACTION = "Invalid action: expected value in [0,5] but received {}." 57 | _INVALID_TAXI_LOC = "Invalid state: expected taxi coordinates in range [0,4]." 58 | _INVALID_PASS_LOC = ("Invalid state: expected passenger location as an integer" 59 | "in [0,4].") 60 | _INVALID_DEST = ("Invalid state: expected destination location as an integer" 61 | "in [0,3].") 62 | 63 | # Dictionary mapping the four colored squares to their xy-coordinate 64 | # on the 5x5 grid, with keys 0 (Red), 1 (Green), 2 (Yellow), 3 (Blue). 65 | _COLOR_POSITIONS = { 66 | 0: (0, 0), 67 | 1: (4, 0), 68 | 2: (0, 4), 69 | 3: (3, 4), 70 | } 71 | 72 | # List of (x, y) pairs where transitioning 73 | # between each pair's shared edge is forbidden. 74 | _BLOCKED_TUPLE = ( 75 | ((1, 0), (2, 0)), 76 | ((1, 1), (2, 1)), 77 | ((0, 3), (1, 3)), 78 | ((0, 4), (1, 4)), 79 | ((2, 3), (3, 3)), 80 | ((2, 4), (3, 4)), 81 | ) 82 | # Variables for pixel visualization of the environment. 83 | _PIXELS_PER_SQ = 50 # size of each grid square. 84 | _RED_HEX = "#ff9999" 85 | _GREEN_HEX = "#9cff9c" 86 | _BLUE_HEX = "#99e2ff" 87 | _YELLOW_HEX = "#fff899" 88 | _PASS_LOC_HEX = "#d400ff" 89 | _DEST_HEX = "#008c21" 90 | _EMPTY_TAXI_HEX = "#8f8f8f" 91 | 92 | # Other derived constants used for pixel visualization. 93 | _HEIGHT = _PIXELS_PER_SQ * (_NUM_ROWS + 2) 94 | _WIDTH = _PIXELS_PER_SQ * (_NUM_COLUMNS + 2) 95 | _BORDER = [ # Bounding box for perimeter of the taxi grid. 96 | (_PIXELS_PER_SQ, _PIXELS_PER_SQ), 97 | (_PIXELS_PER_SQ * (_NUM_COLUMNS + 1), _PIXELS_PER_SQ * (_NUM_ROWS + 1)) 98 | ] 99 | _OFFSET = _PIXELS_PER_SQ // 5 # To make the taxi bounding box slightly smaller. 100 | _LINE_WIDTH_THIN = _PIXELS_PER_SQ // 50 101 | _LINE_WIDTH_THICK = _PIXELS_PER_SQ // 10 102 | 103 | # Dictionary mapping the four colored squares to their rectangle bounding boxes 104 | # used for visualization, with keys 0 (Red), 1 (Green), 2 (Yellow), 3 (Blue). 105 | _BOUNDING_BOXES = { 106 | idx: [(_PIXELS_PER_SQ * (x + 1), _PIXELS_PER_SQ * (y + 1)), 107 | (_PIXELS_PER_SQ * (x + 2), _PIXELS_PER_SQ * (y + 2))] 108 | for idx, (x, y) in _COLOR_POSITIONS.items() 109 | } 110 | 111 | 112 | class Action(enum.IntEnum): 113 | """Actions for the Taxi environment. 114 | 115 | There are six actions: 116 | 0: move North. 117 | 1: move West. 118 | 2: move South. 119 | 3: move East. 120 | 4: pickup the passenger. 121 | 5: dropoff the passenger. 122 | """ 123 | NORTH, WEST, SOUTH, EAST, PICKUP, DROPOFF = list(range(_NUM_ACTIONS)) 124 | 125 | @property 126 | def dx(self): 127 | """Maps EAST to 1, WEST to -1, and other actions to 0.""" 128 | if self.name == "EAST": 129 | return 1 130 | elif self.name == "WEST": 131 | return -1 132 | else: 133 | return 0 134 | 135 | @property 136 | def dy(self): 137 | """Maps NORTH to -1, SOUTH to 1, and other actions to 0.""" 138 | if self.name == "NORTH": 139 | return -1 140 | elif self.name == "SOUTH": 141 | return 1 142 | else: 143 | return 0 144 | 145 | 146 | @dataclasses.dataclass 147 | class State: 148 | """State of a continuing Taxi environment. 149 | 150 | The coordinate system excludes border of the map and provides the location of 151 | the taxi. The coordinate (0, 0) corresponds to the top left corner, 152 | i.e. the Red pickup location, and the coordinate (4, 4) corresponds to the 153 | bottom right corner. 154 | 155 | The passenger location is an integer in [0, 4] corresponding to the four 156 | colored squares and the fifth possible position being in the taxi. The 157 | destination is similarly numbered, but only includes the four colored squares: 158 | 0 - Red square. 159 | 1 - Green square. 160 | 2 - Yellow square. 161 | 3 - Blue square. 162 | 4 - In taxi. 163 | """ 164 | taxi_x: int 165 | taxi_y: int 166 | passenger_loc: int 167 | destination: int 168 | rng: np.random.Generator 169 | 170 | 171 | class Taxi(base.Environment): 172 | """A continuing Taxi environment. 173 | 174 | This environment originates from the paper "Hierarchical Reinforcement 175 | Learning with the MAXQ Value Function Decomposition" by Tom Dietterich. 176 | 177 | In a square grid world with four colored squares (R(ed), G(reen), Y(ellow), 178 | B(lue)), the agent must drive a taxi to various passengers' locations and 179 | drop them off at the passenger's desired location at one of the four squares. 180 | The agent receives positive reward for each passenger successfully picked up 181 | and dropped off, and receives negative reward for doing an inappropriate 182 | action (eg. dropping off the passenger at the incorrect location, attempting 183 | to pick up a passenger on an empty square, etc.). 184 | 185 | There are six possible actions, corresponding to four navigation actions 186 | (move North, South, East, and West), a pickup action, and a dropoff action. 187 | 188 | The observation space is a single state index, which encodes the possible 189 | states accounting for the taxi position, location of the passenger, and four 190 | desired destination locations. 191 | """ 192 | 193 | def __init__(self, seed=None): 194 | """Initialize Taxi environment. 195 | 196 | Args: 197 | seed: Seed for the internal random number generator. 198 | """ 199 | self._seed = seed 200 | self._state = None 201 | 202 | # Populate lookup table for observations. 203 | self.lookup_table = {} 204 | for idx, state in enumerate( 205 | itertools.product( 206 | range(_NUM_ROWS), range(_NUM_COLUMNS), range(_NUM_POSITIONS), 207 | range(_NUM_DEST))): 208 | self.lookup_table[state] = idx 209 | 210 | def start(self, seed: Optional[int] = None): 211 | """Initializes the environment and returns an initial observation.""" 212 | rng = np.random.default_rng(self._seed if seed is None else seed) 213 | self._state = State( 214 | taxi_x=rng.integers(_NUM_COLUMNS), 215 | taxi_y=rng.integers(_NUM_ROWS), 216 | passenger_loc=rng.integers(_NUM_POSITIONS - 1), 217 | destination=rng.integers(_NUM_DEST), 218 | rng=rng, 219 | ) 220 | return self._get_observation() 221 | 222 | @property 223 | def started(self): 224 | """True if the environment has been started, False otherwise.""" 225 | # An unspecified state implies that the environment needs to be started. 226 | return self._state is not None 227 | 228 | def step(self, action): 229 | """Updates the environment state and returns an observation and reward. 230 | 231 | Args: 232 | action: An integer in [0,5] indicating whether the taxi moves, picks up 233 | the passenger, or drops off the passenger. 234 | 235 | Returns: 236 | A tuple of type (int, float) giving the next observation and the reward. 237 | 238 | Raises: 239 | RuntimeError: If state has not yet been initialized by `start`. 240 | ValueError: If input action has an invalid value. 241 | """ 242 | # Check if state has been initialized. 243 | if not self.started: 244 | raise RuntimeError(base.STEP_WITHOUT_START_ERR) 245 | 246 | # Check if input action is valid. 247 | if action not in [a.value for a in Action]: 248 | raise ValueError(_INVALID_ACTION.format(action)) 249 | 250 | reward = 0 251 | # Move taxi according to the action. 252 | self._state.taxi_y = np.clip(self._state.taxi_y + Action(action).dy, 0, 253 | _NUM_ROWS - 1) 254 | # If moving East or West, check that the taxi does not hit a barrier. 255 | if action in [Action.EAST, Action.WEST]: 256 | move = ((self._state.taxi_x, self._state.taxi_y), 257 | (self._state.taxi_x + Action(action).dx, self._state.taxi_y)) 258 | if action == Action.WEST: # Need to reverse the tuple. 259 | move = move[::-1] 260 | if move not in _BLOCKED_TUPLE: 261 | self._state.taxi_x = np.clip(self._state.taxi_x + Action(action).dx, 0, 262 | _NUM_COLUMNS - 1) 263 | 264 | # If action is pickup, check if passenger location matches current location. 265 | if action == Action.PICKUP: 266 | if self._state.passenger_loc == 4: # Passenger was already picked up. 267 | reward = -10 268 | else: 269 | passenger_coordinates = _COLOR_POSITIONS[self._state.passenger_loc] 270 | # Check if passenger and taxi are at the same location. 271 | if passenger_coordinates != (self._state.taxi_x, self._state.taxi_y): 272 | reward = -10 273 | else: 274 | # Passenger has been successfully picked up. 275 | self._state.passenger_loc = 4 276 | 277 | # If action is dropoff, check if passenger is present in taxi and 278 | # desired destination matches current location. 279 | if action == Action.DROPOFF: 280 | dest_coordinates = _COLOR_POSITIONS[self._state.destination] 281 | if (self._state.passenger_loc != 4 or dest_coordinates != 282 | (self._state.taxi_x, self._state.taxi_y)): 283 | reward = -10 284 | else: 285 | reward = 20 286 | # Add new passenger. 287 | self._state.passenger_loc = self._state.rng.integers(_NUM_POSITIONS - 1) 288 | self._state.destination = self._state.rng.integers(_NUM_DEST) 289 | 290 | return self._get_observation(), reward 291 | 292 | def _get_observation(self): 293 | """Returns a observation index uniquely identifying the current state.""" 294 | state_tuple = (self._state.taxi_x, self._state.taxi_y, 295 | self._state.passenger_loc, self._state.destination) 296 | return self.lookup_table[state_tuple] 297 | 298 | def observation_spec(self): 299 | """Describes the observation specs of the environment.""" 300 | return specs.DiscreteArray(_NUM_STATES, dtype=int, name="observation") 301 | 302 | def action_spec(self): 303 | """Describes the action specs of the environment.""" 304 | return specs.DiscreteArray(_NUM_ACTIONS, dtype=int, name="action") 305 | 306 | def get_state(self): 307 | """Returns a copy of the current environment state.""" 308 | return copy.deepcopy(self._state) if self._state is not None else None 309 | 310 | def set_state(self, state): 311 | """Sets environment state to state provided. 312 | 313 | Args: 314 | state: A State object which overrides the current state. 315 | """ 316 | # Check that input state values are valid. 317 | if not (0 <= state.taxi_x < _NUM_COLUMNS and 0 <= state.taxi_y < _NUM_ROWS): 318 | raise ValueError(_INVALID_TAXI_LOC) 319 | elif not 0 <= state.passenger_loc < _NUM_POSITIONS: 320 | raise ValueError(_INVALID_PASS_LOC) 321 | elif not 0 <= state.destination < _NUM_DEST: 322 | raise ValueError(_INVALID_DEST) 323 | 324 | self._state = copy.deepcopy(state) 325 | 326 | def render(self) -> np.ndarray: 327 | """Creates an image of the current environment state. 328 | 329 | The underlying grid with the four colored squares are drawn. The taxi is 330 | drawn as a circle which is _PASS_LOC_HEX (purple) when the passenger is 331 | present, and _EMPTY_TAXI_HEX (grey) otherwise. The passenger location before 332 | being picked up is outlined with the color _PASS_LOC_HEX (purple), and the 333 | destination location is similarly outlined with the color _DEST_HEX 334 | (dark green). 335 | 336 | In the case where the passenger location and self._state.destination are 337 | identical, only the self._state.destination outline is visible. 338 | 339 | Returns: 340 | A NumPy array giving an image of the environment state. 341 | """ 342 | image = Image.new("RGB", (_WIDTH, _HEIGHT), "white") 343 | dct = ImageDraw.Draw(image) 344 | 345 | # First place four colored destination squares so grid lines appear on top. 346 | # Red, green, yellow, and blue squares. 347 | dct.rectangle(_BOUNDING_BOXES[0], fill=_RED_HEX) 348 | dct.rectangle(_BOUNDING_BOXES[1], fill=_GREEN_HEX) 349 | dct.rectangle(_BOUNDING_BOXES[2], fill=_YELLOW_HEX) 350 | dct.rectangle(_BOUNDING_BOXES[3], fill=_BLUE_HEX) 351 | 352 | # Draw basic grid. 353 | for row in range(1, _NUM_ROWS + 2): # horizontal grid lines. 354 | line_coordinates = [(_PIXELS_PER_SQ, _PIXELS_PER_SQ * row), 355 | (_PIXELS_PER_SQ * (_NUM_ROWS + 1), 356 | _PIXELS_PER_SQ * row)] 357 | dct.line(line_coordinates, fill="black", width=_LINE_WIDTH_THIN) 358 | for col in range(1, _NUM_COLUMNS + 2): # vertical grid lines. 359 | line_coordinates = [(_PIXELS_PER_SQ * col, _PIXELS_PER_SQ), 360 | (_PIXELS_PER_SQ * col, 361 | _PIXELS_PER_SQ * (_NUM_ROWS + 1))] 362 | dct.line(line_coordinates, fill="black", width=_LINE_WIDTH_THIN) 363 | 364 | # Draw barriers. 365 | dct.rectangle( 366 | _BORDER, # Grid perimeter. 367 | outline="black", 368 | width=_LINE_WIDTH_THICK) 369 | 370 | def get_barrier_coordinates(x, y): 371 | """Returns bounding box for barrier (length two down from input).""" 372 | return [(x, y), (x, y + 2 * _PIXELS_PER_SQ)] 373 | 374 | # Top barrier, bottom left barrier, bottom right barrier. 375 | dct.line( 376 | get_barrier_coordinates(3 * _PIXELS_PER_SQ, _PIXELS_PER_SQ), 377 | fill="black", 378 | width=_LINE_WIDTH_THICK) 379 | dct.line( 380 | get_barrier_coordinates(2 * _PIXELS_PER_SQ, 4 * _PIXELS_PER_SQ), 381 | fill="black", 382 | width=_LINE_WIDTH_THICK) 383 | dct.line( 384 | get_barrier_coordinates(4 * _PIXELS_PER_SQ, 4 * _PIXELS_PER_SQ), 385 | fill="black", 386 | width=_LINE_WIDTH_THICK) 387 | 388 | # Draw passenger location. 389 | if self._state.passenger_loc in range(4): 390 | taxi_color = _EMPTY_TAXI_HEX 391 | dct.rectangle( 392 | _BOUNDING_BOXES[self._state.passenger_loc], 393 | outline=_PASS_LOC_HEX, 394 | width=_LINE_WIDTH_THICK) 395 | else: 396 | taxi_color = _PASS_LOC_HEX 397 | 398 | # Draw taxi. 399 | def get_circle_coordinates(x, y): 400 | return [((x + 1) * _PIXELS_PER_SQ + _OFFSET, 401 | (y + 1) * _PIXELS_PER_SQ + _OFFSET), 402 | ((x + 2) * _PIXELS_PER_SQ - _OFFSET, 403 | (y + 2) * _PIXELS_PER_SQ - _OFFSET)] 404 | 405 | dct.ellipse( 406 | get_circle_coordinates(self._state.taxi_x, self._state.taxi_y), 407 | fill=taxi_color) 408 | 409 | # Draw self._state.destination location. 410 | dct.rectangle( 411 | _BOUNDING_BOXES[self._state.destination], 412 | outline=_DEST_HEX, 413 | width=_LINE_WIDTH_THICK) 414 | return np.asarray(image, dtype=np.uint8) 415 | -------------------------------------------------------------------------------- /csuite/environments/taxi_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for taxi.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from csuite.environments import taxi 21 | 22 | 23 | class TaxiTest(parameterized.TestCase): 24 | 25 | def test_environment_setup(self): 26 | """Tests environment initialization.""" 27 | env = taxi.Taxi() 28 | self.assertIsNotNone(env) 29 | 30 | def test_start(self): 31 | """Tests environment start.""" 32 | env = taxi.Taxi() 33 | 34 | with self.subTest(name='step_without_start'): 35 | # Calling step before start should raise an error. 36 | with self.assertRaises(RuntimeError): 37 | env.step(taxi.Action.NORTH) 38 | 39 | with self.subTest(name='start_state'): 40 | _ = env.start() 41 | state = env.get_state() 42 | self.assertIn(state.taxi_x, range(5)) 43 | self.assertIn(state.taxi_y, range(5)) 44 | # Original passenger location should not be in the taxi. 45 | self.assertIn(state.passenger_loc, range(4)) 46 | self.assertIn(state.destination, range(4)) 47 | 48 | @parameterized.parameters( 49 | (2, 2, taxi.Action.NORTH, True), (2, 0, taxi.Action.NORTH, False), 50 | (2, 2, taxi.Action.SOUTH, True), (2, 4, taxi.Action.SOUTH, False), 51 | (1, 4, taxi.Action.EAST, True), (1, 1, taxi.Action.EAST, False), 52 | (2, 4, taxi.Action.WEST, True), (2, 1, taxi.Action.WEST, False)) 53 | def test_one_movement_step(self, x, y, action, can_move): 54 | """Tests one step with a movement action (North, East, South, West).""" 55 | env = taxi.Taxi() 56 | env.start() 57 | cur_state = env.get_state() 58 | 59 | # Create new state from input parameters and set environment to this state. 60 | cur_state.taxi_x = x 61 | cur_state.taxi_y = y 62 | cur_state.passenger_loc = 0 63 | cur_state.destination = 2 64 | env.set_state(cur_state) 65 | 66 | # Take movement step provided. 67 | env.step(action) 68 | next_state = env.get_state() 69 | if can_move: 70 | self.assertEqual(next_state.taxi_x, 71 | cur_state.taxi_x + taxi.Action(action).dx) 72 | self.assertEqual(next_state.taxi_y, 73 | cur_state.taxi_y + taxi.Action(action).dy) 74 | else: 75 | self.assertEqual(next_state.taxi_x, cur_state.taxi_x) 76 | self.assertEqual(next_state.taxi_y, cur_state.taxi_y) 77 | 78 | @parameterized.parameters((0, 0, 0, 2, taxi.Action.PICKUP, True), 79 | (0, 1, 0, 2, taxi.Action.PICKUP, False), 80 | (0, 1, 4, 2, taxi.Action.PICKUP, False), 81 | (3, 4, 4, 3, taxi.Action.DROPOFF, True), 82 | (2, 4, 4, 3, taxi.Action.DROPOFF, False), 83 | (3, 4, 3, 3, taxi.Action.DROPOFF, False)) 84 | def test_pickup_dropoff(self, x, y, pass_loc, dest, action, is_success): 85 | """Tests the two passenger actions (pickup and dropoff).""" 86 | env = taxi.Taxi() 87 | env.start() 88 | cur_state = env.get_state() 89 | 90 | # Create new state from input parameters and set environment to this state. 91 | cur_state.taxi_x = x 92 | cur_state.taxi_y = y 93 | cur_state.passenger_loc = pass_loc 94 | cur_state.destination = dest 95 | env.set_state(cur_state) 96 | _, reward = env.step(action) 97 | 98 | # Check correct reward: for successful dropoffs, reward is 20. For 99 | # successful pickups, reward is 0. For incorrect actions, reward is -10. 100 | if is_success and action == taxi.Action.DROPOFF: 101 | self.assertEqual(reward, 20) 102 | elif is_success: 103 | self.assertEqual(reward, 0) 104 | else: 105 | self.assertEqual(reward, -10) 106 | 107 | if is_success and action == taxi.Action.PICKUP: 108 | # Check passenger is in the taxi. 109 | next_state = env.get_state() 110 | self.assertEqual(next_state.passenger_loc, 4) 111 | 112 | def test_runs_from_start(self): 113 | """Tests running a full passenger pickup and dropoff sequence.""" 114 | env = taxi.Taxi() 115 | env.start() 116 | cur_state = env.get_state() 117 | 118 | # Set state to have passenger and taxi on the red square, and destination 119 | # on the blue square. 120 | cur_state.taxi_x = 0 121 | cur_state.taxi_y = 0 122 | cur_state.passenger_loc = 0 123 | cur_state.destination = 3 124 | 125 | env.set_state(cur_state) 126 | # Pick up the passenger. 127 | env.step(taxi.Action.PICKUP) 128 | 129 | for _ in range(2): 130 | _, reward = env.step(taxi.Action.SOUTH) 131 | self.assertEqual(reward, 0) 132 | for _ in range(3): 133 | _, reward = env.step(taxi.Action.EAST) 134 | self.assertEqual(reward, 0) 135 | for _ in range(2): 136 | _, reward = env.step(taxi.Action.SOUTH) 137 | self.assertEqual(reward, 0) 138 | 139 | # Drop off the passenger. 140 | _, reward = env.step(taxi.Action.DROPOFF) 141 | self.assertEqual(reward, 20) 142 | 143 | 144 | if __name__ == '__main__': 145 | absltest.main() 146 | -------------------------------------------------------------------------------- /csuite/environments/windy_catch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """WindyCatch environment. 17 | 18 | Environment description can be found in the `WindyCatch` environment class. 19 | """ 20 | import copy 21 | import dataclasses 22 | import enum 23 | from typing import Optional 24 | 25 | from csuite.environments import base 26 | from csuite.environments import common 27 | from dm_env import specs 28 | 29 | import numpy as np 30 | 31 | # Error messages. 32 | _INVALID_ACTION = "Invalid action: expected 0, 1, or 2 but received {action}." 33 | _INVALID_PADDLE_POS = ("Invalid state: paddle should be positioned at the" 34 | " bottom of the board.") 35 | _INVALID_BALLS_RANGE = ( 36 | "Invalid state: positions of balls and paddle not in expected" 37 | " row range [0, {rows}) and column range [0, {columns}).") 38 | _INVALID_WIND_DIRECTION = ( 39 | "Invalid state: expected exactly one entry of wind_direction to be True.") 40 | 41 | # Default environment variables. 42 | _ROWS = 10 43 | _COLUMNS = 5 44 | _SPAWN_PROBABILITY = 0.1 45 | _CHANGE_EVERY = 100000 46 | _WIND_DELTA = [0, -1, 1] 47 | 48 | 49 | class Action(enum.IntEnum): 50 | LEFT = 0 51 | STAY = 1 52 | RIGHT = 2 53 | 54 | @property 55 | def dx(self): 56 | """Maps LEFT to -1, STAY to 0 and RIGHT to 1.""" 57 | return self.value - 1 58 | 59 | 60 | @dataclasses.dataclass 61 | class Params: 62 | """Windy catch parameters. 63 | 64 | Attributes: 65 | rows: Integer number of rows. 66 | columns: Integer number of columns. 67 | observation_dim: Integer dimension of the observation features. 68 | spawn_probability: Probability of a new ball spawning. 69 | change_every: Integer giving the interval at which direction of the wind 70 | changes. 71 | """ 72 | rows: int 73 | columns: int 74 | observation_dim: int 75 | spawn_probability: float 76 | change_every: int 77 | 78 | 79 | @dataclasses.dataclass 80 | class State: 81 | """Windy catch state. 82 | 83 | Attributes: 84 | paddle_x: An integer denoting the x-coordinate of the paddle. 85 | paddle_y: An integer denoting the y-coordinate of the paddle 86 | balls: A list of (x, y) coordinates representing the present balls. 87 | wind_direction: List of three booleans (no wind, left wind, right wind); 88 | only one is True. 89 | time_since_wind_change: An integer denoting how many timesteps have elapsed 90 | since the last change in wind direction. 91 | rng: Internal NumPy pseudo-random number generator, included here for 92 | reproducibility purposes. 93 | """ 94 | paddle_x: int 95 | paddle_y: int 96 | balls: list[tuple[int, int]] 97 | wind_direction: list[bool] 98 | time_since_wind_change: int 99 | rng: np.random.Generator 100 | 101 | 102 | class WindyCatch(base.Environment): 103 | """A windy catch enviornment. 104 | 105 | Wind moves a falling ball by a column, depending on the direction. Leftward 106 | wind moves the ball to the left, rightware wind moves the ball to 107 | the right. If there is no wind, the ball stays in the same column. The 108 | direction of the wind (or absence thereof) is observable through three bits 109 | the activations of which are mutually exclusive. Every K steps, the wind 110 | changes to one of the three possibilities. 111 | 112 | The environment is fully-observable and has stationary dynamics. 113 | """ 114 | 115 | def __init__(self, 116 | rows=_ROWS, 117 | columns=_COLUMNS, 118 | spawn_probability=_SPAWN_PROBABILITY, 119 | seed=None, 120 | change_every=_CHANGE_EVERY): 121 | """Initialize the windy catch environment. 122 | 123 | Args: 124 | rows: A positive integer denoting the number of rows. 125 | columns: A positive integer denoting the number of columns. 126 | spawn_probability: Float giving the probability of a new ball appearing. 127 | seed: Seed for the internal random number generator. 128 | change_every: A positive integer denoting the interval at which wind 129 | changes. 130 | """ 131 | self._seed = seed 132 | self._params = Params( 133 | rows=rows, 134 | columns=columns, 135 | observation_dim=rows * columns + 3, 136 | spawn_probability=spawn_probability, 137 | change_every=change_every) 138 | self._state = None 139 | 140 | def start(self, seed: Optional[int] = None): 141 | """Initializes the environment and returns an initial observation.""" 142 | 143 | # The initial state has one ball appearing in a random column at the top, 144 | # and the paddle centered at the bottom. 145 | 146 | rng = np.random.default_rng(self._seed if seed is None else seed) 147 | self._state = State( 148 | paddle_x=self._params.columns // 2, 149 | paddle_y=self._params.rows - 1, 150 | balls=[(rng.integers(self._params.columns), 0)], 151 | wind_direction=[True, False, False], 152 | time_since_wind_change=0, 153 | rng=rng, 154 | ) 155 | return self._get_observation() 156 | 157 | @property 158 | def started(self): 159 | """True if the environment has been started, False otherwise.""" 160 | # An unspecified state implies that the environment needs to be started. 161 | return self._state is not None 162 | 163 | def step(self, action): 164 | """Updates the environment state and returns an observation and reward. 165 | 166 | Args: 167 | action: An integer equalling 0, 1, or 2 indicating whether to move the 168 | paddle left, stay, or move the paddle right respectively. 169 | 170 | Returns: 171 | A tuple of type (int, float) giving the next observation and the reward. 172 | 173 | Raises: 174 | RuntimeError: If state has not yet been initialized by `start`. 175 | """ 176 | # Check if state has been initialized. 177 | if not self.started: 178 | raise RuntimeError(base.STEP_WITHOUT_START_ERR) 179 | 180 | # Check if input action is valid. 181 | if action not in [Action.LEFT, Action.STAY, Action.RIGHT]: 182 | raise ValueError(_INVALID_ACTION.format(action=action)) 183 | 184 | # Move the paddle. 185 | self._state.paddle_x = np.clip(self._state.paddle_x + Action(action).dx, 0, 186 | self._params.columns - 1) 187 | 188 | # Move all balls down by one unit, with wind. 189 | wd = _WIND_DELTA[self._state.wind_direction.index(True)] 190 | self._state.balls = [ 191 | # x coord: applies wind; y coord: gravity 192 | ((x + wd) % self._params.columns, y + 1) for x, y in self._state.balls 193 | ] 194 | 195 | # Since at most one ball is added at each timestep, at most one ball 196 | # can be at the bottom of the board, and must be the 'oldest' ball. 197 | reward = 0. 198 | if self._state.balls and self._state.balls[0][1] == self._state.paddle_y: 199 | reward = 1. if self._state.balls[0][0] == self._state.paddle_x else -1. 200 | # Remove ball from list. 201 | self._state.balls = self._state.balls[1:] 202 | 203 | # Add new ball with given probability. 204 | if self._state.rng.random() < self._params.spawn_probability: 205 | self._state.balls.append( 206 | (self._state.rng.integers(self._params.columns), 0)) 207 | 208 | # Update time since last change in wind. 209 | self._state.time_since_wind_change += 1 210 | 211 | # Update the wind direction. 212 | if self._state.time_since_wind_change % self._params.change_every == 0: 213 | self._state.wind_direction = [False, False, False] 214 | self._state.wind_direction[self._state.rng.integers(3)] = True 215 | self._state.time_since_wind_change = 0 216 | 217 | return self._get_observation(), reward 218 | 219 | def _get_board(self) -> np.ndarray: 220 | board = np.zeros((_ROWS, _COLUMNS), dtype=int) 221 | board.fill(0) 222 | board[self._state.paddle_y, self._state.paddle_x] = 1 223 | for x, y in self._state.balls: 224 | board[y, x] = 1 225 | return board 226 | 227 | def _get_observation(self) -> np.ndarray: 228 | """Converts internal environment state to an array observation. 229 | 230 | Returns: 231 | A binary array of size (rows * columns + 3,). 232 | """ 233 | return np.concatenate([ 234 | self._get_board().flatten(), self._state.wind_direction]) 235 | 236 | def observation_spec(self): 237 | """Describes the observation specs of the environment.""" 238 | return specs.BoundedArray( 239 | shape=(self._params.observation_dim,), 240 | dtype=int, 241 | minimum=0, 242 | maximum=1, 243 | name="board") 244 | 245 | def action_spec(self): 246 | """Describes the action specs of the environment.""" 247 | return specs.DiscreteArray(num_values=3, dtype=int, name="action") 248 | 249 | def get_state(self): 250 | """Returns a copy of the current environment state.""" 251 | return copy.deepcopy(self._state) if self._state is not None else None 252 | 253 | def set_state(self, state): 254 | """Sets environment state to state provided. 255 | 256 | Args: 257 | state: A State object which overrides the current state. 258 | """ 259 | # Check that input state values are valid. 260 | if not (0 <= state.paddle_x < self._params.columns and 261 | state.paddle_y == self._params.rows - 1): 262 | raise ValueError(_INVALID_PADDLE_POS) 263 | 264 | for x, y in state.balls: 265 | if not (0 <= x < self._params.columns and 0 <= y < self._params.rows): 266 | raise ValueError( 267 | _INVALID_BALLS_RANGE.format( 268 | rows=self._params.rows, columns=self._params.columns)) 269 | 270 | if sum(state.wind_direction) != 1: 271 | raise ValueError(_INVALID_WIND_DIRECTION) 272 | 273 | self._state = copy.deepcopy(state) 274 | 275 | def get_config(self): 276 | """Returns a copy of the environment configuration.""" 277 | return copy.deepcopy(self._params) 278 | 279 | def render(self) -> np.ndarray: 280 | board = self._get_board() 281 | num_cols = board.shape[1] 282 | header = np.ones((2, num_cols), dtype=np.uint8) 283 | center = num_cols // 2 284 | header[0, center-1:center+2] = self._state.wind_direction 285 | return common.binary_board_to_rgb(np.concatenate([board, header])) 286 | -------------------------------------------------------------------------------- /csuite/environments/windy_catch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for WindyCatch.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from csuite.environments import windy_catch 21 | 22 | 23 | class WindyCatchTest(parameterized.TestCase): 24 | 25 | def test_environment_setup(self): 26 | """Tests environment initialization.""" 27 | env = windy_catch.WindyCatch() 28 | self.assertIsNotNone(env) 29 | 30 | def test_start(self): 31 | """Tests environment start.""" 32 | env = windy_catch.WindyCatch() 33 | params = env.get_config() 34 | 35 | with self.subTest(name='step_without_start'): 36 | # Calling step before start should raise an error. 37 | with self.assertRaises(RuntimeError): 38 | env.step(windy_catch.Action.LEFT) 39 | 40 | with self.subTest(name='start_state'): 41 | start_obs = env.start() 42 | state = env.get_state() 43 | # Paddle should be positioned at the bottom of the board. 44 | self.assertEqual(state.paddle_y, params.rows - 1) 45 | paddle_idx = state.paddle_y * params.columns + state.paddle_x 46 | self.assertEqual(start_obs[paddle_idx], 1) 47 | 48 | # First ball should be positioned at the top of the board. 49 | ball_x = state.balls[0][0] 50 | ball_y = state.balls[0][1] 51 | self.assertEqual(ball_y, 0) 52 | ball_idx = ball_y * params.columns + ball_x 53 | self.assertEqual(start_obs[ball_idx], 1) 54 | 55 | def test_invalid_state(self): 56 | """Tests setting environment state with invalid fields.""" 57 | env = windy_catch.WindyCatch() 58 | env.start() 59 | 60 | with self.subTest(name='paddle_out_of_range'): 61 | new_state = env.get_state() 62 | new_state.paddle_x = 5 63 | with self.assertRaises(ValueError): 64 | env.set_state(new_state) 65 | 66 | with self.subTest(name='balls_out_of_range'): 67 | new_state = env.get_state() 68 | new_state.balls = [(0, -1)] 69 | with self.assertRaises(ValueError): 70 | env.set_state(new_state) 71 | 72 | @parameterized.parameters((0, 0, 1), (2, 1, 3), (4, 3, 4)) 73 | def test_one_step(self, paddle_x, expected_left_x, expected_right_x): 74 | """Tests one environment step given the x-position of the paddle.""" 75 | env = windy_catch.WindyCatch() 76 | env.start() 77 | 78 | with self.subTest(name='invalid_action'): 79 | with self.assertRaises(ValueError): 80 | env.step(3) 81 | 82 | with self.subTest(name='move_left_step'): 83 | current_state = env.get_state() 84 | current_state.paddle_x = paddle_x 85 | env.set_state(current_state) 86 | 87 | env.step(windy_catch.Action.LEFT) 88 | state = env.get_state() 89 | 90 | # Paddle x-position should have moved left by 1 unless at the edge. 91 | self.assertEqual(state.paddle_x, expected_left_x) 92 | 93 | with self.subTest(name='move_right_step'): 94 | current_state = env.get_state() 95 | current_state.paddle_x = paddle_x 96 | env.set_state(current_state) 97 | 98 | env.step(windy_catch.Action.RIGHT) 99 | state = env.get_state() 100 | 101 | # Paddle x-position should have moved right by 1 unless at the edge. 102 | self.assertEqual(state.paddle_x, expected_right_x) 103 | 104 | with self.subTest(name='stay_step'): 105 | current_state = env.get_state() 106 | current_state.paddle_x = paddle_x 107 | env.set_state(current_state) 108 | 109 | env.step(windy_catch.Action.STAY) 110 | state = env.get_state() 111 | self.assertEqual(state.paddle_x, paddle_x) 112 | 113 | def test_wind(self): 114 | """Tests the wind.""" 115 | env = windy_catch.WindyCatch(spawn_probability=0.0) 116 | env.start() 117 | 118 | with self.subTest(name='wind_stay'): 119 | state = env.get_state() 120 | state.balls = [(0, 0), (2, 0), (4, 0)] 121 | state.wind_direction = [True, False, False] 122 | env.set_state(state) 123 | 124 | env.step(windy_catch.Action.STAY) 125 | 126 | b0, b1, b2 = env.get_state().balls 127 | self.assertEqual(b0[0], 0) 128 | self.assertEqual(b0[1], 1) 129 | self.assertEqual(b1[0], 2) 130 | self.assertEqual(b1[1], 1) 131 | self.assertEqual(b2[0], 4) 132 | self.assertEqual(b2[1], 1) 133 | 134 | with self.subTest(name='wind_left'): 135 | state = env.get_state() 136 | state.balls = [(0, 0), (2, 0), (4, 0)] 137 | state.wind_direction = [False, True, False] 138 | env.set_state(state) 139 | 140 | env.step(windy_catch.Action.STAY) 141 | 142 | b0, b1, b2 = env.get_state().balls 143 | self.assertEqual(b0[0], 4) 144 | self.assertEqual(b0[1], 1) 145 | self.assertEqual(b1[0], 1) 146 | self.assertEqual(b1[1], 1) 147 | self.assertEqual(b2[0], 3) 148 | self.assertEqual(b2[1], 1) 149 | 150 | with self.subTest(name='wind_right'): 151 | state = env.get_state() 152 | state.balls = [(0, 0), (2, 0), (4, 0)] 153 | state.wind_direction = [False, False, True] 154 | env.set_state(state) 155 | 156 | env.step(windy_catch.Action.STAY) 157 | 158 | b0, b1, b2 = env.get_state().balls 159 | self.assertEqual(b0[0], 1) 160 | self.assertEqual(b0[1], 1) 161 | self.assertEqual(b1[0], 3) 162 | self.assertEqual(b1[1], 1) 163 | self.assertEqual(b2[0], 0) 164 | self.assertEqual(b2[1], 1) 165 | 166 | 167 | if __name__ == '__main__': 168 | absltest.main() 169 | -------------------------------------------------------------------------------- /csuite/utils/dm_env_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Wrapper for converting a csuite base.Environment to dm_env.Environment.""" 17 | 18 | from csuite.environments import base 19 | import dm_env 20 | 21 | 22 | class DMEnvFromCSuite(dm_env.Environment): 23 | """A wrapper to convert a CSuite environment to a dm_env.Environment.""" 24 | 25 | def __init__(self, csuite_env: base.Environment): 26 | self._csuite_env = csuite_env 27 | self._started = False 28 | 29 | def reset(self) -> dm_env.TimeStep: 30 | observation = self._csuite_env.start() 31 | self._started = True 32 | return dm_env.restart(observation) 33 | 34 | def step(self, action) -> dm_env.TimeStep: 35 | if not self._started: 36 | return self.reset() 37 | # Convert the csuite step result to a dm_env TimeStep. 38 | observation, reward = self._csuite_env.step(action) 39 | return dm_env.TimeStep( 40 | step_type=dm_env.StepType.MID, 41 | observation=observation, 42 | reward=reward, 43 | discount=1.0) 44 | 45 | def observation_spec(self): 46 | return self._csuite_env.observation_spec() 47 | 48 | def action_spec(self): 49 | return self._csuite_env.action_spec() 50 | 51 | def get_state(self): 52 | return self._csuite_env.get_state() 53 | 54 | def set_state(self, state): 55 | self._csuite_env.set_state(state) 56 | 57 | def render(self): 58 | return self._csuite_env.render() 59 | -------------------------------------------------------------------------------- /csuite/utils/dm_env_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Test for DMEnvFromCSuite.""" 17 | 18 | from absl.testing import absltest 19 | import csuite 20 | from dm_env import test_utils 21 | 22 | 23 | class DMEnvFromCSuiteTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 24 | 25 | def make_object_under_test(self): 26 | csuite_env = csuite.load('catch') 27 | return csuite.dm_env_wrapper.DMEnvFromCSuite(csuite_env) 28 | 29 | 30 | if __name__ == '__main__': 31 | absltest.main() 32 | -------------------------------------------------------------------------------- /csuite/utils/gym_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Wrapper for adapating a csuite base.Environment to OpenAI gym interface.""" 17 | 18 | import typing 19 | from typing import Any, Dict, Tuple, Optional, Union 20 | 21 | from csuite.environments import base 22 | from dm_env import specs 23 | import gym 24 | from gym import spaces 25 | import numpy as np 26 | 27 | # OpenAI gym step format = obs, reward, is_finished, other_info 28 | _GymTimestep = Tuple[np.ndarray, float, bool, Dict[str, Any]] 29 | 30 | 31 | class GymFromCSuite(gym.Env): 32 | """A wrapper to convert a CSuite environment to an OpenAI gym.Env.""" 33 | 34 | metadata = {'render.modes': ['human', 'rgb_array']} 35 | 36 | def __init__(self, csuite_env: base.Environment): 37 | self._csuite_env = csuite_env 38 | self._np_random = None # GYM env_checker requires a _np_random attr 39 | self.viewer = None 40 | 41 | def step(self, action) -> _GymTimestep: 42 | # Convert the csuite step result to a gym timestep. 43 | try: 44 | observation, reward = self._csuite_env.step(action) 45 | except RuntimeError as e: 46 | # The gym.utils.env_checker expects the following assertion. 47 | if str(e) == base.STEP_WITHOUT_START_ERR: 48 | assert False, 'Cannot call env.step() before calling reset()' 49 | 50 | return observation, reward, False, {} 51 | 52 | def reset(self, 53 | seed: Optional[int] = None, 54 | options: Optional[dict] = None): # pylint: disable=g-bare-generic 55 | if options: 56 | raise NotImplementedError('options not supported with the gym wrapper.') 57 | del options 58 | observation = self._csuite_env.start(seed) 59 | state = self._csuite_env.get_state() 60 | self._np_random = state.rng if hasattr( 61 | state, 'rng') else np.random.default_rng(seed) 62 | if gym.__version__ == '0.19.0': 63 | return observation 64 | else: 65 | return observation, {} 66 | 67 | def render(self, mode: str = 'rgb_array') -> Union[np.ndarray, bool]: 68 | 69 | if mode == 'rgb_array': 70 | return self._csuite_env.render() 71 | 72 | if mode == 'human': 73 | if self.viewer is None: 74 | # pylint: disable=import-outside-toplevel 75 | # pylint: disable=g-import-not-at-top 76 | from gym.envs.classic_control import rendering 77 | self.viewer = rendering.SimpleImageViewer() 78 | self.viewer.imshow(self._csuite_env.render()) 79 | return self.viewer.isopen 80 | 81 | @property 82 | def action_space(self) -> spaces.Discrete: 83 | action_spec = self._csuite_env.action_spec() 84 | if isinstance(action_spec, specs.DiscreteArray): 85 | action_spec = typing.cast(specs.DiscreteArray, action_spec) 86 | return spaces.Discrete(action_spec.num_values) 87 | else: 88 | raise NotImplementedError( 89 | 'The gym wrapper only supports environments with discrete action ' 90 | 'spaces. Please raise an issue if you want to work with a ' 91 | 'a non-discrete action space.') 92 | 93 | @property 94 | def observation_space(self) -> spaces.Box: 95 | obs_spec = self._csuite_env.observation_spec() 96 | if isinstance(obs_spec, specs.BoundedArray): 97 | return spaces.Box( 98 | low=float(obs_spec.minimum), 99 | high=float(obs_spec.maximum), 100 | shape=obs_spec.shape, 101 | dtype=obs_spec.dtype) 102 | return spaces.Box( 103 | low=-float('inf'), 104 | high=float('inf'), 105 | shape=obs_spec.shape, 106 | dtype=obs_spec.dtype) 107 | 108 | @property 109 | def reward_range(self) -> Tuple[float, float]: 110 | # CSuite does not return reward range. 111 | return -float('inf'), float('inf') 112 | 113 | def __getattr__(self, attr): 114 | """Delegate attribute access to underlying environment.""" 115 | return getattr(self._csuite_env, attr) 116 | -------------------------------------------------------------------------------- /csuite/utils/gym_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for GymFromCSuite.""" 17 | 18 | from absl.testing import absltest 19 | import csuite 20 | from gym.utils import env_checker as gym_env_checker 21 | 22 | 23 | class GymFromCSuiteTest(absltest.TestCase): 24 | 25 | def test_env_checker(self): 26 | csuite_env = csuite.load('catch') 27 | gym_env = csuite.gym_wrapper.GymFromCSuite(csuite_env) 28 | # Gym's env_checker.check_env throws an exception if the env does not 29 | # conform to the Gym API: 30 | gym_env_checker.check_env(gym_env) 31 | 32 | 33 | if __name__ == '__main__': 34 | absltest.main() 35 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/_static/img/taxi_grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/csuite/b74faed74685dc852d43c2022150c4186c244929/docs/_static/img/taxi_grid.png -------------------------------------------------------------------------------- /docs/_static/img/taxi_pickup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/csuite/b74faed74685dc852d43c2022150c4186c244929/docs/_static/img/taxi_pickup.png -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # Environment Interface 2 | All CSuite environments adhere to the Python interface defined in the abstract 3 | base class `csuite.Environment`. This base class specifies the standard methods 4 | to implement in a CSuite environment, which are outlined below. 5 | 6 | ```{eval-rst} 7 | .. autoclass:: csuite.Environment 8 | ``` 9 | 10 | ## Loading a CSuite Environment 11 | Environments in CSuite are specified by an identifying string and can be 12 | initialized using the `load` function. 13 | 14 | ```python 15 | import csuite 16 | 17 | env = csuite.load('catch') 18 | ``` 19 | 20 | The list of available environments and their associated loading strings is 21 | given in the `csuite.EnvName` class. 22 | 23 | ```{eval-rst} 24 | .. autoclass:: csuite.EnvName 25 | ``` 26 | 27 | ## API Methods 28 | ### Start 29 | After initialization, `start` must be called to set the environment state. Since 30 | all environments are continuing, `start` should only be called *once*, at the 31 | beginning of the agent-environment interaction. 32 | 33 | ```{eval-rst} 34 | .. automethod:: csuite.Environment.start 35 | ``` 36 | 37 | ### Step 38 | After `start` is called, `step` updates the environment by one timestep 39 | given the action taken. The resulting observation and reward are returned. 40 | 41 | ```{eval-rst} 42 | .. automethod:: csuite.Environment.step 43 | ``` 44 | 45 | ### Render 46 | All CSuite environments are expected to return an object serving to render 47 | the environment for visualization at the current timestep. 48 | 49 | ```{eval-rst} 50 | .. automethod:: csuite.Environment.render 51 | ``` 52 | 53 | ### Get and Set State 54 | The `get_state` and `set_state` methods permit environment state retrieval 55 | and manipulation. These methods should only be used for reproducibility or 56 | checkpointing purposes; thus, it is expected that these methods can sufficiently 57 | manipulate the internal state to provide full reproducibility of the environment 58 | dynamics (supplying the internal random number generator if applicable, 59 | for example). 60 | 61 | ```{eval-rst} 62 | .. automethod:: csuite.Environment.get_state 63 | 64 | .. automethod:: csuite.Environment.set_state 65 | ``` 66 | 67 | ### Observation and Action Specs 68 | Environments are expected to return the specifications of the observation and 69 | action space by calling `observation_spec` and `action_spec` respectively. 70 | These methods should return structures of dm_env 71 | [`Array` specs](https://github.com/deepmind/dm_env/blob/master/dm_env/specs.py) 72 | which adhere exactly to the format of observations and actions. 73 | 74 | ```{eval-rst} 75 | .. automethod:: csuite.Environment.observation_spec 76 | 77 | .. automethod:: csuite.Environment.action_spec 78 | ``` 79 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Configuration file for the Sphinx documentation builder.""" 17 | 18 | # This file only contains a selection of the most common options. For a full 19 | # list see the documentation: 20 | # http://www.sphinx-doc.org/en/master/config 21 | 22 | # -- Path setup -------------------------------------------------------------- 23 | 24 | # If extensions (or modules to document with autodoc) are in another directory, 25 | # add these directories to sys.path here. If the directory is relative to the 26 | # documentation root, use os.path.abspath to make it absolute, like shown here. 27 | 28 | # pylint: disable=g-bad-import-order 29 | # pylint: disable=g-import-not-at-top 30 | import inspect 31 | import os 32 | import sys 33 | import typing 34 | 35 | 36 | def _add_annotations_import(path): 37 | """Appends a future annotations import to the file at the given path.""" 38 | with open(path) as f: 39 | contents = f.read() 40 | if contents.startswith('from __future__ import annotations'): 41 | # If we run sphinx multiple times then we will append the future import 42 | # multiple times too. 43 | return 44 | 45 | assert contents.startswith('#'), (path, contents.split('\n')[0]) 46 | with open(path, 'w') as f: 47 | # NOTE: This is subtle and not unit tested, we're prefixing the first line 48 | # in each Python file with this future import. It is important to prefix 49 | # not insert a newline such that source code locations are accurate (we link 50 | # to GitHub). The assertion above ensures that the first line in the file is 51 | # a comment so it is safe to prefix it. 52 | f.write('from __future__ import annotations ') 53 | f.write(contents) 54 | 55 | 56 | def _recursive_add_annotations_import(): 57 | for path, _, files in os.walk('../csuite/'): 58 | for file in files: 59 | if file.endswith('.py'): 60 | _add_annotations_import(os.path.abspath(os.path.join(path, file))) 61 | 62 | 63 | if 'READTHEDOCS' in os.environ: 64 | _recursive_add_annotations_import() 65 | 66 | typing.get_type_hints = lambda obj, *unused: obj.__annotations__ 67 | sys.path.insert(0, os.path.abspath('../')) 68 | sys.path.append(os.path.abspath('ext')) 69 | 70 | import csuite 71 | from sphinxcontrib import katex 72 | 73 | # -- Project information ----------------------------------------------------- 74 | 75 | project = 'CSuite' 76 | copyright = '2022, DeepMind' # pylint: disable=redefined-builtin 77 | author = 'CSuite Contributors' 78 | 79 | # -- General configuration --------------------------------------------------- 80 | 81 | master_doc = 'index' 82 | 83 | # Add any Sphinx extension module names here, as strings. They can be 84 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 85 | # ones. 86 | extensions = [ 87 | 'myst_parser', 88 | 'sphinx.ext.autodoc', 89 | 'sphinx.ext.autosummary', 90 | 'sphinx.ext.doctest', 91 | 'sphinx.ext.inheritance_diagram', 92 | 'sphinx.ext.intersphinx', 93 | 'sphinx.ext.linkcode', 94 | 'sphinx.ext.napoleon', 95 | 'sphinxcontrib.katex', 96 | 'sphinx_autodoc_typehints', 97 | 'sphinx_rtd_theme', 98 | # 'coverage_check', 99 | ] 100 | 101 | # Add any paths that contain templates here, relative to this directory. 102 | templates_path = ['_templates'] 103 | 104 | # List of patterns, relative to source directory, that match files and 105 | # directories to ignore when looking for source files. 106 | # This pattern also affects html_static_path and html_extra_path. 107 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 108 | 109 | # -- Options for autodoc ----------------------------------------------------- 110 | 111 | autodoc_default_options = { 112 | 'member-order': 'bysource', 113 | 'special-members': True, 114 | 'exclude-members': '__repr__, __str__, __weakref__', 115 | } 116 | 117 | # -- Options for HTML output ------------------------------------------------- 118 | 119 | # The theme to use for HTML and HTML Help pages. See the documentation for 120 | # a list of builtin themes. 121 | # 122 | html_theme = 'sphinx_rtd_theme' 123 | 124 | # Add any paths that contain custom static files (such as style sheets) here, 125 | # relative to this directory. They are copied after the builtin static files, 126 | # so a file named "default.css" will overwrite the builtin "default.css". 127 | html_static_path = [] 128 | # html_favicon = '_static/favicon.ico' 129 | 130 | # -- Options for katex ------------------------------------------------------ 131 | 132 | # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html 133 | latex_macros = r""" 134 | \def \d #1{\operatorname{#1}} 135 | """ 136 | 137 | # Translate LaTeX macros to KaTeX and add to options for HTML builder 138 | katex_macros = katex.latex_defs_to_katex_macros(latex_macros) 139 | katex_options = 'macros: {' + katex_macros + '}' 140 | 141 | # Add LaTeX macros for LATEX builder 142 | latex_elements = {'preamble': latex_macros} 143 | 144 | # -- Source code links ------------------------------------------------------- 145 | 146 | 147 | def linkcode_resolve(domain, info): 148 | """Resolve a GitHub URL corresponding to Python object.""" 149 | if domain != 'py': 150 | return None 151 | 152 | try: 153 | mod = sys.modules[info['module']] 154 | except ImportError: 155 | return None 156 | 157 | obj = mod 158 | try: 159 | for attr in info['fullname'].split('.'): 160 | obj = getattr(obj, attr) 161 | except AttributeError: 162 | return None 163 | else: 164 | obj = inspect.unwrap(obj) 165 | 166 | try: 167 | filename = inspect.getsourcefile(obj) 168 | except TypeError: 169 | return None 170 | 171 | try: 172 | source, lineno = inspect.getsourcelines(obj) 173 | except OSError: 174 | return None 175 | 176 | return 'https://github.com/deepmind/csuite/tree/master/csuite/%s#L%d#L%d' % ( 177 | os.path.relpath(filename, start=os.path.dirname( 178 | csuite.__file__)), lineno, lineno + len(source) - 1) 179 | 180 | 181 | # -- Intersphinx configuration ----------------------------------------------- 182 | 183 | intersphinx_mapping = { 184 | 'jax': ('https://jax.readthedocs.io/en/latest/', None), 185 | } 186 | 187 | source_suffix = ['.rst', '.md'] 188 | -------------------------------------------------------------------------------- /docs/environments/access_control.md: -------------------------------------------------------------------------------- 1 | # Access-Control 2 | | Overview | Specification | 3 | |-------------------|--------------------| 4 | | Observation Space | \{0, 1, ..., 43\} | 5 | | Action Space | \{0, 1\} | 6 | | Reward Space | \{1, 2, 4, 8\} | 7 | | Loading String | `'access_control'` | 8 | 9 | ## Description 10 | 11 | Given access to a set of 10 servers and an infinite queue of customers with 12 | different priorities, decide whether to accept or reject the next customer in 13 | line based on their priority and the number of free servers. 14 | 15 | Each customer has a uniformly random priority of 1, 2, 4, or 8. At each time 16 | step, the priority of the next customer is revealed and the customer is either 17 | accepted or rejected. If accepted, customers provide the agent with a reward 18 | equal to their priority and are assigned to a server. If rejected, the reward is 19 | zero. In either case, on the next time step the next customer in the queue is 20 | considered. When all servers are busy, the customer will be rejected no matter 21 | the intended choice. Busy servers are freed with probability 0.06 at each step. 22 | 23 | This problem is based on the Access-Control queueing task in *Reinforcement 24 | Learning: An Introduction* by Sutton and Barto (see 25 | [Example 10.2, 2nd ed.](http://incompleteideas.net/book/RLbook2020.pdf#page=274)). 26 | 27 | ## Observations 28 | The state space is represented as tuples `(num_busy_servers, incoming_priority)` 29 | and an observation is a single integer encoding the current state. With there 30 | being 0 to 10 busy servers and 4 possible customer priorities, there is a total 31 | of 44 discrete states. 32 | 33 | ## Actions 34 | There are two discrete actions. 35 | 36 | * 0: Reject the incoming customer. 37 | * 1: Accept the incoming customer. 38 | 39 | ## Rewards 40 | The incoming customer has a priority equal to 1, 2, 4, or 8. 41 | * If the customer is accepted, the reward equals the customer's priority. 42 | * Otherwise the customer is rejected and the reward equals 0. 43 | 44 | 45 | -------------------------------------------------------------------------------- /docs/environments/catch.md: -------------------------------------------------------------------------------- 1 | # Catch 2 | | Overview | Specification | 3 | |-------------------|-------------------------------------------| 4 | | Observation Space | Array of shape (10, 5) with binary values | 5 | | Action Space | \{0, 1, 2\} | 6 | | Reward Space | \{-1, 0, 1\} | 7 | | Loading String | `'catch'` | 8 | 9 | ## Description 10 | Catch as many falling objects as possible by controlling a breakout-like paddle 11 | positioned at the bottom of a 10 x 5 board. 12 | 13 | In the episodic version of Catch, an episode terminates after a single ball 14 | reaches the bottom of the screen. This is a continuing version of Catch, in 15 | which a new ball appears at the top of the screen with 10% probability, in a 16 | column chosen uniformly at random. This means that multiple balls can be present 17 | on the screen, but only one new ball can be *added* at each timestep. 18 | 19 | At each timestep, balls present on the screen will fall by one pixel. Balls only 20 | move downwards on the column they are in. The paddle can either stay in place, 21 | or move by one pixel to the left or right at each timestep. Balls successfully 22 | caught by the paddle give a reward of +1, and balls that fail to be caught by 23 | the paddle give a reward of -1. 24 | 25 | ## Observations 26 | The observation is an array of shape `(10, 5)`, with binary values: 27 | zero if a space is empty; 1 if it contains the paddle or a ball. The initial 28 | observation has one ball present at the top of the screen, in a column 29 | chosen uniformly at random. 30 | 31 | ## Actions 32 | There are three discrete actions. 33 | * 0: Move the paddle one pixel to the left. 34 | * 1: Keep the paddle in place. 35 | * 2: Move the paddle one pixel to the right. 36 | 37 | ## Rewards 38 | * If the paddle catches a ball, a reward of +1 is received. 39 | * If the paddle fails to catch a ball, a reward of -1 is received. 40 | * Otherwise the reward is 0. 41 | -------------------------------------------------------------------------------- /docs/environments/pendulum.md: -------------------------------------------------------------------------------- 1 | # Pendulum 2 | | Overview | Specification | 3 | |-------------------|-----------------------------------------------------------| 4 | | Observation Space | 3-d array: [ $\cos\theta$, $\sin\theta$, $\dot{\theta}$ ] | 5 | | Action Space | \{0, 1, 2\} | 6 | | Reward Space | \{0, 1\} | 7 | | Loading String | `'pendulum'` | 8 | 9 | 10 | ## Description 11 | 12 | Starting from a hanging down position, swing up a single pendulum to the 13 | inverted position and maintain the pendulum in this position. 14 | 15 | Most of the default arguments model the pendulum described in "A Comparison of 16 | Direct and Model-Based Reinforcement Learning" (Atkeson & Santamaria, 1997). 17 | The key differences are: 18 | 1) This environment is now continuing, i.e. the "trial length" is infinite. 19 | 2) Instead of directly returning the angle of the pendulum in the observation, 20 | this environment returns the cosine and sine of the angle. 21 | 3) There are only three discrete actions (apply a torque of -1, apply a torque 22 | of +1, and apply no-op) as opposed to a continuous torque value. 23 | 4) The default reward function is sparse: when the pendulum is near upright 24 | the reward is 1, otherwise 0 (more details below). 25 | 26 | The pendulum's motion is described by the equation: 27 | $\dot{\theta} = \tau - \mu \dot{\theta} - g \sin\theta$ , 28 | where $\theta$ is the angle with the vertical, $\mu$ is the friction coefficient, 29 | $\tau$ is the torque, $g$ is the acceleration due to gravity, 30 | $\dot{\theta}$ is the first derivative of theta w.r.t. time, 31 | and $\ddot{\theta}$ is the second derivative. 32 | 33 | ## Observations 34 | 35 | The observation is a 3-dimensional array encoding the cosine and sine of the 36 | pendulum's angle as well as its angular velocity: [ $\cos\theta$, $\sin\theta$, $\dot{\theta}$ ]. 37 | The pendulum starts at the bottom ( $\theta=0$ ) with zero velocity. 38 | The first two elements of the observation are inherently bounded in [-1, 1]; 39 | the third element is bound in the code by the parameter `max_speed`, whose 40 | default value is `np.inf`. 41 | 42 | ## Actions 43 | 44 | There are three discrete actions: 45 | * 0: -1 unit of torque 46 | * 1: 0 units of torque 47 | * 2: +1 unit of torque 48 | 49 | The sign indicates the direction of torque. 50 | 51 | ## Rewards 52 | The default reward function is sparse: 53 | * +1 reward if the pendulum is within a small range of the upright position: 54 | ( $\pi$ - reward_angle, $\pi$ + reward_angle) 55 | * 0 reward otherwise 56 | 57 | Variants of this problem can include a dense reward function 58 | (e.g., inversely proportional to the angular distance from the upright position). 59 | -------------------------------------------------------------------------------- /docs/environments/taxi.md: -------------------------------------------------------------------------------- 1 | # Taxi 2 | | Overview | Specification | 3 | |-------------------|-------------------| 4 | | Observation Space | \{0, 1, ..., 499\}| 5 | | Action Space | \{0, 1, 2, 4, 5\} | 6 | | Reward Space | \{-10, 0, 20\} | 7 | | Loading String | `'taxi'` | 8 | 9 | ## Description 10 | 11 | ```{image} /_static/img/taxi_grid.png 12 | :width: 25% 13 | :align: center 14 | ``` 15 | 16 | In a 5x5 grid world with four colored squares (Red, Green, Yellow, and Blue), 17 | taxi passengers spawn on one of the four squares, chosen uniformly at random. 18 | The passengers have a desired destination, also on one of the four squares 19 | chosen uniformly at random. The agent must pick up the current passenger 20 | and drop them off at their desired destination. Exactly one passenger is present 21 | on the grid at every timestep. 22 | 23 | This environment is based on the corresponding episodic environment introduced in 24 | ["Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition"](https://jair.org/index.php/jair/article/view/10266) 25 | by Thomas G. Dietterich. In the episodic version of Taxi, an episode terminates 26 | after a single passenger is delivered. This is a continuing version of Taxi, 27 | where a new passenger spawns once the current passenger has been successfully 28 | dropped off. 29 | 30 | At each timestep, the taxi can move in one of the four cardinal directions, or 31 | it can attempt to perform a pick-up or drop-off on its current square. The 32 | 'pick-up' action is legal only when the passenger and taxi are on the same 33 | square and the passenger has not already been picked up; similarly, the 34 | 'drop-off' action is legal only when the passenger has been picked up and the 35 | taxi has arrived at the passenger's destination. A passenger's pick-up and 36 | destination location *can be the same square*, but the taxi must still pick up 37 | and drop off the passenger for a successful delivery. 38 | 39 | For each passenger successfully dropped off, the agent receives a reward of +20. 40 | If the agent performs an illegal 'pick-up' or 'drop-off' action (eg. dropping 41 | off the passenger at an incorrect location), it receives a reward of -10. If a 42 | movement action causes the taxi to hit a barrier, the taxi stays in its 43 | current square. 44 | 45 | ## Observations 46 | The observation space is a single state index, which encodes the possible 47 | states accounting for the taxi position, location of the passenger, and four 48 | desired destination locations. Since there are 25 possible taxi positions, 49 | 5 possible locations of the passenger (either on the colored squares or in the 50 | taxi), and 4 possible destinations, there is a total of 500 discrete states. 51 | 52 | ## Actions 53 | There are six discrete actions. 54 | * 0: Move one pixel North. 55 | * 1: Move one pixel West. 56 | * 2: Move one pixel South. 57 | * 3: Move one pixel East. 58 | * 4: Pickup the passenger. 59 | * 5: Dropoff the passenger. 60 | 61 | ## Rewards 62 | * If a passenger is in the taxi and is successfully dropped off, a reward of +20 63 | is received. 64 | * If an illegal 'pick-up' or 'drop-off' action is performed, a reward of -10 is 65 | received. 66 | * Else the reward equals 0. 67 | 68 | ## Rendering 69 | Example visualizations of a few states are provided below. 70 | 71 | ```{image} /_static/img/taxi_pickup.png 72 | :width: 75% 73 | :align: center 74 | ``` 75 | * (Left) The empty taxi (grey circle) located on the same square as the 76 | passenger (purple outline), prior to pick-up. 77 | * (Middle) The taxi containing the passenger (purple circle) upon pick-up. 78 | * (Right) The taxi containing the passenger on the same square as its 79 | destination (dark green outline), prior to drop-off. 80 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to CSuite 2 | ----------------- 3 | 4 | CSuite is a collection of synthetic reinforcement learning environments that are 5 | continuing — the agent-environment interaction goes on forever without limit, 6 | with no natural episode boundaries. 7 | 8 | 9 | .. toctree:: :caption: User Guide 10 | :maxdepth: 1 11 | 12 | api 13 | 14 | 15 | .. toctree:: :caption: Environments 16 | :maxdepth: 1 17 | 18 | environments/access_control 19 | environments/catch 20 | environments/pendulum 21 | environments/taxi 22 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==4.5.0 2 | sphinx_rtd_theme==0.5.0 3 | sphinxcontrib-katex==0.8.6 4 | sphinx-autodoc-typehints==1.11.1 5 | IPython==7.16.3 6 | ipykernel==5.3.4 7 | pandoc==1.0.2 8 | docutils==0.16 9 | myst-parser -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | sphinx: 7 | builder: html 8 | configuration: docs/conf.py 9 | fail_on_warning: false 10 | 11 | python: 12 | version: 3.8 13 | install: 14 | - requirements: docs/requirements.txt 15 | - requirements: requirements.txt 16 | - method: setuptools 17 | path: . 18 | 19 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | pytest==7.1.2 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dm_env==1.5 2 | gym==0.26.2 3 | numpy==1.23.1 4 | Pillow==9.2.0 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Install script for setuptools.""" 17 | 18 | import os 19 | from setuptools import find_namespace_packages 20 | from setuptools import setup 21 | 22 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | 24 | setup( 25 | name='csuite', 26 | version='0.1.0', 27 | url='https://github.com/deepmind/csuite', 28 | license='Apache 2.0', 29 | author='DeepMind', 30 | description=( 31 | 'A collection of continuing environments for reinforcement learning.'), 32 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(), 33 | long_description_content_type='text/markdown', 34 | author_email='csuite@google.com', 35 | keywords='reinforcement-learning environment suite python machine learning', 36 | packages=find_namespace_packages(exclude=['*_test.py']), 37 | install_requires=[ 38 | 'dm_env>=1.5', 39 | 'gym>=0.19.0', 40 | 'numpy>=1.18.0', 41 | 'Pillow>=9.0.1', 42 | 'absl-py>=0.7.1', 43 | 'pytest>=6.2.5', 44 | ], 45 | zip_safe=False, # Required for full installation. 46 | python_requires='>=3.9,<3.11', 47 | classifiers=[ 48 | # TODO(b/241264065): list classifiers. 49 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 50 | 'Programming Language :: Python :: 3.9', 51 | 'Programming Language :: Python :: 3.10', 52 | ], 53 | ) 54 | --------------------------------------------------------------------------------