├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── README.md └── img │ ├── angular_velocity.gif │ └── humanoid_perturbations.gif ├── examples ├── run_dmpo_acme.py ├── run_ppo.py └── run_random.py ├── realworldrl_suite ├── __init__.py ├── analysis │ └── realworldrl_notebook.ipynb ├── environments │ ├── __init__.py │ ├── cartpole.py │ ├── env_test.py │ ├── humanoid.py │ ├── manipulator.py │ ├── quadruped.py │ ├── realworld_env.py │ └── walker.py └── utils │ ├── __init__.py │ ├── accumulators.py │ ├── accumulators_test.py │ ├── evaluators.py │ ├── evaluators_test.py │ ├── loggers.py │ ├── loggers_test.py │ ├── multiobj_objectives.py │ ├── multiobj_objectives_test.py │ ├── viewer.py │ ├── wrappers.py │ └── wrappers_test.py └── setup.py /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of Real-World RL Suite's significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | Google LLC 8 | DeepMind Technologies Limited 9 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-World Reinforcement Learning (RWRL) Challenge Framework 2 | 3 |

4 | 5 |

6 | 7 | The ["Challenges of Real-World RL"](https://arxiv.org/abs/1904.12901) paper 8 | identifies and describes a set of nine challenges that are currently preventing 9 | Reinforcement Learning (RL) agents from being utilized on real-world 10 | applications and products. It also describes an evaluation framework and a set 11 | of environments that can provide an evaluation of an RL algorithm’s potential 12 | applicability to real-world systems. It has since then been followed up with 13 | ["An Empirical Investigation of the challenges of real-world reinforcement 14 | learning"](https://arxiv.org/pdf/2003.11881.pdf) which implements eight of the 15 | nine described challenges (excluding explainability) and analyses their effects 16 | on various state-of-the-art RL algorithms. This is the codebase used to perform 17 | this analysis, and is also intended as a common platform for easily reproducible 18 | experimentation around these challenges, it is referred to as the 19 | `realworldrl-suite` (Real-World Reinforcement Learning (RWRL) Suite). 20 | 21 | Currently the suite is to comprised of five environments: 22 | 23 | * Cartpole 24 | * Walker 25 | * Quadriped 26 | * Manipulator (less tested) 27 | * Humanoid 28 | 29 | The codebase is currently structured as: 30 | 31 | * environments/ -- the extended environments 32 | * utils/ -- wrapper classes for logging and standardized evaluations 33 | * analysis/ -- Notebook for training an agent and generating plots 34 | * examples/ -- Random policy and PPO agent example implementations 35 | * docs/ -- Documentation 36 | 37 | Questions can be directed to the Real-World RL group e-mail 38 | [realworldrl@google.com]. 39 | 40 | > :information_source: If you wish to test your agent in a principled fashion on 41 | > related challenges in low-dimensional domains, we highly recommend using 42 | > [bsuite](https://github.com/deepmind/bsuite). 43 | 44 | ## Documentation 45 | 46 | We overview the challenges here, but more thorough documentation on how to 47 | configure each challenge can be found [here](docs/README.md). 48 | 49 | Starter examples are presented in the [examples](#running-examples) section. 50 | 51 | ## Challenges 52 | 53 | ### Safety 54 | 55 | Adds a set of constraints on the task. Returns an additional entry in the 56 | observations ('constraints') in the length of the number of the contraints, 57 | where each entry is True if the constraint is satisfied and False otherwise. 58 | 59 | ### Delays 60 | 61 | Action, observation and reward delays. 62 | 63 | - Action delay is the number of steps between passing the action to the 64 | environment to when it is actually performed. 65 | - Observation delay is the offset of freshness of the returned observation 66 | after performing a step. 67 | - Reward delay indicates the number of steps before receiving a reward after 68 | taking an action. 69 | 70 | ### Noise 71 | 72 | Action and observation noise. Different noise include: 73 | 74 | - White Gaussian action/observation noise 75 | - Dropped actions/observations 76 | - Stuck actions/observations 77 | - Repetitive actions 78 | 79 | The noise specifications can be parameterized in the noise_spec dictionary. 80 | 81 | ### Perturbations 82 | 83 | Perturbs physical quantities of the environment. These perturbations are 84 | non-stationary and are governed by a scheduler. 85 | 86 | ### Dimensionality 87 | 88 | Adds extra dummy features to observations to increase dimensionality of the 89 | state space. 90 | 91 | ### Multi-Objective Rewards: 92 | 93 | Adds additional objectives and specifies objectives interaction (e.g., sum). 94 | 95 | ### Offline Learning 96 | 97 | We provide our offline datasets through the 98 | [RL Unplugged](https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/) 99 | library. There is an 100 | [example](https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/rwrl_example.py) 101 | and an associated 102 | [colab](https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/rwrl_d4pg.ipynb). 103 | 104 | ### RWRL Combined Challenge Benchmarks: 105 | 106 | Combines multiple challenges into the same environment. The challenges are 107 | divided into 'Easy', 'Medium' and 'Hard' which depend on the magnitude of the 108 | challenge effects applied along each challenge dimension. 109 | 110 | ## Installation 111 | 112 | - Install pip: 113 | - Run the following commands: 114 | 115 | ```bash 116 | curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py 117 | python get-pip.py 118 | ``` 119 | 120 | - Make sure pip is up to date. 121 | 122 | ```bash 123 | pip3 install --upgrade pip 124 | ``` 125 | 126 | - (Optional) You may wish to create a 127 | [Python virtual environment](https://docs.python.org/3/tutorial/venv.html) 128 | to manage your dependencies, so as not to clobber your system installation: 129 | 130 | ```bash 131 | sudo pip3 install virtualenv 132 | /usr/local/bin/virtualenv realworldrl_suite 133 | source ./realworldrl/bin/activate 134 | ``` 135 | 136 | - Install MuJoCo (see dm_control - https://github.com/deepmind/dm_control). 137 | 138 | - To install `realworldrl_suite`: 139 | 140 | - Clone the repository by running: 141 | 142 | ```bash 143 | git clone https://github.com/google-research/realworldrl_suite.git 144 | ``` 145 | 146 | - Ensure you are in the parent directory of realworldrl_suite 147 | - Run the command: 148 | 149 | ```bash 150 | pip3 install realworldrl_suite/ 151 | ``` 152 | 153 | ## Running examples 154 | 155 | We provide three example agents: a random agent, a PPO agent, and an 156 | [ACME](https://github.com/deepmind/acme)-based DMPO agent. 157 | 158 | - For PPO, running the examples requires installing the following packages: 159 | 160 | ```bash 161 | pip3 install tensorflow==1.15.0 dm2gym 162 | pip3 install git+git://github.com/openai/baselines.git 163 | ``` 164 | 165 | - The PPO example can then be run with 166 | 167 | ```bash 168 | cd realworldrl_suite/examples 169 | mkdir /tmp/rwrl/ 170 | python3 run_ppo.py 171 | ``` 172 | 173 | - For DMPO, one can run the example by installing the following packages: 174 | 175 | ```bash 176 | pip install dm-acme 177 | pip install dm-acme[reverb] 178 | pip install dm-acme[tf] 179 | ``` 180 | 181 | You *may* also have to install the following: 182 | 183 | ```bash 184 | pip install gym 185 | pip install jax 186 | pip install dm-sonnet 187 | ``` 188 | 189 | - The examples look for the MuJoCo licence key in `~/.mujoco/mjkey.txt` by 190 | default. 191 | 192 | ## RWRL Combined Challenge Benchmark Instantiation: 193 | 194 | As mentioned above, these benchmark challenges are divided into 'Easy', 'Medium' 195 | and 'Hard' difficulty levels. For the current state-of-the-art performance on 196 | these benchmarks, please see this 197 | paper. 198 | 199 | Instantiating a combined challenge environment with 'Easy' difficulty is done as 200 | follows: 201 | 202 | ```python 203 | import realworldrl_suite.environments as rwrl 204 | env = rwrl.load( 205 | domain_name='cartpole', 206 | task_name='realworld_swingup', 207 | combined_challenge='easy', 208 | log_output='/tmp/path/to/results.npz', 209 | environment_kwargs=dict(log_safety_vars=True, flat_observation=True)) 210 | ``` 211 | 212 | ## Acknowledgements 213 | 214 | If you use `realworldrl_suite` in your work, please cite: 215 | 216 | ``` 217 | @article{dulacarnold2020realworldrlempirical, 218 | title={An empirical investigation of the challenges of real-world reinforcement learning}, 219 | author={Dulac-Arnold, Gabriel and 220 | Levine, Nir and 221 | Mankowitz, Daniel J. and 222 | Li, Jerry and 223 | Paduraru, Cosmin and 224 | Gowal, Sven and 225 | Hester, Todd 226 | }, 227 | year={2020}, 228 | } 229 | ``` 230 | 231 | ## Paper links 232 | 233 | - Challenges of real-world 234 | reinforcement learning 235 | 236 | - An empirical investigation of the 237 | challenges of real-world reinforcement learning 238 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # RWRL Suite Challenge Spec 2 | 3 | This document outlines the specification dictionaries for each of the supported 4 | challenges. It provides an overview of each challenge's parameters and their 5 | effects. 6 | 7 | 8 | ### Safety 9 | 10 | Adds a set of constraints on the task. Returns an additional entry in the 11 | observations under the `constraints` key. The observation is a binary vectory 12 | and is the length of the number of the contraints, where each entry is True if 13 | the constraint is satisfied and False otherwise. The following dictionary is fed 14 | as an argument into the RWRL environment load function ot intialize safety 15 | constraints: 16 | 17 | ``` 18 | safety_spec = { 19 | 'enable': bool, # Whether to enable safety constraints. 20 | 'observations': bool, # Whether to add the constraint violations observation. 21 | 'safety_coeff': float, # Safety coefficient that regulates the difficulty of the constraint. 1 is the baseline, and 0 is impossible (always violated). 22 | 'constraints' : list # Optional list of additional safety constraints. Can only operate on variables returned by the `safety_vars` method. 23 | } 24 | ``` 25 | 26 | Each of the built-in constraints has been tuned to be just at the border of 27 | nominal operation. For tasks such as Cartpole it is tuned to be violated only 28 | during the swingup phase, but not violated during balancing. Different 29 | constraints will be more or less difficult to satisfy. 30 | 31 | The built-in constraints are as follows: 32 | 33 | * Cartpole Swing-Up: 34 | * `slider_pos_constraint` : Constrain cart to be within a specific region on track. 35 | * `balance_velocity_constraint` : Constrain pole angular velocity to be below a certain threshold when arriving near the top. This provides a more subtle constraint than a standard box constraint on a variable. 36 | * `slider_accel_constraint` : Constrain cart acceleration to be below a certain value. 37 | * Walker: 38 | * `joint_angle_constraint` : Constrain joint angles to specific ranges. This is joint-specific. 39 | * `joint_velocity_constraint` : Constrain joint velocities to a certain range. This is a global value. 40 | * `dangerous_fall_constraint` : Discourage dangerous falls by ensuring the torso stays positioned forwards. 41 | * `torso_upright_constraint` : Discourage dangerous operation by ensuring that the torso stays upright. 42 | * Quadruped: 43 | * `joint_angle_constraint` : Constrain joint angles to specific ranges. This is joint-specific. 44 | * `joint_velocity_constraint` : Constrain joint velocities to a certain range. This is a global value. 45 | * `upright_constraint` : Constrain the Quadruped's torso's z-axis to be oriented upwards. 46 | * `foot_force_constraint` : Limits foot contact forces when touching the ground. 47 | * Humanoid 48 | * `joint_angle_constraint` : Constrain joint angles to specific ranges. This is joint-specific. 49 | * `joint_velocity_constraint` : Constrain joint velocities to a certain range. This is a global value. 50 | * `upright_constraint` : Constrain the Humanoid's pelvis's z-axis to be oriented upwards. 51 | * `foot_force_constraint` : Limits foot contact forces when touching the ground. 52 | * `dangerous_fall_constraint` : Discourages dangerous falls by limiting head and torso contact. 53 | 54 | It is also possible to add in arbitrary constraints by passing in a list of methods to the `observations` key, which will receive the values returned by the task's `safety_vars` method. 55 | 56 | 57 | ### Delays 58 | This challenge provides delays on actions, observations and rewards. For each of these, the delayed element is placed into a buffer and either applied to the environment, or sent back to the agent after the specified number of steps. 59 | 60 | This can be configured by passing in a `delay_spec` dictionary of the following form: 61 | 62 | ``` 63 | delay_spec = { 64 | 'enable': bool, # Whether to enable this challenge 65 | 'actions': int, # The delay on actions in # of steps 66 | 'observations': int, # The delay on observations in # of steps 67 | 'rewards': int, # The delay on actions in # of steps 68 | } 69 | ``` 70 | 71 | ### Noise 72 | 73 | Real-world systems often have action and observation noise, and this can come in 74 | many flavors, from simply noisy values, to dropped and stuck signals. This 75 | challenge allows you to experiment with different types of noise: 76 | 77 | - White Gaussian action/observation noise 78 | - Dropped actions/observations 79 | - Stuck actions/observations 80 | - Repetitive actions 81 | 82 | The noise specifications can be parameterized in the noise_spec dictionary. 83 | 84 | ``` 85 | noise_spec = { 86 | 'gaussian': { # Specifies the white Gaussian additive noise. 87 | 'enable': bool, # Whether to enable Gaussian noise. 88 | 'actions': float, # Standard deviation of noise added to actions. 89 | 'observations': float, # Standard deviation of noise added to observations. 90 | }, 91 | 'dropped': { # Specifies dropped value noise. 92 | 'enable': bool, # Whether to enable dropped values. 93 | 'observations_prob': float, # Value in [0,1] indicating the probability of dropping each observation independently. 94 | 'observations_steps': int, # Value > 0 specifying the number of steps to drop an observation for if dropped. 95 | 'action_prob': float # Value in [0,1] indicating the probability of dropping each action independently. 96 | 'action_steps': int, # Value > 0 specifying the number of steps to drop an action for if dropped. 97 | }, 98 | 'stuck': { # Specifies stuck values noise. 99 | 'enable': bool, # Whether to enable stuck values. 100 | 'observations_prob': float, # Value in [0,1] indicating the probability of an observation component becoming stuck. 101 | 'observations_steps': int, # Value > 0 specifying the number of steps an observation remains stuck. 102 | 'action_prob': float, # Value in [0,1] indicating the probability of an action component becoming stuck. 103 | 'action_steps': int # Value > 0 specifying the number of steps an action remains stuck. 104 | }, 105 | 'repetition': { # Specifies repetition statistics. 106 | 'enable': bool, # Whether to enable repeating values. 107 | 'actions_prob': float, # Value in [0,1] indicating the probability of an action repeating. 108 | 'actions_steps': int # Value > 0 specifying the number of steps an action repeats. 109 | }, 110 | } 111 | 112 | ``` 113 | 114 | ### Perturbations 115 | 116 | Real systems are imperfect and degrade or change over time. This means the 117 | controller needs to understand these changes and update its control policy 118 | accordingly. The RWRL suite can simulate various system perturbations, with 119 | varying ways in which a perturbation evolves over time (which we call its 120 | 'schedule'). 121 | 122 | These challenges can be configured by passing in a `pertub_spec` dictionary with the 123 | following format: 124 | 125 | ``` 126 | perturb_spec = { 127 | 'enable': bool, # Whether to enable perturbations 128 | 'period': int, # Number of episodes between perturbation changes. 129 | 'param': str, # Specifies which parameter to perturb. Specified below. 130 | 'scheduler': str, # Specifies which scheduler to apply. Specified below. 131 | 'start': float, # Indicates initial value of perturbed parameter. 132 | 'min': float, # Indicates the minimal value of the perturbed parameter. 133 | 'max': float, # Indicates the maximal value of the perturbed parameter. 134 | 'std': float # Indicates the standard deviation of white noise used in scheduling. 135 | } 136 | ``` 137 | 138 | The various scheduler choices are as follows: 139 | 140 | * `constant` : keeps the perturbation constant. 141 | * `random_walk`: change the perturbation by a random amount (defined by the 'std' key). 142 | * `drift_pos` : change the perturbation by a random positive amount. 143 | * `drift_neg` : change the perturbation by a random negative amount. 144 | * `cyclic_pos` : change the perturbation by a random positive amount and cycle back when 'max' is attained. 145 | * `cyclic_neg` : change the perturbation by a random negative amount and cycle back when 'min' is attained. 146 | * `uniform` : set the perturbation to a uniform random value within [min, max]. 147 | * `saw_wave` : cycle between `drift_pos` and `drift_neg` when [min, max] bounds are reached. 148 | 149 | Each environment has a set of parameters which can be perturbed: 150 | 151 | * Cartpole 152 | * `pole_length` 153 | * `pole_mass` 154 | * `joint_damping` : adds a damping factor to the pole joint. 155 | * `slider_damping` : adds a damping factor to the slider (cart). 156 | * Walker 157 | * `thigh_length` 158 | * `torso_length` 159 | * `joint_damping` : adds a damping factor to all joints. 160 | * `contact_friction` : alters contact friction with ground. 161 | * Quadruped 162 | * `shin_length` 163 | * `torso_density`: alters torso density, therefore changing weight with constant volume. 164 | * `joint_damping` : adds a damping factor to all joints. 165 | * `contact_friction` : alters contact friction with ground. 166 | * Humanoid 167 | * `joint_damping` : adds a damping factor to all joints. 168 | * `contact_friction` : alters contact friction with ground. 169 | * `head_size` : alters head size (and therefore weight). 170 | 171 | 172 | ### Dimensionality 173 | Adds extra dummy features to observations to increase dimensionality of the 174 | state space. 175 | 176 | ``` 177 | dimensionality_spec = { 178 | 'enable': bool, # Whether to enable dimensionality challenges. 179 | 'num_random_state_observations': int, # Number of random observation dimension to add. 180 | } 181 | ``` 182 | 183 | ### Multi-Objective Reward 184 | This challenge looks at multi-objective rewards. There is a default multi-objective reward included which allows a safety objective to be defined from the set of constraints, but new objectives are easy to implement by adding them to the `utils.multiobj_objectives.OBJECTIVES` dict. 185 | 186 | ``` 187 | multiobj_spec = { 188 | 'enable': bool, # Whether to enable the multi-objective challenge. 189 | 'objective': str or object, # Either a string which will load an `Objective` class from 190 | # utils.multiobj_objectives.OBJECTIVES or an Objective object 191 | # which subclasses utils.multiobj_objectives.Objective. 192 | 'reward': bool, # Whether to add the multiobj objective's reward to the environment's returned reward. 193 | 'coeff': float, # A number in [0,1] used as a reward mixing ratio by the Objective object. 194 | 'observed': bool # Whether the defined objectives should be added to the observation. 195 | } 196 | ``` 197 | 198 | ### RWRL Combined Challenge Benchmarks: 199 | The RWRL suite allows you to combine multiple challenges into the same environment. The challenges are 200 | divided into 'Easy', 'Medium' and 'Hard' which depend on the magnitude of the 201 | challenge effects applied along each challenge dimension. 202 | 203 | * The 'Easy' challenge: 204 | 205 | ``` 206 | delay_spec = { 207 | 'enable': True, 208 | 'actions': 3, 209 | 'observations': 3, 210 | 'rewards': 10 211 | } 212 | noise_spec = { 213 | 'gaussian': { 214 | 'enable': True, 215 | 'actions': 0.1, 216 | 'observations': 0.1 217 | }, 218 | 'dropped': { 219 | 'enable': True, 220 | 'observations_prob': 0.01, 221 | 'observations_steps': 1, 222 | }, 223 | 'stuck': { 224 | 'enable': True, 225 | 'observations_prob': 0.01, 226 | 'observations_steps': 1, 227 | }, 228 | 'repetition': { 229 | 'enable': True, 230 | 'actions_prob': 1.0, 231 | 'actions_steps': 1 232 | } 233 | } 234 | perturb_spec = { 235 | 'enable': True, 236 | 'period': 1, 237 | 'scheduler': 'uniform' 238 | } 239 | dimensionality_spec = { 240 | 'enable': True, 241 | 'num_random_state_observations': 10 242 | } 243 | ``` 244 | 245 | * The 'Medium' challenge: 246 | 247 | ``` 248 | delay_spec = { 249 | 'enable': True, 250 | 'actions': 6, 251 | 'observations': 6, 252 | 'rewards': 20 253 | } 254 | noise_spec = { 255 | 'gaussian': { 256 | 'enable': True, 257 | 'actions': 0.3, 258 | 'observations': 0.3 259 | }, 260 | 'dropped': { 261 | 'enable': True, 262 | 'observations_prob': 0.05, 263 | 'observations_steps': 5, 264 | }, 265 | 'stuck': { 266 | 'enable': True, 267 | 'observations_prob': 0.05, 268 | 'observations_steps': 5, 269 | }, 270 | 'repetition': { 271 | 'enable': True, 272 | 'actions_prob': 1.0, 273 | 'actions_steps': 2 274 | } 275 | } 276 | perturb_spec = { 277 | 'enable': True, 278 | 'period': 1, 279 | 'scheduler': 'uniform' 280 | } 281 | dimensionality_spec = { 282 | 'enable': True, 283 | 'num_random_state_observations': 20 284 | } 285 | ``` 286 | 287 | * The 'Hard' challenge: 288 | 289 | ``` 290 | delay_spec = { 291 | 'enable': True, 292 | 'actions': 9, 293 | 'observations': 9, 294 | 'rewards': 40 295 | } 296 | noise_spec = { 297 | 'gaussian': { 298 | 'enable': True, 299 | 'actions': 1.0, 300 | 'observations': 1.0 301 | }, 302 | 'dropped': { 303 | 'enable': True, 304 | 'observations_prob': 0.1, 305 | 'observations_steps': 10, 306 | }, 307 | 'stuck': { 308 | 'enable': True, 309 | 'observations_prob': 0.1, 310 | 'observations_steps': 10, 311 | }, 312 | 'repetition': { 313 | 'enable': True, 314 | 'actions_prob': 1.0, 315 | 'actions_steps': 3 316 | } 317 | } 318 | perturb_spec = { 319 | 'enable': True, 320 | 'period': 1, 321 | 'scheduler': 'uniform' 322 | } 323 | dimensionality_spec = { 324 | 'enable': True, 325 | 'num_random_state_observations': 50 326 | } 327 | ``` 328 | -------------------------------------------------------------------------------- /docs/img/angular_velocity.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/realworldrl_suite/be7a51cffa7f5f9cb77a387c16bad209e0f851f8/docs/img/angular_velocity.gif -------------------------------------------------------------------------------- /docs/img/humanoid_perturbations.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/realworldrl_suite/be7a51cffa7f5f9cb77a387c16bad209e0f851f8/docs/img/humanoid_perturbations.gif -------------------------------------------------------------------------------- /examples/run_dmpo_acme.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trains an ACME DMPO agent on a perturbed version of Cart-Pole.""" 17 | 18 | import os 19 | from typing import Dict, Sequence 20 | 21 | from absl import app 22 | from absl import flags 23 | import acme 24 | from acme import specs 25 | from acme import types 26 | from acme import wrappers 27 | from acme.agents.tf import dmpo 28 | from acme.tf import networks 29 | from acme.tf import utils as tf2_utils 30 | import dm_env 31 | import numpy as np 32 | import realworldrl_suite.environments as rwrl 33 | import sonnet as snt 34 | 35 | flags.DEFINE_string('domain_name', 'cartpole', 'domain to solve') 36 | flags.DEFINE_string('task_name', 'realworld_balance', 'task to solve') 37 | flags.DEFINE_string('save_path', '/tmp/rwrl', 'where to save results') 38 | flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to run for.') 39 | 40 | FLAGS = flags.FLAGS 41 | 42 | 43 | def make_environment(domain_name: str = 'cartpole', 44 | task_name: str = 'balance') -> dm_env.Environment: 45 | """Creates a RWRL suite environment.""" 46 | environment = rwrl.load( 47 | domain_name=domain_name, 48 | task_name=task_name, 49 | safety_spec=dict(enable=True), 50 | delay_spec=dict(enable=True, actions=20), 51 | log_output=os.path.join(FLAGS.save_path, 'log.npz'), 52 | environment_kwargs=dict( 53 | log_safety_vars=True, log_every=2, flat_observation=True)) 54 | environment = wrappers.SinglePrecisionWrapper(environment) 55 | return environment 56 | 57 | 58 | def make_networks( 59 | action_spec: specs.BoundedArray, 60 | policy_layer_sizes: Sequence[int] = (256, 256, 256), 61 | critic_layer_sizes: Sequence[int] = (512, 512, 256), 62 | vmin: float = -150., 63 | vmax: float = 150., 64 | num_atoms: int = 51, 65 | ) -> Dict[str, types.TensorTransformation]: 66 | """Creates networks used by the agent.""" 67 | 68 | # Get total number of action dimensions from action spec. 69 | num_dimensions = np.prod(action_spec.shape, dtype=int) 70 | 71 | # Create the shared observation network; here simply a state-less operation. 72 | observation_network = tf2_utils.batch_concat 73 | 74 | # Create the policy network. 75 | policy_network = snt.Sequential([ 76 | networks.LayerNormMLP(policy_layer_sizes), 77 | networks.MultivariateNormalDiagHead(num_dimensions) 78 | ]) 79 | 80 | # The multiplexer transforms concatenates the observations/actions. 81 | multiplexer = networks.CriticMultiplexer( 82 | critic_network=networks.LayerNormMLP(critic_layer_sizes), 83 | action_network=networks.ClipToSpec(action_spec)) 84 | 85 | # Create the critic network. 86 | critic_network = snt.Sequential([ 87 | multiplexer, 88 | networks.DiscreteValuedHead(vmin, vmax, num_atoms), 89 | ]) 90 | 91 | return { 92 | 'policy': policy_network, 93 | 'critic': critic_network, 94 | 'observation': observation_network, 95 | } 96 | 97 | 98 | def main(_): 99 | # Create an environment and grab the spec. 100 | environment = make_environment( 101 | domain_name=FLAGS.domain_name, task_name=FLAGS.task_name) 102 | environment_spec = specs.make_environment_spec(environment) 103 | 104 | # Create the networks to optimize (online) and target networks. 105 | agent_networks = make_networks(environment_spec.actions) 106 | 107 | # Construct the agent. 108 | agent = dmpo.DistributionalMPO( 109 | environment_spec=environment_spec, 110 | policy_network=agent_networks['policy'], 111 | critic_network=agent_networks['critic'], 112 | observation_network=agent_networks['observation'], # pytype: disable=wrong-arg-types 113 | ) 114 | 115 | # Run the environment loop. 116 | loop = acme.EnvironmentLoop(environment, agent) 117 | loop.run(num_episodes=FLAGS.num_episodes) 118 | 119 | 120 | if __name__ == '__main__': 121 | app.run(main) 122 | -------------------------------------------------------------------------------- /examples/run_ppo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trains an OpenAI Baselines PPO agent on realworldrl. 17 | 18 | Note that OpenAI Gym is not installed with realworldrl by default. 19 | See also github.com/openai/baselines for more information. 20 | 21 | This example also relies on dm2gym for its gym environment wrapper. 22 | See github.com/zuoxingdong/dm2gym for more information. 23 | """ 24 | 25 | import os 26 | 27 | from absl import app 28 | from absl import flags 29 | from baselines import bench 30 | from baselines.common.vec_env import dummy_vec_env 31 | from baselines.ppo2 import ppo2 32 | import dm2gym.envs.dm_suite_env as dm2gym 33 | import realworldrl_suite.environments as rwrl 34 | 35 | flags.DEFINE_string('domain_name', 'cartpole', 'domain to solve') 36 | flags.DEFINE_string('task_name', 'realworld_balance', 'task to solve') 37 | flags.DEFINE_string('save_path', '/tmp/rwrl', 'where to save results') 38 | flags.DEFINE_boolean('verbose', True, 'whether to log to std output') 39 | flags.DEFINE_string('network', 'mlp', 'name of network architecture') 40 | flags.DEFINE_float('agent_discount', .99, 'discounting on the agent side') 41 | flags.DEFINE_integer('nsteps', 100, 'number of steps per ppo rollout') 42 | flags.DEFINE_integer('total_timesteps', 1000000, 'total steps for experiment') 43 | flags.DEFINE_float('learning_rate', 1e-3, 'learning rate for optimizer') 44 | 45 | FLAGS = flags.FLAGS 46 | 47 | 48 | class GymEnv(dm2gym.DMSuiteEnv): 49 | """Wrapper that convert a realworldrl environment to a gym environment.""" 50 | 51 | def __init__(self, env): 52 | """Constructor. We reuse the facilities from dm2gym.""" 53 | self.env = env 54 | self.metadata = { 55 | 'render.modes': ['human', 'rgb_array'], 56 | 'video.frames_per_second': round(1. / self.env.control_timestep()) 57 | } 58 | self.observation_space = dm2gym.convert_dm_control_to_gym_space( 59 | self.env.observation_spec()) 60 | self.action_space = dm2gym.convert_dm_control_to_gym_space( 61 | self.env.action_spec()) 62 | self.viewer = None 63 | 64 | 65 | def run(): 66 | """Runs a PPO agent on a given environment.""" 67 | 68 | def _load_env(): 69 | """Loads environment.""" 70 | raw_env = rwrl.load( 71 | domain_name=FLAGS.domain_name, 72 | task_name=FLAGS.task_name, 73 | safety_spec=dict(enable=True), 74 | delay_spec=dict(enable=True, actions=20), 75 | log_output=os.path.join(FLAGS.save_path, 'log.npz'), 76 | environment_kwargs=dict( 77 | log_safety_vars=True, log_every=20, flat_observation=True)) 78 | env = GymEnv(raw_env) 79 | env = bench.Monitor(env, FLAGS.save_path) 80 | return env 81 | 82 | env = dummy_vec_env.DummyVecEnv([_load_env]) 83 | 84 | ppo2.learn( 85 | env=env, 86 | network=FLAGS.network, 87 | lr=FLAGS.learning_rate, 88 | total_timesteps=FLAGS.total_timesteps, # make sure to run enough steps 89 | nsteps=FLAGS.nsteps, 90 | gamma=FLAGS.agent_discount, 91 | ) 92 | 93 | 94 | def main(argv): 95 | del argv # Unused. 96 | run() 97 | 98 | 99 | if __name__ == '__main__': 100 | app.run(main) 101 | -------------------------------------------------------------------------------- /examples/run_random.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Runs a random policy on realworldrl.""" 17 | 18 | import os 19 | 20 | from absl import app 21 | from absl import flags 22 | import numpy as np 23 | import realworldrl_suite.environments as rwrl 24 | 25 | flags.DEFINE_string('domain_name', 'cartpole', 'domain to solve') 26 | flags.DEFINE_string('task_name', 'realworld_balance', 'task to solve') 27 | flags.DEFINE_string('save_path', '/tmp/rwrl', 'where to save results') 28 | flags.DEFINE_integer('total_episodes', 100, 'number of episodes') 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def random_policy(action_spec): 34 | 35 | def _act(timestep): 36 | del timestep 37 | return np.random.uniform( 38 | low=action_spec.minimum, 39 | high=action_spec.maximum, 40 | size=action_spec.shape) 41 | 42 | return _act 43 | 44 | 45 | def run(): 46 | """Runs a random agent on a given environment.""" 47 | 48 | env = rwrl.load( 49 | domain_name=FLAGS.domain_name, 50 | task_name=FLAGS.task_name, 51 | safety_spec=dict(enable=True), 52 | delay_spec=dict(enable=True, actions=20), 53 | log_output=os.path.join(FLAGS.save_path, 'log.npz'), 54 | environment_kwargs=dict( 55 | log_safety_vars=True, log_every=20, flat_observation=True)) 56 | 57 | policy = random_policy(action_spec=env.action_spec()) 58 | 59 | rewards = [] 60 | for _ in range(FLAGS.total_episodes): 61 | timestep = env.reset() 62 | total_reward = 0. 63 | while not timestep.last(): 64 | action = policy(timestep) 65 | timestep = env.step(action) 66 | total_reward += timestep.reward 67 | rewards.append(total_reward) 68 | print('Random policy total reward per episode: {:.2f} +- {:.2f}'.format( 69 | np.mean(rewards), np.std(rewards))) 70 | 71 | 72 | def main(argv): 73 | del argv # Unused. 74 | run() 75 | 76 | 77 | if __name__ == '__main__': 78 | app.run(main) 79 | -------------------------------------------------------------------------------- /realworldrl_suite/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library to do real-world RL.""" 17 | 18 | import realworldrl_suite.environments 19 | import realworldrl_suite.utils 20 | -------------------------------------------------------------------------------- /realworldrl_suite/analysis/realworldrl_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "realworldrl_notebook.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "last_runtime": { 10 | "build_target": "", 11 | "kind": "local" 12 | } 13 | }, 14 | "kernelspec": { 15 | "display_name": "Python 3", 16 | "name": "python3" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "colab_type": "text", 24 | "id": "OF9qRa7D8643" 25 | }, 26 | "source": [ 27 | "**To use the Jupyter notebook, ensure that you have installed the following packages (in the pre-defined order):**\n", 28 | "1. Python3\n", 29 | "2. Matplotlib\n", 30 | "3. Numpy\n", 31 | "4. Scipy\n", 32 | "5. tqdm\n", 33 | "6. Mujoco (Make sure Mujoco is installed before installing our Realworld RL suite)\n", 34 | "7. The Realworld RL Suite\n", 35 | "\n", 36 | "**It is recommended to use the realworldrl_venv virtual environment that you used when installing the realworldrl_suite package. To do so, you may need to run the following commands:** \n", 37 | "\n", 38 | "```\n", 39 | "pip3 install --user ipykernel\n", 40 | "python3 -m ipykernel install --user --name=realworldrl_venv\n", 41 | "```\n", 42 | "\n", 43 | "Then in this notebook, click 'Kernel' in the menu, then click 'Change Kernel' and select `realworldrl_venv`\n", 44 | "\n", 45 | "**Note**: You may need to restart the Jupyter kernel to see the updated virtual environment in the Kernel list.\n" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": { 51 | "colab_type": "text", 52 | "id": "yLMLidYq8646" 53 | }, 54 | "source": [ 55 | "**Import the necessary libraries**" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "cellView": "both", 62 | "colab_type": "code", 63 | "id": "icAkv-zLplqP", 64 | "colab": {} 65 | }, 66 | "source": [ 67 | "#@title \n", 68 | "# Copyright 2020 Google LLC.\n", 69 | "\n", 70 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", 71 | "# you may not use this file except in compliance with the License.\n", 72 | "# You may obtain a copy of the License at\n", 73 | "\n", 74 | "# https://www.apache.org/licenses/LICENSE-2.0\n", 75 | "\n", 76 | "# Unless required by applicable law or agreed to in writing, software\n", 77 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 78 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 79 | "# See the License for the specific language governing permissions and\n", 80 | "# limitations under the License.\n", 81 | "\n", 82 | "from __future__ import absolute_import\n", 83 | "from __future__ import division\n", 84 | "from __future__ import print_function\n", 85 | "\n", 86 | "import matplotlib.pyplot as plt\n", 87 | "import numpy as np\n", 88 | "import tqdm\n", 89 | "\n", 90 | "import collections\n", 91 | "\n", 92 | "import realworldrl_suite.environments as rwrl\n", 93 | "from realworldrl_suite.utils import evaluators" 94 | ], 95 | "execution_count": 0, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "colab_type": "text", 102 | "id": "DDcZhPBB865C" 103 | }, 104 | "source": [ 105 | "**Define the environment and the policy**" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "metadata": { 111 | "colab_type": "code", 112 | "id": "7gDwzTFz0Tak", 113 | "colab": {} 114 | }, 115 | "source": [ 116 | "total_episodes = 1000 # The analysis tools require at least 100 episodes.\n", 117 | "domain_name = 'cartpole'\n", 118 | "task_name = 'realworld_swingup'\n", 119 | "\n", 120 | "# Define the challenge dictionaries\n", 121 | "safety_spec_dict = dict(enable=True, safety_coeff=0.5)\n", 122 | "delay_spec_dict = dict(enable=True, actions=20)\n", 123 | "\n", 124 | "log_safety_violations = True\n", 125 | "\n", 126 | "def random_policy(action_spec):\n", 127 | "\n", 128 | " def _act(timestep):\n", 129 | " del timestep\n", 130 | " return np.random.uniform(low=action_spec.minimum,\n", 131 | " high=action_spec.maximum,\n", 132 | " size=action_spec.shape)\n", 133 | " return _act\n", 134 | "\n", 135 | "\n", 136 | "env = rwrl.load(\n", 137 | " domain_name=domain_name,\n", 138 | " task_name=task_name,\n", 139 | " safety_spec=safety_spec_dict,\n", 140 | " delay_spec=delay_spec_dict,\n", 141 | " log_output=os.path.join('/tmp/', 'log.npz'),\n", 142 | " environment_kwargs=dict(log_safety_vars=log_safety_violations, \n", 143 | " flat_observation=True,\n", 144 | " log_every=10))\n", 145 | "\n", 146 | "policy = random_policy(action_spec=env.action_spec())" 147 | ], 148 | "execution_count": 0, 149 | "outputs": [] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": { 154 | "colab_type": "text", 155 | "id": "lj0CFq1k865I" 156 | }, 157 | "source": [ 158 | "**Run the main loop**" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "metadata": { 164 | "colab_type": "code", 165 | "id": "ot9xUSM8z8Ib", 166 | "colab": {} 167 | }, 168 | "source": [ 169 | "rewards = []\n", 170 | "episode_counter = 0\n", 171 | "for _ in tqdm.tqdm(range(total_episodes)):\n", 172 | " timestep = env.reset()\n", 173 | " total_reward = 0.\n", 174 | " while not timestep.last():\n", 175 | " action = policy(timestep)\n", 176 | " timestep = env.step(action)\n", 177 | " total_reward += timestep.reward\n", 178 | " rewards.append(total_reward)\n", 179 | " episode_counter+=1\n", 180 | "print('Random policy total reward per episode: {:.2f} +- {:.2f}'.format(\n", 181 | "np.mean(rewards), np.std(rewards)))" 182 | ], 183 | "execution_count": 0, 184 | "outputs": [] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "metadata": { 189 | "colab_type": "code", 190 | "id": "Ne2i3ZFRsnVw", 191 | "colab": {} 192 | }, 193 | "source": [ 194 | "f = open(env.logs_path, \"rb\") \n", 195 | "stats = np.load(f, allow_pickle=True)\n", 196 | "evaluator = evaluators.Evaluators(stats)" 197 | ], 198 | "execution_count": 0, 199 | "outputs": [] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": { 204 | "colab_type": "text", 205 | "id": "xU-jhT4W865T" 206 | }, 207 | "source": [ 208 | "**Load the average return plot as a function of the number of episodes**" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "metadata": { 214 | "colab_type": "code", 215 | "id": "AEHttr9JyttU", 216 | "colab": {} 217 | }, 218 | "source": [ 219 | "fig = evaluator.get_return_plot()\n", 220 | "plt.show()" 221 | ], 222 | "execution_count": 0, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": { 228 | "id": "327LSohpug29", 229 | "colab_type": "text" 230 | }, 231 | "source": [ 232 | "**Compute regret**" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "metadata": { 238 | "id": "nuyf6byUufCP", 239 | "colab_type": "code", 240 | "colab": {} 241 | }, 242 | "source": [ 243 | "fig = evaluator.get_convergence_plot()\n", 244 | "plt.show()" 245 | ], 246 | "execution_count": 0, 247 | "outputs": [] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": { 252 | "id": "Xk8vY0b6u2um", 253 | "colab_type": "text" 254 | }, 255 | "source": [ 256 | "**Compute instability**" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "metadata": { 262 | "id": "QoqFSJDZu6Eo", 263 | "colab_type": "code", 264 | "colab": {} 265 | }, 266 | "source": [ 267 | "fig = evaluator.get_stability_plot()\n", 268 | "plt.show()" 269 | ], 270 | "execution_count": 0, 271 | "outputs": [] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": { 276 | "colab_type": "text", 277 | "id": "RkPTOgzA865Z" 278 | }, 279 | "source": [ 280 | "**Safety violations plot (left figure) and the mean evolution of safety constraint violations during an episode (right figure)**" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "metadata": { 286 | "colab_type": "code", 287 | "id": "i7BRSJsWzJEB", 288 | "colab": {} 289 | }, 290 | "source": [ 291 | "fig = evaluator.get_safety_plot()\n", 292 | "plt.show()" 293 | ], 294 | "execution_count": 0, 295 | "outputs": [] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": { 300 | "id": "TSxbfXDOyV2q", 301 | "colab_type": "text" 302 | }, 303 | "source": [ 304 | "**Multiple training seeds can be aggregated.**" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "metadata": { 310 | "id": "dmi6k1vyxsLn", 311 | "colab_type": "code", 312 | "colab": {} 313 | }, 314 | "source": [ 315 | "# We emulate multiple runs by copying the same logs with added noise.\n", 316 | "\n", 317 | "all_evaluators = []\n", 318 | "for _ in range(10):\n", 319 | " another_evaluator = evaluators.Evaluators(stats)\n", 320 | " v = another_evaluator.stats['return_stats']['episode_totals']\n", 321 | " v += np.random.randn(*v.shape) * 100.\n", 322 | " all_evaluators.append(another_evaluator)" 323 | ], 324 | "execution_count": 0, 325 | "outputs": [] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": { 330 | "id": "36t1n8k17oIp", 331 | "colab_type": "text" 332 | }, 333 | "source": [ 334 | "**Normalized regret across all runs.**" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "metadata": { 340 | "id": "J2XujjsfykR5", 341 | "colab_type": "code", 342 | "colab": {} 343 | }, 344 | "source": [ 345 | "evaluators.get_regret_plot(all_evaluators)\n", 346 | "plt.show()" 347 | ], 348 | "execution_count": 0, 349 | "outputs": [] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": { 354 | "id": "BNgepDFv7nna", 355 | "colab_type": "text" 356 | }, 357 | "source": [ 358 | "**Return across all runs.**" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "metadata": { 364 | "id": "khHpm_Gy5nOP", 365 | "colab_type": "code", 366 | "colab": {} 367 | }, 368 | "source": [ 369 | "evaluators.get_return_plot(all_evaluators, stride=500)\n", 370 | "plt.show()" 371 | ], 372 | "execution_count": 0, 373 | "outputs": [] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": { 378 | "id": "HY3NkWxNRceA", 379 | "colab_type": "text" 380 | }, 381 | "source": [ 382 | "**Additional useful functions**\n", 383 | "\n", 384 | "Multi-objective runs can be analyzed using:\n", 385 | "\n", 386 | "```\n", 387 | "evaluator.get_multiobjective_plot() # For a single run.\n", 388 | "evaluators.get_multiobjective_plot(all_evaluators) # For multiple runs.\n", 389 | "```" 390 | ] 391 | } 392 | ] 393 | } 394 | -------------------------------------------------------------------------------- /realworldrl_suite/environments/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Safety environments for real-world RL.""" 17 | from realworldrl_suite.environments import cartpole 18 | from realworldrl_suite.environments import humanoid 19 | from realworldrl_suite.environments import manipulator 20 | from realworldrl_suite.environments import quadruped 21 | from realworldrl_suite.environments import walker 22 | 23 | # This is a tuple of all the domains and tasks present in the suite. It is 24 | # currently used mainly for unit test coverage but can be useful if one wants 25 | # to sweep over all the tasks. 26 | ALL_TASKS = (('cartpole_balance', 'cartpole', 'realworld_balance'), 27 | ('cartpole_swingup', 'cartpole', 'realworld_swingup'), 28 | ('humanoid_stand', 'humanoid', 'realworld_stand'), 29 | ('humanoid_walk', 'humanoid', 'realworld_walk'), 30 | ('manipulator_bring_ball', 'manipulator', 'realworld_bring_ball'), 31 | ('manipulator_bring_peg', 'manipulator', 'realworld_bring_peg'), 32 | ('manipulator_insert_ball', 'manipulator', 33 | 'realworld_insert_ball'), 34 | ('manipulator_insert_peg', 'manipulator', 'realworld_insert_peg'), 35 | ('quadruped_walk', 'quadruped', 'realworld_walk'), 36 | ('quadruped_run', 'quadruped', 'realworld_run'), 37 | ('walker_stand', 'walker', 'realworld_stand'), 38 | ('walker_walk', 'walker', 'realworld_walk')) 39 | 40 | DOMAINS = dict( 41 | cartpole=cartpole, humanoid=humanoid, manipulator=manipulator, 42 | quadruped=quadruped, walker=walker) 43 | 44 | 45 | def load(domain_name, task_name, **kwargs): 46 | return DOMAINS[domain_name].load(task_name, **kwargs) 47 | -------------------------------------------------------------------------------- /realworldrl_suite/environments/cartpole.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Real-world control of cartpole.""" 17 | 18 | import collections 19 | 20 | import dm_control.suite.cartpole as cartpole 21 | import dm_control.suite.common as common 22 | from lxml import etree 23 | import numpy as np 24 | 25 | from realworldrl_suite.environments import realworld_env 26 | from realworldrl_suite.utils import loggers 27 | from realworldrl_suite.utils import wrappers 28 | 29 | _DEFAULT_TIME_LIMIT = 10 30 | 31 | PERTURB_PARAMS = ['pole_length', 'pole_mass', 'joint_damping', 'slider_damping'] 32 | 33 | 34 | def load(task_name, **task_kwargs): 35 | return globals()[task_name](**task_kwargs) 36 | 37 | 38 | # Task Constraints 39 | def slider_pos_constraint(env, safety_vars): 40 | """Slider must be within a certain area of the track.""" 41 | slider_pos = safety_vars['slider_pos'] 42 | return (np.greater(slider_pos, env.limits['slider_pos_constraint'][0]) and 43 | np.less(slider_pos, env.limits['slider_pos_constraint'][1])) 44 | 45 | 46 | def balance_velocity_constraint(env, safety_vars): 47 | """Joint angle velocity must be low when close to the goal.""" 48 | joint_angle_cos = safety_vars['joint_angle_cos'] 49 | joint_vel = safety_vars['joint_vel'] 50 | # When the angle is close to zero, and the velocity is larger than an amount 51 | # then the constraint is no longer satisfied. In cosine-space the cosine 52 | # of the angle needs to be greater than a certain value to be close to zero. 53 | return not ( 54 | np.greater(joint_angle_cos, 55 | np.cos(env.limits['balance_velocity_constraint'][0])) and 56 | np.greater(joint_vel, env.limits['balance_velocity_constraint'][1])[0]) 57 | 58 | 59 | def slider_accel_constraint(env, safety_vars): 60 | """Slider acceleration should never go above threshold.""" 61 | slider_accel = safety_vars['slider_accel'] 62 | return np.less(slider_accel, env.limits['slider_accel_constraint'])[0] 63 | 64 | 65 | # Action rate of change constraint. 66 | action_roc_constraint = realworld_env.action_roc_constraint 67 | 68 | 69 | def realworld_balance(time_limit=_DEFAULT_TIME_LIMIT, 70 | random=None, 71 | log_output=None, 72 | environment_kwargs=None, 73 | safety_spec=None, 74 | delay_spec=None, 75 | noise_spec=None, 76 | perturb_spec=None, 77 | dimensionality_spec=None, 78 | multiobj_spec=None, 79 | combined_challenge=None): 80 | """Returns the Cartpole Balance task with specified real world attributes. 81 | 82 | Args: 83 | time_limit: Integer length of task 84 | random: random seed (unsure) 85 | log_output: String of path for pickle data logging, None disables logging. 86 | environment_kwargs: additional kwargs for environment. 87 | safety_spec: dictionary that specifies the safety specifications. 88 | delay_spec: dictionary that specifies the delay specifications. 89 | noise_spec: dictionary that specifies the noise specifications. 90 | perturb_spec: dictionary that specifies the perturbations specifications. 91 | dimensionality_spec: dictionary that specifies extra observation features. 92 | multiobj_spec: dictionary that specifies complementary objectives. 93 | combined_challenge: string that can be 'easy', 'medium', or 'hard'. 94 | Specifying the combined challenge (can't be used with any other spec). 95 | """ 96 | physics = Physics.from_xml_string(*cartpole.get_model_and_assets()) 97 | safety_spec = safety_spec or {} 98 | delay_spec = delay_spec or {} 99 | noise_spec = noise_spec or {} 100 | perturb_spec = perturb_spec or {} 101 | dimensionality_spec = dimensionality_spec or {} 102 | multiobj_spec = multiobj_spec or {} 103 | # Check and update for combined challenge. 104 | (delay_spec, noise_spec, 105 | perturb_spec, dimensionality_spec) = ( 106 | realworld_env.get_combined_challenge( 107 | combined_challenge, delay_spec, noise_spec, perturb_spec, 108 | dimensionality_spec)) 109 | # Updating perturbation parameters if combined_challenge. 110 | if combined_challenge == 'easy': 111 | perturb_spec.update( 112 | {'param': 'pole_length', 'min': 0.9, 'max': 1.1, 'std': 0.02}) 113 | elif combined_challenge == 'medium': 114 | perturb_spec.update( 115 | {'param': 'pole_length', 'min': 0.7, 'max': 1.7, 'std': 0.1}) 116 | elif combined_challenge == 'hard': 117 | perturb_spec.update( 118 | {'param': 'pole_length', 'min': 0.5, 'max': 2.3, 'std': 0.15}) 119 | 120 | task = RealWorldBalance( 121 | swing_up=False, 122 | sparse=False, 123 | random=random, 124 | safety_spec=safety_spec, 125 | delay_spec=delay_spec, 126 | noise_spec=noise_spec, 127 | perturb_spec=perturb_spec, 128 | dimensionality_spec=dimensionality_spec, 129 | multiobj_spec=multiobj_spec 130 | ) 131 | environment_kwargs = environment_kwargs or {} 132 | if log_output: 133 | logger = loggers.PickleLogger(path=log_output) 134 | else: 135 | logger = None 136 | return wrappers.LoggingEnv( 137 | physics, task, logger=logger, time_limit=time_limit, **environment_kwargs) 138 | 139 | 140 | def realworld_swingup(time_limit=_DEFAULT_TIME_LIMIT, 141 | random=None, 142 | log_output=None, 143 | environment_kwargs=None, 144 | safety_spec=None, 145 | delay_spec=None, 146 | noise_spec=None, 147 | perturb_spec=None, 148 | dimensionality_spec=None, 149 | multiobj_spec=None, 150 | combined_challenge=None): 151 | """Returns the Cartpole Swing-Up task with specified real world attributes. 152 | 153 | Args: 154 | time_limit: Integer length of task 155 | random: random seed (unsure) 156 | log_output: String of path for pickle data logging, None disables logging 157 | environment_kwargs: additional kwargs for environment. 158 | safety_spec: dictionary that specifies the safety specifications. 159 | delay_spec: dictionary that specifies the delay specifications. 160 | noise_spec: dictionary that specifies the noise specifications. 161 | perturb_spec: dictionary that specifies the perturbations specifications. 162 | dimensionality_spec: dictionary that specifies extra observation features. 163 | multiobj_spec: dictionary that specifies complementary objectives. 164 | combined_challenge: string that can be 'easy', 'medium', or 'hard'. 165 | Specifying the combined challenge (can't be used with any other spec). 166 | """ 167 | physics = Physics.from_xml_string(*cartpole.get_model_and_assets()) 168 | safety_spec = safety_spec or {} 169 | delay_spec = delay_spec or {} 170 | noise_spec = noise_spec or {} 171 | perturb_spec = perturb_spec or {} 172 | dimensionality_spec = dimensionality_spec or {} 173 | multiobj_spec = multiobj_spec or {} 174 | # Check and update for combined challenge. 175 | (delay_spec, noise_spec, 176 | perturb_spec, dimensionality_spec) = ( 177 | realworld_env.get_combined_challenge( 178 | combined_challenge, delay_spec, noise_spec, perturb_spec, 179 | dimensionality_spec)) 180 | # Updating perturbation parameters if combined_challenge. 181 | if combined_challenge == 'easy': 182 | perturb_spec.update( 183 | {'param': 'pole_length', 'min': 0.9, 'max': 1.1, 'std': 0.02}) 184 | elif combined_challenge == 'medium': 185 | perturb_spec.update( 186 | {'param': 'pole_length', 'min': 0.7, 'max': 1.7, 'std': 0.1}) 187 | elif combined_challenge == 'hard': 188 | perturb_spec.update( 189 | {'param': 'pole_length', 'min': 0.5, 'max': 2.3, 'std': 0.15}) 190 | 191 | if 'limits' not in safety_spec: 192 | if 'safety_coeff' in safety_spec: 193 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1: 194 | raise ValueError('safety_coeff should be in [0,1], but got {}'.format( 195 | safety_spec['safety_coeff'])) 196 | safety_coeff = safety_spec['safety_coeff'] 197 | else: 198 | safety_coeff = 1 199 | safety_spec['limits'] = { 200 | 'slider_pos_constraint': 201 | safety_coeff * np.array([-2, 2]), # m 202 | 'balance_velocity_constraint': 203 | np.array([(1 - safety_coeff) / 0.5 + 0.15, 204 | safety_coeff * 0.5]), # rad, rad/s 205 | 'slider_accel_constraint': 206 | safety_coeff * 130, # m/s^2 207 | 'action_roc_constraint': safety_coeff * 1.5, 208 | } 209 | 210 | task = RealWorldBalance( 211 | swing_up=True, 212 | sparse=False, 213 | random=random, 214 | safety_spec=safety_spec, 215 | delay_spec=delay_spec, 216 | noise_spec=noise_spec, 217 | perturb_spec=perturb_spec, 218 | dimensionality_spec=dimensionality_spec, 219 | multiobj_spec=multiobj_spec 220 | ) 221 | environment_kwargs = environment_kwargs or {} 222 | if log_output: 223 | logger = loggers.PickleLogger(path=log_output) 224 | else: 225 | logger = None 226 | return wrappers.LoggingEnv( 227 | physics, task, logger=logger, time_limit=time_limit, **environment_kwargs) 228 | 229 | 230 | class Physics(cartpole.Physics): 231 | """Inherits from cartpole.Physics.""" 232 | 233 | 234 | class RealWorldBalance(realworld_env.Base, cartpole.Balance): 235 | """A Cartpole task with real-world specifications. 236 | 237 | Subclasses dm_control.suite.cartpole. 238 | 239 | Safety: 240 | Adds a set of constraints on the task. 241 | Returns an additional entry in the observations ('constraints') in the 242 | length of the number of the constraints, where each entry is True if the 243 | constraint is satisfied and False otherwise. 244 | 245 | Delays: 246 | Adds actions, observations, and rewards delays. 247 | Actions delay is the number of steps between passing the action to the 248 | environment to when it is actually performed, and observations (rewards) 249 | delay is the offset of freshness of the returned observation (reward) after 250 | performing a step. 251 | 252 | Noise: 253 | Adds action or observation noise. 254 | Different noise include: white Gaussian actions/observations, 255 | dropped actions/observations values, stuck actions/observations values, 256 | or repetitive actions. 257 | 258 | Perturbations: 259 | Perturbs physical quantities of the environment. These perturbations are 260 | non-stationary and are governed by a scheduler. 261 | 262 | Dimensionality: 263 | Adds extra dummy features to observations to increase dimensionality of the 264 | state space. 265 | 266 | Multi-Objective Reward: 267 | Adds additional objectives and specifies objectives interaction (e.g., sum). 268 | """ 269 | 270 | def __init__(self, safety_spec, delay_spec, noise_spec, perturb_spec, 271 | dimensionality_spec, multiobj_spec, **kwargs): 272 | """Initialize the RealWorldBalance task. 273 | 274 | Args: 275 | safety_spec: dictionary that specifies the safety specifications of the 276 | task. It may contain the following fields: 277 | enable- bool that represents whether safety specifications are enabled. 278 | constraints- list of class methods returning boolean constraint 279 | satisfactions. 280 | limits- dictionary of constants used by the functions in 'constraints'. 281 | safety_coeff - a scalar between 1 and 0 that scales safety constraints, 282 | 1 producing the base constraints, and 0 likely producing an 283 | unsolveable task. 284 | observations- a default-True boolean that toggles the whether a vector 285 | of satisfied constraints is added to observations. 286 | delay_spec: dictionary that specifies the delay specifications of the 287 | task. It may contain the following fields: 288 | enable- bool that represents whether delay specifications are enabled. 289 | actions- integer indicating the number of steps actions are being 290 | delayed. 291 | observations- integer indicating the number of steps observations are 292 | being delayed. 293 | rewards- integer indicating the number of steps observations are being 294 | delayed. 295 | noise_spec: dictionary that specifies the noise specifications of the 296 | task. It may contains the following fields: 297 | gaussian- dictionary that specifies the white Gaussian additive noise. 298 | It may contain the following fields: 299 | enable- bool that represents whether noise specifications are enabled. 300 | actions- float inidcating the standard deviation of a white Gaussian 301 | noise added to each action. 302 | observations- similarly, additive white Gaussian noise to each 303 | returned observation. 304 | dropped- dictionary that specifies the dropped values noise. 305 | It may contain the following fields: 306 | enable- bool that represents whether dropped values specifications are 307 | enabled. 308 | observations_prob- float in [0,1] indicating the probability of 309 | dropping each observation component independently. 310 | observations_steps- positive integer indicating the number of time 311 | steps of dropping a value (setting to zero) if dropped. 312 | actions_prob- float in [0,1] indicating the probability of dropping 313 | each action component independently. 314 | actions_steps- positive integer indicating the number of time steps of 315 | dropping a value (setting to zero) if dropped. 316 | stuck- dictionary that specifies the stuck values noise. 317 | It may contain the following fields: 318 | enable- bool that represents whether stuck values specifications are 319 | enabled. 320 | observations_prob- float in [0,1] indicating the probability of each 321 | observation component becoming stuck. 322 | observations_steps- positive integer indicating the number of time 323 | steps an observation (or components of) stays stuck. 324 | actions_prob- float in [0,1] indicating the probability of each 325 | action component becoming stuck. 326 | actions_steps- positive integer indicating the number of time 327 | steps an action (or components of) stays stuck. 328 | repetition- dictionary that specifies the repetition statistics. 329 | It may contain the following fields: 330 | enable- bool that represents whether repetition specifications are 331 | enabled. 332 | actions_prob- float in [0,1] indicating the probability of the actions 333 | to be repeated in the following steps. 334 | actions_steps- positive integer indicating the number of time steps of 335 | repeating the same action if it to be repeated. 336 | perturb_spec: dictionary that specifies the perturbation specifications 337 | of the task. It may contain the following fields: 338 | enable- bool that represents whether perturbation specifications are 339 | enabled. 340 | period- int, number of episodes between updates perturbation updates. 341 | param - string indicating which parameter to perturb (currently 342 | supporting pole_length, pole_mass, joint_damping, slider_damping). 343 | scheduler- string inidcating the scheduler to apply to the perturbed 344 | parameter (currently supporting constant, random_walk, drift_pos, 345 | drift_neg, cyclic_pos, cyclic_neg, uniform, and saw_wave). 346 | start - float indicating the initial value of the perturbed parameter. 347 | min - float indicating the minimal value the perturbed parameter may be. 348 | max - float indicating the maximal value the perturbed parameter may be. 349 | std - float indicating the standard deviation of the white noise for the 350 | scheduling process. 351 | dimensionality_spec: dictionary that specifies the added dimensions to the 352 | state space. It may contain the following fields: 353 | enable- bool that represents whether dimensionality specifications are 354 | enabled. 355 | num_random_state_observations - num of random (unit Gaussian) 356 | observations to add. 357 | multiobj_spec: dictionary that sets up the multi-objective challenge. 358 | The challenge works by providing an `Objective` object which describes 359 | both numerical objectives and a reward-merging method that allow to both 360 | observe the various objectives in the observation and affect the 361 | returned reward in a manner defined by the Objective object. 362 | enable- bool that represents whether delay multi-objective 363 | specifications are enabled. 364 | objective - either a string which will load an `Objective` class from 365 | utils.multiobj_objectives.OBJECTIVES, or an Objective object which 366 | subclasses utils.multiobj_objectives.Objective. 367 | reward - boolean indicating whether to add the multiobj objective's 368 | reward to the environment's returned reward. 369 | coeff - a number in [0,1] that is passed into the Objective object to 370 | change the mix between the original reward and the Objective's 371 | rewards. 372 | observed - boolean indicating whether the defined objectives should be 373 | added to the observation. 374 | **kwargs: extra parameters passed to parent class (cartpole.Balance) 375 | """ 376 | # Initialize parent classes. 377 | realworld_env.Base.__init__(self) 378 | cartpole.Balance.__init__(self, **kwargs) 379 | 380 | # Safety setup. 381 | self._setup_safety(safety_spec) 382 | 383 | # Delay setup. 384 | realworld_env.Base._setup_delay(self, delay_spec) 385 | 386 | # Noise setup. 387 | realworld_env.Base._setup_noise(self, noise_spec) 388 | 389 | # Perturb setup. 390 | self._setup_perturb(perturb_spec) 391 | 392 | # Dimensionality setup 393 | realworld_env.Base._setup_dimensionality(self, dimensionality_spec) 394 | 395 | # Multi-objective setup 396 | realworld_env.Base._setup_multiobj(self, multiobj_spec) 397 | 398 | # Safety methods. 399 | def _setup_safety(self, safety_spec): 400 | """Setup for the safety specifications of the task.""" 401 | self._safety_enabled = safety_spec.get('enable', False) 402 | self._safety_observed = safety_spec.get('observations', True) 403 | 404 | if self._safety_enabled: 405 | # Add safety specifications. 406 | if 'constraints' in safety_spec: 407 | self.constraints = safety_spec['constraints'] 408 | else: 409 | self.constraints = collections.OrderedDict([ 410 | ('slider_pos_constraint', slider_pos_constraint), 411 | ('slider_accel_constraint', slider_accel_constraint), 412 | ('balance_velocity_constraint', balance_velocity_constraint) 413 | ]) 414 | if 'limits' in safety_spec: 415 | self.limits = safety_spec['limits'] 416 | else: 417 | if 'safety_coeff' in safety_spec: 418 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1: 419 | raise ValueError( 420 | 'safety_coeff should be in [0,1], but got {}'.format( 421 | safety_spec['safety_coeff'])) 422 | safety_coeff = safety_spec['safety_coeff'] 423 | else: 424 | safety_coeff = 1 425 | self.limits = { 426 | 'slider_pos_constraint': 427 | safety_coeff * np.array([-1.5, 1.5]), # m 428 | 'balance_velocity_constraint': 429 | np.array([(1 - safety_coeff) / 0.5 + 0.15, 430 | safety_coeff * 0.5]), # rad, rad/s 431 | 'slider_accel_constraint': 432 | safety_coeff * 10, # m/s^2 433 | 'action_roc_constraint': safety_coeff * 1.5 434 | } 435 | self._constraints_obs = np.ones(len(self.constraints), dtype=bool) 436 | 437 | def safety_vars(self, physics): 438 | safety_vars = collections.OrderedDict( 439 | slider_pos=physics.cart_position().copy(), 440 | joint_angle_cos=physics.pole_angle_cosine().copy(), 441 | joint_vel=np.abs(physics.angular_vel().copy()), 442 | slider_accel=np.abs(physics.named.data.qacc['slider'].copy()), 443 | actions=physics.control(),) 444 | return safety_vars 445 | 446 | def _setup_perturb(self, perturb_spec): 447 | """Setup for the perturbations specification of the task.""" 448 | self._perturb_enabled = perturb_spec.get('enable', False) 449 | self._perturb_period = perturb_spec.get('period', 1) 450 | 451 | if self._perturb_enabled: 452 | # Add perturbations specifications. 453 | self._perturb_param = perturb_spec.get('param', 'pole_length') 454 | # Making sure object to perturb is supported. 455 | if self._perturb_param not in PERTURB_PARAMS: 456 | raise ValueError("""param was: {}. Currently only supporting {}. 457 | """.format(self._perturb_param, PERTURB_PARAMS)) 458 | 459 | # Setting perturbation function. 460 | self._perturb_scheduler = perturb_spec.get('scheduler', 'constant') 461 | if self._perturb_scheduler not in realworld_env.PERTURB_SCHEDULERS: 462 | raise ValueError("""scheduler was: {}. Currently only supporting {}. 463 | """.format(self._perturb_scheduler, realworld_env.PERTURB_SCHEDULERS)) 464 | 465 | # Setting perturbation process parameters. 466 | if self._perturb_param == 'pole_length': 467 | self._perturb_cur = perturb_spec.get('start', 1.) 468 | self._perturb_start = perturb_spec.get('start', 1.) 469 | self._perturb_min = perturb_spec.get('min', 0.3) 470 | self._perturb_max = perturb_spec.get('max', 3.) 471 | self._perturb_std = perturb_spec.get('std', 0.3) 472 | elif self._perturb_param == 'pole_mass': 473 | self._perturb_cur = perturb_spec.get('start', 0.1) 474 | self._perturb_start = perturb_spec.get('start', 0.1) 475 | self._perturb_min = perturb_spec.get('min', 0.1) 476 | self._perturb_max = perturb_spec.get('max', 10.) 477 | self._perturb_std = perturb_spec.get('std', 0.5) 478 | elif self._perturb_param == 'joint_damping': 479 | self._perturb_cur = perturb_spec.get('start', 2e-6) 480 | self._perturb_start = perturb_spec.get('start', 2e-6) 481 | self._perturb_min = perturb_spec.get('min', 2e-6) 482 | self._perturb_max = perturb_spec.get('max', 2e-1) 483 | self._perturb_std = perturb_spec.get('std', 2e-2) 484 | elif self._perturb_param == 'slider_damping': 485 | self._perturb_cur = perturb_spec.get('start', 5e-4) 486 | self._perturb_start = perturb_spec.get('start', 5e-4) 487 | self._perturb_min = perturb_spec.get('min', 5e-4) 488 | self._perturb_max = perturb_spec.get('max', 3.0) 489 | self._perturb_std = perturb_spec.get('std', 0.3) 490 | 491 | def update_physics(self): 492 | """Returns a new Physics object with perturbed parameter.""" 493 | # Generate the new perturbed parameter. 494 | realworld_env.Base._generate_parameter(self) 495 | 496 | # Create new physics object with the perturb parameter. 497 | xml_string = common.read_model('cartpole.xml') 498 | mjcf = etree.fromstring(xml_string) 499 | 500 | if self._perturb_param in ['pole_length', 'pole_mass']: 501 | pole = mjcf.find('./default/default/geom') 502 | if self._perturb_param == 'pole_length': 503 | pole.set('fromto', '0 0 0 0 0 {}'.format(self._perturb_cur)) 504 | pole.set('mass', str(self._perturb_cur / 10.)) 505 | elif self._perturb_param == 'pole_mass': 506 | pole.set('mass', str(self._perturb_cur)) 507 | elif self._perturb_param == 'joint_damping': 508 | pole_joint = mjcf.find('./default/default/joint') 509 | pole_joint.set('damping', str(self._perturb_cur)) 510 | elif self._perturb_param == 'slider_damping': 511 | sliders_joint = mjcf.find('./worldbody/body/joint') 512 | sliders_joint.set('damping', str(self._perturb_cur)) 513 | 514 | xml_string_modified = etree.tostring(mjcf, pretty_print=True) 515 | physics = Physics.from_xml_string(xml_string_modified, common.ASSETS) 516 | return physics 517 | 518 | def before_step(self, action, physics): 519 | """Updates the environment using the action and returns a `TimeStep`.""" 520 | self._last_action = physics.control() 521 | action_min = self.action_spec(physics).minimum[:] 522 | action_max = self.action_spec(physics).maximum[:] 523 | action = realworld_env.Base.before_step(self, action, action_min, 524 | action_max) 525 | cartpole.Balance.before_step(self, action, physics) 526 | 527 | def after_step(self, physics): 528 | realworld_env.Base.after_step(self, physics) 529 | cartpole.Balance.after_step(self, physics) 530 | self._last_action = None 531 | -------------------------------------------------------------------------------- /realworldrl_suite/environments/humanoid.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Real-world control of humanoid.""" 17 | 18 | import collections 19 | 20 | from dm_control.suite import common 21 | from dm_control.suite import humanoid 22 | from lxml import etree 23 | import numpy as np 24 | 25 | from realworldrl_suite.environments import realworld_env 26 | from realworldrl_suite.utils import loggers 27 | from realworldrl_suite.utils import wrappers 28 | 29 | _DEFAULT_TIME_LIMIT = 25 30 | _CONTROL_TIMESTEP = .025 31 | 32 | # Height of head above which stand reward is 1. 33 | _STAND_HEIGHT = 1.4 34 | 35 | # Horizontal speeds above which move reward is 1. 36 | _WALK_SPEED = 1 37 | _RUN_SPEED = 10 38 | 39 | PERTURB_PARAMS = ['joint_damping', 'contact_friction', 'head_size'] 40 | 41 | 42 | def load(task_name, **task_kwargs): 43 | return globals()[task_name](**task_kwargs) 44 | 45 | 46 | # Task Constraints 47 | def joint_angle_constraint(env, safety_vars): 48 | """Slider must be within a certain area of the track.""" 49 | joint_pos = safety_vars['joint_pos'] 50 | return np.less(np.max(np.abs(joint_pos)), env.limits['joint_pos_constraint']) 51 | 52 | 53 | def joint_velocity_constraint(env, safety_vars): 54 | """Joint angle velocities must stay below a certain limit.""" 55 | joint_vels = safety_vars['joint_vels'] 56 | return np.less(np.max(joint_vels), env.limits['joint_velocity_constraint']) 57 | 58 | 59 | def upright_constraint(env, safety_vars): 60 | """Pelvis orientantion should remain upright.""" 61 | z_up = safety_vars['z_up'] 62 | return np.greater(z_up, env.limits['torso_upright_constraint']) 63 | 64 | 65 | def foot_force_constraint(env, safety_vars): 66 | """Limit foot contact forces, uses ankle forces as proxy.""" 67 | force_max = np.max( 68 | [safety_vars[idx] for idx in ['rl_f', 'rr_f', 'll_f', 'lr_f']]) 69 | return np.less(force_max, env.limits['foot_force_constraint']) 70 | 71 | 72 | def dangerous_fall_constraint(env, safety_vars): 73 | """Discourage dangerous falls (head & torso contact).""" 74 | force_max = np.max( 75 | [safety_vars[idx] for idx in ['head_touch', 'torso_touch']]) 76 | return np.less(force_max, env.limits['dangerous_fall_constraint']) 77 | 78 | 79 | # Action rate of change constraint. 80 | action_roc_constraint = realworld_env.action_roc_constraint 81 | 82 | 83 | def realworld_stand(time_limit=_DEFAULT_TIME_LIMIT, 84 | random=None, 85 | log_output=None, 86 | environment_kwargs=None, 87 | safety_spec=None, 88 | delay_spec=None, 89 | noise_spec=None, 90 | perturb_spec=None, 91 | dimensionality_spec=None, 92 | multiobj_spec=None, 93 | combined_challenge=None): 94 | """Returns the Humanoid Stand task with specified real world attributes. 95 | 96 | Args: 97 | time_limit: Integer length of task 98 | random: random seed (unsure) 99 | log_output: String of path for pickle data logging, None disables logging 100 | environment_kwargs: additional kwargs for environment. 101 | safety_spec: dictionary that specifies the safety specifications. 102 | delay_spec: dictionary that specifies the delay specifications. 103 | noise_spec: dictionary that specifies the noise specifications. 104 | perturb_spec: dictionary that specifies the perturbations specifications. 105 | dimensionality_spec: dictionary that specifies extra observation features. 106 | multiobj_spec: dictionary that specifies complementary objectives. 107 | combined_challenge: string that can be 'easy', 'medium', or 'hard'. 108 | Specifying the combined challenge (can't be used with any other spec). 109 | """ 110 | physics = humanoid.Physics.from_xml_string(*humanoid.get_model_and_assets()) 111 | safety_spec = safety_spec or {} 112 | delay_spec = delay_spec or {} 113 | noise_spec = noise_spec or {} 114 | perturb_spec = perturb_spec or {} 115 | dimensionality_spec = dimensionality_spec or {} 116 | multiobj_spec = multiobj_spec or {} 117 | # Check and update for combined challenge. 118 | (delay_spec, noise_spec, 119 | perturb_spec, dimensionality_spec) = ( 120 | realworld_env.get_combined_challenge( 121 | combined_challenge, delay_spec, noise_spec, perturb_spec, 122 | dimensionality_spec)) 123 | # Updating perturbation parameters if combined_challenge. 124 | if combined_challenge == 'easy': 125 | perturb_spec.update( 126 | {'param': 'contact_friction', 'min': 0.6, 'max': 0.8, 'std': 0.02}) 127 | elif combined_challenge == 'medium': 128 | perturb_spec.update( 129 | {'param': 'contact_friction', 'min': 0.5, 'max': 0.9, 'std': 0.04}) 130 | elif combined_challenge == 'hard': 131 | perturb_spec.update( 132 | {'param': 'contact_friction', 'min': 0.4, 'max': 1.0, 'std': 0.06}) 133 | 134 | task = RealWorldHumanoid( 135 | move_speed=0, 136 | pure_state=False, 137 | random=random, 138 | safety_spec=safety_spec, 139 | delay_spec=delay_spec, 140 | noise_spec=noise_spec, 141 | perturb_spec=perturb_spec, 142 | dimensionality_spec=dimensionality_spec, 143 | multiobj_spec=multiobj_spec) 144 | environment_kwargs = environment_kwargs or {} 145 | if log_output: 146 | logger = loggers.PickleLogger(path=log_output) 147 | else: 148 | logger = None 149 | return wrappers.LoggingEnv( 150 | physics, 151 | task, 152 | logger=logger, 153 | time_limit=time_limit, 154 | control_timestep=_CONTROL_TIMESTEP, 155 | **environment_kwargs) 156 | 157 | 158 | def realworld_walk(time_limit=_DEFAULT_TIME_LIMIT, 159 | random=None, 160 | log_output=None, 161 | environment_kwargs=None, 162 | safety_spec=None, 163 | delay_spec=None, 164 | noise_spec=None, 165 | perturb_spec=None, 166 | dimensionality_spec=None, 167 | multiobj_spec=None, 168 | combined_challenge=None): 169 | """Returns the Walk task with specified real world attributes. 170 | 171 | Args: 172 | time_limit: Integer length of task 173 | random: random seed (unsure) 174 | log_output: String of path for pickle data logging, None disables logging 175 | environment_kwargs: additional kwargs for environment. 176 | safety_spec: dictionary that specifies the safety specifications. 177 | delay_spec: dictionary that specifies the delay specifications. 178 | noise_spec: dictionary that specifies the noise specifications. 179 | perturb_spec: dictionary that specifies the perturbations specifications. 180 | dimensionality_spec: dictionary that specifies extra observation features. 181 | multiobj_spec: dictionary that specifies complementary objectives. 182 | combined_challenge: string that can be 'easy', 'medium', or 'hard'. 183 | Specifying the combined challenge (can't be used with any other spec). 184 | """ 185 | physics = humanoid.Physics.from_xml_string(*humanoid.get_model_and_assets()) 186 | safety_spec = safety_spec or {} 187 | delay_spec = delay_spec or {} 188 | noise_spec = noise_spec or {} 189 | perturb_spec = perturb_spec or {} 190 | dimensionality_spec = dimensionality_spec or {} 191 | multiobj_spec = multiobj_spec or {} 192 | # Check and update for combined challenge. 193 | (delay_spec, noise_spec, 194 | perturb_spec, dimensionality_spec) = ( 195 | realworld_env.get_combined_challenge( 196 | combined_challenge, delay_spec, noise_spec, perturb_spec, 197 | dimensionality_spec)) 198 | # Updating perturbation parameters if combined_challenge. 199 | if combined_challenge == 'easy': 200 | perturb_spec.update( 201 | {'param': 'contact_friction', 'min': 0.6, 'max': 0.8, 'std': 0.02}) 202 | elif combined_challenge == 'medium': 203 | perturb_spec.update( 204 | {'param': 'contact_friction', 'min': 0.5, 'max': 0.9, 'std': 0.04}) 205 | elif combined_challenge == 'hard': 206 | perturb_spec.update( 207 | {'param': 'contact_friction', 'min': 0.4, 'max': 1.0, 'std': 0.06}) 208 | 209 | if 'limits' not in safety_spec: 210 | if 'safety_coeff' in safety_spec: 211 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1: 212 | raise ValueError('safety_coeff should be in [0,1], but got {}'.format( 213 | safety_spec['safety_coeff'])) 214 | safety_coeff = safety_spec['safety_coeff'] 215 | else: 216 | safety_coeff = 1 217 | 218 | safety_spec['limits'] = { 219 | 'joint_pos_constraint': safety_coeff * np.pi, # rad 220 | 'joint_velocity_constraint': safety_coeff * 90, # rad/s 221 | 'foot_force_constraint': safety_coeff * 8000, # newtons 222 | 'dangerous_fall_constraint': safety_coeff * 100, # newtons 223 | 'torso_upright_constraint': (1 - safety_coeff), # vector magnitude 224 | 'action_roc_constraint': safety_coeff * 1.85, 225 | } 226 | 227 | task = RealWorldHumanoid( 228 | move_speed=1, 229 | pure_state=False, 230 | random=random, 231 | safety_spec=safety_spec, 232 | delay_spec=delay_spec, 233 | noise_spec=noise_spec, 234 | perturb_spec=perturb_spec, 235 | dimensionality_spec=dimensionality_spec, 236 | multiobj_spec=multiobj_spec) 237 | environment_kwargs = environment_kwargs or {} 238 | if log_output: 239 | logger = loggers.PickleLogger(path=log_output) 240 | else: 241 | logger = None 242 | return wrappers.LoggingEnv( 243 | physics, 244 | task, 245 | logger=logger, 246 | time_limit=time_limit, 247 | control_timestep=_CONTROL_TIMESTEP, 248 | **environment_kwargs) 249 | 250 | 251 | class RealWorldHumanoid(realworld_env.Base, humanoid.Humanoid): 252 | """A Humanoid task with real-world specifications. 253 | 254 | Subclasses dm_control.suite.humanoid. 255 | 256 | Safety: 257 | Adds a set of constraints on the task. 258 | Returns an additional entry in the observations ('constraints') in the 259 | length of the number of the constraints, where each entry is True if the 260 | constraint is satisfied and False otherwise. 261 | 262 | Delays: 263 | Adds actions, observations, and rewards delays. 264 | Actions delay is the number of steps between passing the action to the 265 | environment to when it is actually performed, and observations (rewards) 266 | delay is the offset of freshness of the returned observation (reward) after 267 | performing a step. 268 | 269 | Noise: 270 | Adds action or observation noise. 271 | Different noise include: white Gaussian actions/observations, 272 | dropped actions/observations values, stuck actions/observations values, 273 | or repetitive actions. 274 | 275 | Perturbations: 276 | Perturbs physical quantities of the environment. These perturbations are 277 | non-stationary and are governed by a scheduler. 278 | 279 | Dimensionality: 280 | Adds extra dummy features to observations to increase dimensionality of the 281 | state space. 282 | 283 | Multi-Objective Reward: 284 | Adds additional objectives and specifies objectives interaction (e.g., sum). 285 | """ 286 | 287 | def __init__(self, move_speed, pure_state, safety_spec, delay_spec, 288 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec, 289 | **kwargs): 290 | """Initialize the RealWorldHumanoid task. 291 | 292 | Args: 293 | move_speed: float. If this value is zero, reward is given simply for 294 | standing up. Otherwise this specifies a target horizontal velocity for 295 | the walking task. 296 | pure_state: bool. Whether the observations consist of the pure MuJoCo 297 | state or includes some useful features thereof. 298 | safety_spec: dictionary that specifies the safety specifications of the 299 | task. It may contain the following fields: 300 | enable- bool that represents whether safety specifications are enabled. 301 | constraints- list of class methods returning boolean constraint 302 | satisfactions. 303 | limits- dictionary of constants used by the functions in 'constraints'. 304 | safety_coeff - a scalar between 1 and 0 that scales safety constraints, 305 | 1 producing the base constraints, and 0 likely producing an 306 | unsolveable task. 307 | observations- a default-True boolean that toggles the whether a vector 308 | of satisfied constraints is added to observations. 309 | delay_spec: dictionary that specifies the delay specifications of the 310 | task. It may contain the following fields: 311 | enable- bool that represents whether delay specifications are enabled. 312 | actions- integer indicating the number of steps actions are being 313 | delayed. 314 | observations- integer indicating the number of steps observations are 315 | being delayed. 316 | rewards- integer indicating the number of steps observations are being 317 | delayed. 318 | noise_spec: dictionary that specifies the noise specifications of the 319 | task. It may contains the following fields: 320 | gaussian- dictionary that specifies the white Gaussian additive noise. 321 | It may contain the following fields: 322 | enable- bool that represents whether noise specifications are enabled. 323 | actions- float inidcating the standard deviation of a white Gaussian 324 | noise added to each action. 325 | observations- similarly, additive white Gaussian noise to each 326 | returned observation. 327 | dropped- dictionary that specifies the dropped values noise. 328 | It may contain the following fields: 329 | enable- bool that represents whether dropped values specifications are 330 | enabled. 331 | observations_prob- float in [0,1] indicating the probability of 332 | dropping each observation component independently. 333 | observations_steps- positive integer indicating the number of time 334 | steps of dropping a value (setting to zero) if dropped. 335 | actions_prob- float in [0,1] indicating the probability of dropping 336 | each action component independently. 337 | actions_steps- positive integer indicating the number of time steps of 338 | dropping a value (setting to zero) if dropped. 339 | stuck- dictionary that specifies the stuck values noise. 340 | It may contain the following fields: 341 | enable- bool that represents whether stuck values specifications are 342 | enabled. 343 | observations_prob- float in [0,1] indicating the probability of each 344 | observation component becoming stuck. 345 | observations_steps- positive integer indicating the number of time 346 | steps an observation (or components of) stays stuck. 347 | actions_prob- float in [0,1] indicating the probability of each 348 | action component becoming stuck. 349 | actions_steps- positive integer indicating the number of time 350 | steps an action (or components of) stays stuck. 351 | repetition- dictionary that specifies the repetition statistics. 352 | It may contain the following fields: 353 | enable- bool that represents whether repetition specifications are 354 | enabled. 355 | actions_prob- float in [0,1] indicating the probability of the actions 356 | to be repeated in the following steps. 357 | actions_steps- positive integer indicating the number of time steps of 358 | repeating the same action if it to be repeated. 359 | perturb_spec: dictionary that specifies the perturbation specifications 360 | of the task. It may contain the following fields: 361 | enable- bool that represents whether perturbation specifications are 362 | enabled. 363 | period- int, number of episodes between updates perturbation updates. 364 | param - string indicating which parameter to perturb (currently 365 | supporting joint_damping, contact_friction, head_size). 366 | scheduler- string inidcating the scheduler to apply to the perturbed 367 | parameter (currently supporting constant, random_walk, drift_pos, 368 | drift_neg, cyclic_pos, cyclic_neg, uniform, and saw_wave). 369 | start - float indicating the initial value of the perturbed parameter. 370 | min - float indicating the minimal value the perturbed parameter may be. 371 | max - float indicating the maximal value the perturbed parameter may be. 372 | std - float indicating the standard deviation of the white noise for the 373 | scheduling process. 374 | dimensionality_spec: dictionary that specifies the added dimensions to the 375 | state space. It may contain the following fields: 376 | enable - bool that represents whether dimensionality specifications are 377 | enabled. 378 | num_random_state_observations - num of random (unit Gaussian) 379 | observations to add. 380 | multiobj_spec: dictionary that sets up the multi-objective challenge. 381 | The challenge works by providing an `Objective` object which describes 382 | both numerical objectives and a reward-merging method that allow to both 383 | observe the various objectives in the observation and affect the 384 | returned reward in a manner defined by the Objective object. 385 | enable- bool that represents whether delay multi-objective 386 | specifications are enabled. 387 | objective - either a string which will load an `Objective` class from 388 | utils.multiobj_objectives.OBJECTIVES, or an Objective object which 389 | subclasses utils.multiobj_objectives.Objective. 390 | reward - boolean indicating whether to add the multiobj objective's 391 | reward to the environment's returned reward. 392 | coeff - a number in [0,1] that is passed into the Objective object to 393 | change the mix between the original reward and the Objective's 394 | rewards. 395 | observed - boolean indicating whether the defined objectives should be 396 | added to the observation. 397 | **kwargs: extra parameters passed to parent class (humanoid.Humanoid) 398 | """ 399 | # Initialize parent classes. 400 | realworld_env.Base.__init__(self) 401 | humanoid.Humanoid.__init__(self, move_speed, pure_state, **kwargs) 402 | 403 | # Safety setup. 404 | self._setup_safety(safety_spec) 405 | 406 | # Delay setup. 407 | realworld_env.Base._setup_delay(self, delay_spec) 408 | 409 | # Noise setup. 410 | realworld_env.Base._setup_noise(self, noise_spec) 411 | 412 | # Perturb setup. 413 | self._setup_perturb(perturb_spec) 414 | 415 | # Dimensionality setup 416 | realworld_env.Base._setup_dimensionality(self, dimensionality_spec) 417 | 418 | # Multi-objective setup 419 | realworld_env.Base._setup_multiobj(self, multiobj_spec) 420 | 421 | # Safety methods. 422 | def _setup_safety(self, safety_spec): 423 | """Setup for the safety specifications of the task.""" 424 | self._safety_enabled = safety_spec.get('enable', False) 425 | self._safety_observed = safety_spec.get('observations', True) 426 | 427 | if self._safety_enabled: 428 | # Add safety specifications. 429 | if 'constraints' in safety_spec: 430 | self.constraints = safety_spec['constraints'] 431 | else: 432 | self.constraints = collections.OrderedDict([ 433 | ('joint_angle_constraint', joint_angle_constraint), 434 | ('joint_velocity_constraint', joint_velocity_constraint), 435 | ('upright_constraint', upright_constraint), 436 | ('dangerous_fall_constraint', dangerous_fall_constraint), 437 | ('foot_force_constraint', foot_force_constraint) 438 | ]) 439 | if 'limits' in safety_spec: 440 | self.limits = safety_spec['limits'] 441 | else: 442 | if 'safety_coeff' in safety_spec: 443 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1: 444 | raise ValueError( 445 | 'safety_coeff should be in [0,1], but got {}'.format( 446 | safety_spec['safety_coeff'])) 447 | safety_coeff = safety_spec['safety_coeff'] 448 | else: 449 | safety_coeff = 1 450 | self.limits = { 451 | 'joint_pos_constraint': safety_coeff * np.pi, # rad 452 | 'joint_velocity_constraint': safety_coeff * 90, # rad/s 453 | 'foot_force_constraint': safety_coeff * 8000, # newtons 454 | 'dangerous_fall_constraint': safety_coeff * 100, # newtons 455 | 'torso_upright_constraint': 456 | (1 - safety_coeff), # vector magnitude 457 | 'action_roc_constraint': safety_coeff * 1.85, 458 | } 459 | self._constraints_obs = np.ones(len(self.constraints), dtype=bool) 460 | 461 | def safety_vars(self, physics): 462 | """Centralized retrieval of safety-related variables to simplify logging.""" 463 | safety_vars = collections.OrderedDict( 464 | joint_pos=physics.named.data.qpos[7:].copy(), 465 | joint_vels=np.abs(physics.named.data.qvel[7:]).copy(), 466 | z_up=physics.torso_upright(), 467 | rl_f=np.linalg.norm( 468 | physics.named.data.sensordata['right_left_foot_touch'].copy()), 469 | rr_f=np.linalg.norm( 470 | physics.named.data.sensordata['right_right_foot_touch'].copy()), 471 | ll_f=np.linalg.norm( 472 | physics.named.data.sensordata['left_left_foot_touch'].copy()), 473 | lr_f=np.linalg.norm( 474 | physics.named.data.sensordata['left_right_foot_touch'].copy()), 475 | head_touch=np.linalg.norm( 476 | physics.named.data.sensordata['head_touch'].copy()), 477 | torso_touch=np.linalg.norm( 478 | physics.named.data.sensordata['torso_touch'].copy()), 479 | actions=physics.control(), 480 | ) 481 | return safety_vars 482 | 483 | def _setup_perturb(self, perturb_spec): 484 | """Setup for the perturbations specification of the task.""" 485 | self._perturb_enabled = perturb_spec.get('enable', False) 486 | self._perturb_period = perturb_spec.get('period', 1) 487 | 488 | if self._perturb_enabled: 489 | # Add perturbations specifications. 490 | self._perturb_param = perturb_spec.get('param', 'contact_friction') 491 | # Making sure object to perturb is supported. 492 | if self._perturb_param not in PERTURB_PARAMS: 493 | raise ValueError("""param was: {}. Currently only supporting {}. 494 | """.format(self._perturb_param, PERTURB_PARAMS)) 495 | 496 | # Setting perturbation function. 497 | self._perturb_scheduler = perturb_spec.get('scheduler', 'constant') 498 | if self._perturb_scheduler not in realworld_env.PERTURB_SCHEDULERS: 499 | raise ValueError("""scheduler was: {}. Currently only supporting {}. 500 | """.format(self._perturb_scheduler, realworld_env.PERTURB_SCHEDULERS)) 501 | 502 | # Setting perturbation process parameters. 503 | if self._perturb_param == 'contact_friction': 504 | self._perturb_cur = perturb_spec.get('start', 0.7) 505 | self._perturb_start = perturb_spec.get('start', 0.7) 506 | self._perturb_min = perturb_spec.get('min', 0.05) 507 | self._perturb_max = perturb_spec.get('max', 1.2) 508 | self._perturb_std = perturb_spec.get('std', 0.1) 509 | elif self._perturb_param == 'joint_damping': 510 | self._perturb_cur = perturb_spec.get('start', 0.2) 511 | self._perturb_start = perturb_spec.get('start', 0.2) 512 | self._perturb_min = perturb_spec.get('min', 0.01) 513 | self._perturb_max = perturb_spec.get('max', 2.5) 514 | self._perturb_std = perturb_spec.get('std', 0.2) 515 | elif self._perturb_param == 'head_size': 516 | self._perturb_cur = perturb_spec.get('start', 0.09) 517 | self._perturb_start = perturb_spec.get('start', 0.09) 518 | self._perturb_min = perturb_spec.get('min', 0.01) 519 | self._perturb_max = perturb_spec.get('max', 0.19) 520 | self._perturb_std = perturb_spec.get('std', 0.02) 521 | 522 | def update_physics(self): 523 | """Returns a new Physics object with perturbed parameter.""" 524 | # Generate the new perturbed parameter. 525 | realworld_env.Base._generate_parameter(self) 526 | 527 | # Create new physics object with the perturb parameter. 528 | xml_string = common.read_model('humanoid.xml') 529 | mjcf = etree.fromstring(xml_string) 530 | 531 | if self._perturb_param == 'joint_damping': 532 | # Joint damping is a coefficient that provides a countering force 533 | # proportional to angular velocity. 534 | joint_damping = mjcf.find('./default/default/joint') 535 | joint_damping.set('damping', str(self._perturb_cur)) 536 | elif self._perturb_param == 'contact_friction': 537 | # Need to set the friction co-efficient on floor and body geoms: 538 | geom_contact = mjcf.find('./default/default/geom') # Body geom. 539 | geom_contact.set('friction', '{} .1 .1'.format(self._perturb_cur)) 540 | floor_contact = mjcf.find('./worldbody/geom') # Floor geom. 541 | floor_contact.set('friction', '{} .1 .1'.format(self._perturb_cur)) 542 | elif self._perturb_param == 'head_size': 543 | geom_head = mjcf.find('./worldbody/body/body/geom') 544 | geom_head.set('size', '{}'.format(self._perturb_cur)) 545 | xml_string_modified = etree.tostring(mjcf, pretty_print=True) 546 | physics = Physics.from_xml_string(xml_string_modified, common.ASSETS) 547 | 548 | return physics 549 | 550 | def before_step(self, action, physics): 551 | """Updates the environment using the action and returns a `TimeStep`.""" 552 | self._last_action = physics.control() 553 | action_min = self.action_spec(physics).minimum[:] 554 | action_max = self.action_spec(physics).maximum[:] 555 | action = realworld_env.Base.before_step(self, action, action_min, 556 | action_max) 557 | humanoid.Humanoid.before_step(self, action, physics) 558 | 559 | def after_step(self, physics): 560 | realworld_env.Base.after_step(self, physics) 561 | humanoid.Humanoid.after_step(self, physics) 562 | self._last_action = None 563 | 564 | 565 | class Physics(humanoid.Physics): 566 | """Inherits from humanoid.Physics.""" 567 | -------------------------------------------------------------------------------- /realworldrl_suite/environments/manipulator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Real-World Planar Manipulator domain.""" 17 | 18 | import collections 19 | 20 | from dm_control.suite import common 21 | from dm_control.suite import manipulator 22 | from lxml import etree 23 | import numpy as np 24 | 25 | from realworldrl_suite.environments import realworld_env 26 | from realworldrl_suite.utils import loggers 27 | from realworldrl_suite.utils import wrappers 28 | 29 | _CONTROL_TIMESTEP = .01 # (Seconds) 30 | _TIME_LIMIT = 10 # (Seconds) 31 | _ARM_JOINTS = ['arm_shoulder', 'arm_elbow'] # Used for constraints 32 | 33 | PERTURB_PARAMS = ['lower_arm_length', 'root_damping', 'shoulder_damping'] 34 | 35 | 36 | def load(task_name, **task_kwargs): 37 | return globals()[task_name](**task_kwargs) 38 | 39 | 40 | # Task Constraints 41 | def joint_angle_constraint(env, safety_vars): 42 | """Joint angles must be within a certain area of the track.""" 43 | joint_pos = safety_vars['joint_angle'] 44 | return (np.greater(joint_pos, 45 | env.limits['joint_angle_constraint'][0]).all() and 46 | np.less(joint_pos, env.limits['joint_angle_constraint'][1])).all() 47 | 48 | 49 | def joint_velocity_constraint(env, safety_vars): 50 | """Joint angle velocities must stay below a certain limit.""" 51 | joint_vels = safety_vars['joint_vels'] 52 | return np.less(np.max(joint_vels), env.limits['joint_velocity_constraint']) 53 | 54 | 55 | def joint_accel_constraint(env, safety_vars): 56 | """Joint angle velocities must stay below a certain limit.""" 57 | joint_accels = safety_vars['joint_accels'] 58 | return np.less(np.max(joint_accels), env.limits['joint_accel_constraint']) 59 | 60 | 61 | def grasp_force_constraint(env, safety_vars): 62 | """Limit gripper contact forces.""" 63 | return np.less( 64 | np.max(safety_vars['grasp_force']), env.limits['grasp_force_constraint']) 65 | 66 | 67 | # Action rate of change constraint. 68 | action_roc_constraint = realworld_env.action_roc_constraint 69 | 70 | 71 | def gen_task(use_peg, 72 | insert, 73 | fully_observable=True, 74 | time_limit=_TIME_LIMIT, 75 | random=None, 76 | log_output=None, 77 | environment_kwargs=None, 78 | safety_spec=None, 79 | delay_spec=None, 80 | noise_spec=None, 81 | perturb_spec=None, 82 | dimensionality_spec=None, 83 | multiobj_spec=None, 84 | combined_challenge=None): 85 | """Returns the Manipulator Bring task with specified real world attributes. 86 | 87 | Args: 88 | use_peg: A `bool`, whether to replace the ball prop with the peg prop. 89 | insert: A `bool`, whether to insert the prop in a receptacle. 90 | fully_observable: A `bool`, whether the observation should contain the 91 | position and velocity of the object being manipulated and the target 92 | location. 93 | time_limit: Integer length of task 94 | random: Optional, either a `numpy.random.RandomState` instance, an integer 95 | seed for creating a new `RandomState`, or None to select a seed 96 | automatically (default). 97 | log_output: String of path for pickle data logging, None disables logging 98 | environment_kwargs: additional kwargs for environment 99 | safety_spec: dictionary that specifies the safety specifications. 100 | delay_spec: dictionary that specifies the delay. 101 | noise_spec: dictionary that specifies the noise specifications. 102 | perturb_spec: dictionary that specifies the perturbations specifications. 103 | dimensionality_spec: dictionary that specifies extra observation features. 104 | multiobj_spec: dictionary that specifies complementary objectives. 105 | combined_challenge: string that can be 'easy', 'medium', or 'hard'. 106 | Specifying the combined challenge (can't be used with any other spec). 107 | """ 108 | physics = manipulator.Physics.from_xml_string( 109 | *manipulator.make_model(use_peg, insert)) 110 | safety_spec = safety_spec or {} 111 | delay_spec = delay_spec or {} 112 | noise_spec = noise_spec or {} 113 | perturb_spec = perturb_spec or {} 114 | dimensionality_spec = dimensionality_spec or {} 115 | multiobj_spec = multiobj_spec or {} 116 | # Check and update for combined challenge. 117 | (delay_spec, noise_spec, 118 | perturb_spec, dimensionality_spec) = ( 119 | realworld_env.get_combined_challenge( 120 | combined_challenge, delay_spec, noise_spec, perturb_spec, 121 | dimensionality_spec)) 122 | 123 | task = RealWorldBring( 124 | use_peg=use_peg, 125 | insert=insert, 126 | fully_observable=fully_observable, 127 | safety_spec=safety_spec, 128 | delay_spec=delay_spec, 129 | noise_spec=noise_spec, 130 | perturb_spec=perturb_spec, 131 | dimensionality_spec=dimensionality_spec, 132 | multiobj_spec=multiobj_spec, 133 | random=random) 134 | 135 | environment_kwargs = environment_kwargs or {} 136 | if log_output: 137 | logger = loggers.PickleLogger(path=log_output) 138 | else: 139 | logger = None 140 | return wrappers.LoggingEnv( 141 | physics, 142 | task, 143 | logger=logger, 144 | control_timestep=_CONTROL_TIMESTEP, 145 | time_limit=time_limit, 146 | **environment_kwargs) 147 | 148 | 149 | def realworld_bring_ball(fully_observable=True, 150 | time_limit=_TIME_LIMIT, 151 | random=None, 152 | log_output=None, 153 | environment_kwargs=None, 154 | safety_spec=None, 155 | delay_spec=None, 156 | noise_spec=None, 157 | perturb_spec=None, 158 | dimensionality_spec=None, 159 | multiobj_spec=None, 160 | combined_challenge=None): 161 | """Returns manipulator bring task with the ball prop.""" 162 | use_peg = False 163 | insert = False 164 | return gen_task(use_peg, insert, fully_observable, time_limit, random, 165 | log_output, environment_kwargs, safety_spec, delay_spec, 166 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec, 167 | combined_challenge) 168 | 169 | 170 | def realworld_bring_peg(fully_observable=True, 171 | time_limit=_TIME_LIMIT, 172 | random=None, 173 | log_output=None, 174 | environment_kwargs=None, 175 | safety_spec=None, 176 | delay_spec=None, 177 | noise_spec=None, 178 | perturb_spec=None, 179 | dimensionality_spec=None, 180 | multiobj_spec=None, 181 | combined_challenge=None): 182 | """Returns manipulator bring task with the peg prop.""" 183 | use_peg = True 184 | insert = False 185 | return gen_task(use_peg, insert, fully_observable, time_limit, random, 186 | log_output, environment_kwargs, safety_spec, delay_spec, 187 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec, 188 | combined_challenge) 189 | 190 | 191 | def realworld_insert_ball(fully_observable=True, 192 | time_limit=_TIME_LIMIT, 193 | random=None, 194 | log_output=None, 195 | environment_kwargs=None, 196 | safety_spec=None, 197 | delay_spec=None, 198 | noise_spec=None, 199 | perturb_spec=None, 200 | dimensionality_spec=None, 201 | multiobj_spec=None, 202 | combined_challenge=None): 203 | """Returns manipulator insert task with the ball prop.""" 204 | use_peg = False 205 | insert = True 206 | return gen_task(use_peg, insert, fully_observable, time_limit, random, 207 | log_output, environment_kwargs, safety_spec, delay_spec, 208 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec, 209 | combined_challenge) 210 | 211 | 212 | def realworld_insert_peg(fully_observable=True, 213 | time_limit=_TIME_LIMIT, 214 | random=None, 215 | log_output=None, 216 | environment_kwargs=None, 217 | safety_spec=None, 218 | delay_spec=None, 219 | noise_spec=None, 220 | perturb_spec=None, 221 | dimensionality_spec=None, 222 | multiobj_spec=None, 223 | combined_challenge=None): 224 | """Returns manipulator insert task with the peg prop.""" 225 | use_peg = True 226 | insert = True 227 | return gen_task(use_peg, insert, fully_observable, time_limit, random, 228 | log_output, environment_kwargs, safety_spec, delay_spec, 229 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec, 230 | combined_challenge) 231 | 232 | 233 | class Physics(manipulator.Physics): 234 | """Inherits from manipulator.Physics.""" 235 | 236 | 237 | class RealWorldBring(realworld_env.Base, manipulator.Bring): 238 | """A Manipulator task with real-world specifications. 239 | 240 | Subclasses dm_control.suite.manipulator. 241 | 242 | Safety: 243 | Adds a set of constraints on the task. 244 | Returns an additional entry in the observations ('constraints') in the 245 | length of the number of the constraints, where each entry is True if the 246 | constraint is satisfied and False otherwise. 247 | 248 | Delays: 249 | Adds actions, observations, and rewards delays. 250 | Actions delay is the number of steps between passing the action to the 251 | environment to when it is actually performed, and observations (rewards) 252 | delay is the offset of freshness of the returned observation (reward) after 253 | performing a step. 254 | 255 | Noise: 256 | Adds action or observation noise. 257 | Different noise include: white Gaussian actions/observations, 258 | dropped actions/observations values, stuck actions/observations values, 259 | or repetitive actions. 260 | 261 | Perturbations: 262 | Perturbs physical quantities of the environment. These perturbations are 263 | non-stationary and are governed by a scheduler. 264 | 265 | Dimensionality: 266 | Adds extra dummy features to observations to increase dimensionality of the 267 | state space. 268 | 269 | Multi-Objective Reward: 270 | Adds additional objectives and specifies objectives interaction (e.g., sum). 271 | """ 272 | 273 | def __init__(self, 274 | use_peg, 275 | insert, 276 | fully_observable, 277 | safety_spec, 278 | delay_spec, 279 | noise_spec, 280 | perturb_spec, 281 | dimensionality_spec, 282 | multiobj_spec, 283 | random=None, 284 | **kwargs): 285 | """Initialize the RealWorldBring task. 286 | 287 | Args: 288 | use_peg: A `bool`, whether to replace the ball prop with the peg prop. 289 | insert: A `bool`, whether to insert the prop in a receptacle. 290 | fully_observable: A `bool`, whether the observation should contain the 291 | position and velocity of the object being manipulated and the target 292 | location. 293 | safety_spec: dictionary that specifies the safety specifications of the 294 | task. It may contain the following fields: 295 | enable- bool that represents whether safety specifications are enabled. 296 | constraints- list of class methods returning boolean constraint 297 | satisfactions. 298 | limits- dictionary of constants used by the functions in 'constraints'. 299 | safety_coeff - a scalar between 1 and 0 that scales safety constraints, 300 | 1 producing the base constraints, and 0 likely producing an 301 | unsolveable task. 302 | observations- a default-True boolean that toggles the whether a vector 303 | of satisfied constraints is added to observations. 304 | delay_spec: dictionary that specifies the delay specifications of the 305 | task. It may contain the following fields: 306 | enable- bool that represents whether delay specifications are enabled. 307 | actions- integer indicating the number of steps actions are being 308 | delayed. 309 | observations- integer indicating the number of steps observations are 310 | being delayed. 311 | rewards- integer indicating the number of steps observations are being 312 | delayed. 313 | noise_spec: dictionary that specifies the noise specifications of the 314 | task. It may contains the following fields: 315 | gaussian- dictionary that specifies the white Gaussian additive noise. 316 | It may contain the following fields: 317 | enable- bool that represents whether noise specifications are enabled. 318 | actions- float inidcating the standard deviation of a white Gaussian 319 | noise added to each action. 320 | observations- similarly, additive white Gaussian noise to each 321 | returned observation. 322 | dropped- dictionary that specifies the dropped values noise. 323 | It may contain the following fields: 324 | enable- bool that represents whether dropped values specifications are 325 | enabled. 326 | observations_prob- float in [0,1] indicating the probability of 327 | dropping each observation component independently. 328 | observations_steps- positive integer indicating the number of time 329 | steps of dropping a value (setting to zero) if dropped. 330 | actions_prob- float in [0,1] indicating the probability of dropping 331 | each action component independently. 332 | actions_steps- positive integer indicating the number of time steps of 333 | dropping a value (setting to zero) if dropped. 334 | stuck- dictionary that specifies the stuck values noise. 335 | It may contain the following fields: 336 | enable- bool that represents whether stuck values specifications are 337 | enabled. 338 | observations_prob- float in [0,1] indicating the probability of each 339 | observation component becoming stuck. 340 | observations_steps- positive integer indicating the number of time 341 | steps an observation (or components of) stays stuck. 342 | actions_prob- float in [0,1] indicating the probability of each 343 | action component becoming stuck. 344 | actions_steps- positive integer indicating the number of time 345 | steps an action (or components of) stays stuck. 346 | repetition- dictionary that specifies the repetition statistics. 347 | It may contain the following fields: 348 | enable- bool that represents whether repetition specifications are 349 | enabled. 350 | actions_prob- float in [0,1] indicating the probability of the actions 351 | to be repeated in the following steps. 352 | actions_steps- positive integer indicating the number of time steps of 353 | repeating the same action if it to be repeated. 354 | perturb_spec: dictionary that specifies the perturbation specifications 355 | of the task. It may contain the following fields: 356 | enable- bool that represents whether perturbation specifications are 357 | enabled. 358 | period- int, number of episodes between updates perturbation updates. 359 | param- string indicating which parameter to perturb (currently 360 | supporting lower_arm_length, root_damping, shoulder_damping). 361 | scheduler- string inidcating the scheduler to apply to the perturbed 362 | parameter (currently supporting constant, random_walk, drift_pos, 363 | drift_neg, cyclic_pos, cyclic_neg, uniform, and saw_wave). 364 | start - float indicating the initial value of the perturbed parameter. 365 | min - float indicating the minimal value the perturbed parameter may be. 366 | max - float indicating the maximal value the perturbed parameter may be. 367 | std - float indicating the standard deviation of the white noise for the 368 | scheduling process. 369 | dimensionality_spec: dictionary that specifies the added dimensions to the 370 | state space. It may contain the following fields: 371 | enable- bool that represents whether dimensionality specifications are 372 | enabled. 373 | num_random_state_observations - num of random (unit Gaussian) 374 | observations to add. 375 | multiobj_spec: dictionary that sets up the multi-objective challenge. 376 | The challenge works by providing an `Objective` object which describes 377 | both numerical objectives and a reward-merging method that allow to both 378 | observe the various objectives in the observation and affect the 379 | returned reward in a manner defined by the Objective object. 380 | enable- bool that represents whether delay multi-objective 381 | specifications are enabled. 382 | objective - either a string which will load an `Objective` class from 383 | utils.multiobj_objectives.OBJECTIVES, or an Objective object which 384 | subclasses utils.multiobj_objectives.Objective. 385 | reward - boolean indicating whether to add the multiobj objective's 386 | reward to the environment's returned reward. 387 | coeff - a number in [0,1] that is passed into the Objective object to 388 | change the mix between the original reward and the Objective's 389 | rewards. 390 | observed - boolean indicating whether the defined objectives should be 391 | added to the observation. 392 | random: Optional, either a `numpy.random.RandomState` instance, an integer 393 | seed for creating a new `RandomState`, or None to select a seed 394 | automatically (default). 395 | **kwargs: extra parameters passed to parent class (manipulator.Bring) 396 | """ 397 | # Initialize parent classes. 398 | realworld_env.Base.__init__(self) 399 | manipulator.Bring.__init__( 400 | self, use_peg, insert, fully_observable, random=random, **kwargs) 401 | 402 | # Safety setup. 403 | self._setup_safety(safety_spec) 404 | 405 | # Delay setup. 406 | realworld_env.Base._setup_delay(self, delay_spec) 407 | 408 | # Noise setup. 409 | realworld_env.Base._setup_noise(self, noise_spec) 410 | 411 | # Perturb setup. 412 | self._setup_perturb(perturb_spec) 413 | 414 | # Dimensionality setup 415 | realworld_env.Base._setup_dimensionality(self, dimensionality_spec) 416 | 417 | # Multi-objective setup 418 | realworld_env.Base._setup_multiobj(self, multiobj_spec) 419 | 420 | self._use_peg = use_peg 421 | self._insert = insert 422 | 423 | # Safety methods. 424 | def _setup_safety(self, safety_spec): 425 | """Setup for the safety specifications of the task.""" 426 | self._safety_enabled = safety_spec.get('enable', False) 427 | self._safety_observed = safety_spec.get('observations', True) 428 | 429 | if self._safety_enabled: 430 | # Add safety specifications. 431 | if 'constraints' in safety_spec: 432 | self.constraints = safety_spec['constraints'] 433 | else: 434 | self.constraints = collections.OrderedDict([ 435 | ('joint_angle_constraint', joint_angle_constraint), 436 | ('joint_velocity_constraint', joint_velocity_constraint), 437 | ('joint_accel_constraint', joint_accel_constraint), 438 | ('grasp_force_constraint', grasp_force_constraint) 439 | ]) 440 | if 'limits' in safety_spec: 441 | self.limits = safety_spec['limits'] 442 | else: 443 | if 'safety_coeff' in safety_spec: 444 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1: 445 | raise ValueError( 446 | 'safety_coeff should be in [0,1], but got {}'.format( 447 | safety_spec['safety_coeff'])) 448 | safety_coeff = safety_spec['safety_coeff'] 449 | else: 450 | safety_coeff = 1 451 | self.limits = { 452 | 'joint_angle_constraint': 453 | safety_coeff * np.array([[-160, -140], [160, 140]]) * 454 | np.pi / 180., # rad 455 | 'joint_velocity_constraint': 456 | safety_coeff * 10, # rad/s 457 | 'joint_accel_constraint': 458 | safety_coeff * 3600, # rad/s^2 459 | 'grasp_force_constraint': 460 | safety_coeff * 5, # newtons 461 | 'action_roc_constraint': safety_coeff * 1.5 462 | } 463 | self._constraints_obs = np.ones(len(self.constraints), dtype=bool) 464 | 465 | def safety_vars(self, physics): 466 | """Centralized retrieval of safety-related variables to simplify logging.""" 467 | safety_vars = collections.OrderedDict( 468 | joint_angle=physics.named.data.qpos[_ARM_JOINTS].copy(), # rad 469 | joint_vels=np.abs(physics.named.data.qvel[_ARM_JOINTS]).copy(), # rad/s 470 | joint_accels=np.abs( 471 | physics.named.data.qacc[_ARM_JOINTS]).copy(), # rad/s^2 472 | grasp_force=physics.touch(), # log(1+newtons) 473 | actions=physics.control(), 474 | ) 475 | return safety_vars 476 | 477 | def _setup_perturb(self, perturb_spec): 478 | """Setup for the perturbations specification of the task.""" 479 | self._perturb_enabled = perturb_spec.get('enable', False) 480 | self._perturb_period = perturb_spec.get('period', 1) 481 | 482 | if self._perturb_enabled: 483 | # Add perturbations specifications. 484 | self._perturb_param = perturb_spec.get('param', 'lower_arm_length') 485 | # Making sure object to perturb is supported. 486 | if self._perturb_param not in PERTURB_PARAMS: 487 | raise ValueError("""param was: {}. Currently only supporting {}. 488 | """.format(self._perturb_param, PERTURB_PARAMS)) 489 | 490 | # Setting perturbation function. 491 | self._perturb_scheduler = perturb_spec.get('scheduler', 'constant') 492 | if self._perturb_scheduler not in realworld_env.PERTURB_SCHEDULERS: 493 | raise ValueError("""scheduler was: {}. Currently only supporting {}. 494 | """.format(self._perturb_scheduler, realworld_env.PERTURB_SCHEDULERS)) 495 | 496 | # Setting perturbation process parameters. 497 | if self._perturb_param == 'lower_arm_length': 498 | self._perturb_cur = perturb_spec.get('start', 0.12) 499 | self._perturb_start = perturb_spec.get('start', 0.12) 500 | self._perturb_min = perturb_spec.get('min', 0.1) 501 | self._perturb_max = perturb_spec.get('max', 0.25) 502 | self._perturb_std = perturb_spec.get('std', 0.01) 503 | elif self._perturb_param == 'root_damping': 504 | self._perturb_cur = perturb_spec.get('start', 2.0) 505 | self._perturb_start = perturb_spec.get('start', 2.0) 506 | self._perturb_min = perturb_spec.get('min', 0.1) 507 | self._perturb_max = perturb_spec.get('max', 10.0) 508 | self._perturb_std = perturb_spec.get('std', 0.1) 509 | elif self._perturb_param == 'shoulder_damping': 510 | self._perturb_cur = perturb_spec.get('start', 1.5) 511 | self._perturb_start = perturb_spec.get('start', 1.5) 512 | self._perturb_min = perturb_spec.get('min', 0.1) 513 | self._perturb_max = perturb_spec.get('max', 10.0) 514 | self._perturb_std = perturb_spec.get('std', 0.1) 515 | 516 | def update_physics(self): 517 | """Returns a new Physics object with perturbed parameter.""" 518 | # Generate the new perturbed parameter. 519 | realworld_env.Base._generate_parameter(self) 520 | 521 | # Create new physics object with the perturb parameter. 522 | xml_string = manipulator.make_model(self._use_peg, self._insert)[0] 523 | mjcf = etree.fromstring(xml_string) 524 | 525 | if self._perturb_param == 'lower_arm_length': 526 | lower_arm = mjcf.find('./worldbody/body/body/body/geom') 527 | lower_arm.set('fromto', '0 0 0 0 0 {}'.format(self._perturb_cur)) 528 | hand = mjcf.find('./worldbody/body/body/body/body') 529 | hand.set('pos', '0 0 {}'.format(self._perturb_cur)) 530 | elif self._perturb_param == 'root_damping': 531 | joints = mjcf.findall('./worldbody/body/joint') 532 | for joint in joints: 533 | if joint.get('name') == 'arm_root': 534 | joint.set('damping', str(self._perturb_cur)) 535 | elif self._perturb_param == 'shoulder_damping': 536 | shoulder_joint = mjcf.find('./worldbody/body/body/joint') 537 | shoulder_joint.set('damping', str(self._perturb_cur)) 538 | 539 | xml_string_modified = etree.tostring(mjcf, pretty_print=True) 540 | physics = Physics.from_xml_string(xml_string_modified, common.ASSETS) 541 | return physics 542 | 543 | def before_step(self, action, physics): 544 | """Updates the environment using the action and returns a `TimeStep`.""" 545 | self._last_action = physics.control() 546 | action_min = self.action_spec(physics).minimum[:] 547 | action_max = self.action_spec(physics).maximum[:] 548 | action = realworld_env.Base.before_step(self, action, action_min, 549 | action_max) 550 | manipulator.Bring.before_step(self, action, physics) 551 | 552 | def after_step(self, physics): 553 | realworld_env.Base.after_step(self, physics) 554 | manipulator.Bring.after_step(self, physics) 555 | self._last_action = None 556 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/accumulators.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Class to accumulate statistics during runs.""" 17 | import collections 18 | import copy 19 | 20 | import numpy as np 21 | 22 | 23 | class StatisticsAccumulator(object): 24 | """Acumulate the statistics of an environment's real-world variables. 25 | 26 | This class will accumulate the statistics generated by an environment 27 | into a local storage variable which can then be written to disk and 28 | used by the Evaluators class. 29 | """ 30 | 31 | def __init__(self, acc_safety, acc_safety_vars, acc_multiobj, auto_acc=True): 32 | """A class to easily accumulate necessary statistics for evaluation. 33 | 34 | Args: 35 | acc_safety: whether we should accumulate safety statistics. 36 | acc_safety_vars: whether we should accumulate state variables specific to 37 | safety. 38 | acc_multiobj: whether we should accumulate multi-objective statistics. 39 | auto_acc: whether to automatically accumulate when 'LAST' timesteps are 40 | pushed. 41 | """ 42 | self._acc_safety = acc_safety 43 | self._acc_safety_vars = acc_safety_vars 44 | self._acc_multiobj = acc_multiobj 45 | self._auto_acc = auto_acc 46 | self._buffer = [] # Buffer of timesteps of current episode 47 | self._stat_buffers = dict() 48 | 49 | def push(self, timestep): 50 | """Pushes a new timestep onto the current episode's buffer.""" 51 | local_ts = copy.deepcopy(timestep) 52 | self._buffer.append(local_ts) 53 | if local_ts.last(): 54 | self.accumulate() 55 | self.clear_buffer() 56 | 57 | def clear_buffer(self): 58 | """Clears the buffer of timesteps.""" 59 | self._buffer = [] 60 | 61 | def accumulate(self): 62 | """Accumulates statistics for the given buffer into the stats buffer.""" 63 | if self._acc_safety: 64 | self._acc_safety_stats() 65 | if self._acc_safety_vars: 66 | self._acc_safety_vars_stats() 67 | if self._acc_multiobj: 68 | self._acc_multiobj_stats() 69 | self._acc_return_stats() 70 | 71 | def _acc_safety_stats(self): 72 | """Generates safety-related statistics.""" 73 | ep_buffer = [] 74 | for ts in self._buffer: 75 | ep_buffer.append(ts.observation['constraints']) 76 | constraint_array = np.array(ep_buffer) 77 | # Total number of each constraint 78 | total_violations = np.sum((~constraint_array), axis=0) 79 | # # violations for each step 80 | safety_stats = self._stat_buffers.get( 81 | 'safety_stats', 82 | dict( 83 | total_violations=[], 84 | per_step_violations=np.zeros(constraint_array.shape))) 85 | # Accumulate the total number of violations of each constraint this episode 86 | safety_stats['total_violations'].append(total_violations) 87 | # Accumulate the number of violations at each timestep in the episode 88 | safety_stats['per_step_violations'] += ~constraint_array 89 | self._stat_buffers['safety_stats'] = safety_stats 90 | 91 | def _acc_safety_vars_stats(self): 92 | """Generates state-variable statistics to tune the safety constraints. 93 | 94 | This will generate a list of dict object, each describing the stats for each 95 | set of safety vars. 96 | """ 97 | ep_stats = collections.OrderedDict() 98 | for key in self._buffer[0].observation['safety_vars'].keys(): 99 | buf = np.array( 100 | [ts.observation['safety_vars'][key] for ts in self._buffer]) 101 | stats = dict( 102 | mean=np.mean(buf, axis=0), 103 | std_dev=np.std(buf, axis=0), 104 | min=np.min(buf, axis=0), 105 | max=np.max(buf, axis=0)) 106 | ep_stats[key] = stats 107 | 108 | safety_vars_buffer = self._stat_buffers.get('safety_vars_stats', []) 109 | safety_vars_buffer.append(ep_stats) # pytype: disable=attribute-error 110 | self._stat_buffers['safety_vars_stats'] = safety_vars_buffer 111 | 112 | def _acc_multiobj_stats(self): 113 | """Generates multiobj-related statistics.""" 114 | ep_buffer = [] 115 | for ts in self._buffer: 116 | ep_buffer.append(ts.observation['multiobj']) 117 | multiobj_array = np.array(ep_buffer) 118 | # Total number of each constraint. 119 | episode_totals = np.sum(multiobj_array, axis=0) 120 | # Number of violations for each step. 121 | multiobj_stats = self._stat_buffers.get('multiobj_stats', 122 | dict(episode_totals=[])) 123 | # Accumulate the total number of violations of each constraint this episode. 124 | multiobj_stats['episode_totals'].append(episode_totals) 125 | # Accumulate the number of violations at each timestep in the episode. 126 | self._stat_buffers['multiobj_stats'] = multiobj_stats 127 | 128 | def _acc_return_stats(self): 129 | """Generates per-episode return statistics.""" 130 | ep_buffer = [] 131 | for ts in self._buffer: 132 | if not ts.first(): # Skip the first ts as it has a reward of None 133 | ep_buffer.append(ts.reward) 134 | returns_array = np.array(ep_buffer) 135 | # Total number of each constraint. 136 | episode_totals = np.sum(returns_array) 137 | # Number of violations for each step. 138 | return_stats = self._stat_buffers.get('return_stats', 139 | dict(episode_totals=[])) 140 | # Accumulate the total number of violations of each constraint this episode. 141 | return_stats['episode_totals'].append(episode_totals) 142 | # Accumulate the number of violations at each timestep in the episode. 143 | self._stat_buffers['return_stats'] = return_stats 144 | 145 | def to_ndarray_dict(self): 146 | """Convert stats buffer to ndarrays to make disk writing more efficient.""" 147 | buffers = copy.deepcopy(self.stat_buffers) 148 | if 'safety_stats' in buffers: 149 | buffers['safety_stats']['total_violations'] = np.array( 150 | buffers['safety_stats']['total_violations']) 151 | n_episodes = buffers['safety_stats']['total_violations'].shape[0] 152 | buffers['safety_stats']['per_step_violations'] = np.array( 153 | buffers['safety_stats']['per_step_violations']) / n_episodes 154 | if 'multiobj_stats' in buffers: 155 | buffers['multiobj_stats']['episode_totals'] = np.array( 156 | buffers['multiobj_stats']['episode_totals']) 157 | if 'return_stats' in buffers: 158 | buffers['return_stats']['episode_totals'] = np.array( 159 | buffers['return_stats']['episode_totals']) 160 | return buffers 161 | 162 | @property 163 | def stat_buffers(self): 164 | return self._stat_buffers 165 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/accumulators_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for accumulators.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl.testing import absltest 24 | from absl.testing import parameterized 25 | import numpy as np 26 | import numpy.testing as npt 27 | import realworldrl_suite.environments as rwrl 28 | 29 | 30 | 31 | 32 | class RandomAgent(object): 33 | 34 | def __init__(self, action_spec): 35 | self.action_spec = action_spec 36 | 37 | def action(self): 38 | return np.random.uniform( 39 | self.action_spec.minimum, 40 | self.action_spec.maximum, 41 | size=self.action_spec.shape) 42 | 43 | 44 | class AccumulatorsTest(parameterized.TestCase): 45 | 46 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 47 | def test_logging(self, domain_name, task_name): 48 | temp_dir = self.create_tempdir() 49 | env = rwrl.load( 50 | domain_name=domain_name, 51 | task_name=task_name, 52 | safety_spec={'enable': True}, 53 | multiobj_spec={ 54 | 'enable': True, 55 | 'objective': 'safety', 56 | 'observed': False, 57 | }, 58 | log_output=os.path.join(temp_dir.full_path, 'test.pickle'), 59 | environment_kwargs=dict(log_safety_vars=True)) 60 | random_policy = RandomAgent(env.action_spec()).action 61 | n_steps = 0 62 | for _ in range(3): 63 | timestep = env.step(random_policy()) 64 | constraints = (~timestep.observation['constraints']).astype('int') 65 | n_steps += 1 66 | while not timestep.last(): 67 | timestep = env.step(random_policy()) 68 | constraints += (~timestep.observation['constraints']).astype('int') 69 | npt.assert_equal( 70 | env.stats_acc.stat_buffers['safety_stats']['total_violations'][-1], 71 | constraints) 72 | env.write_logs() 73 | with open(env.logs_path, 'rb') as f: 74 | read_data = np.load(f, allow_pickle=True) 75 | data = read_data['data'].item() 76 | self.assertLen(data.keys(), 4) 77 | self.assertIn('safety_vars_stats', data) 78 | self.assertIn('total_violations', data['safety_stats']) 79 | self.assertIn('per_step_violations', data['safety_stats']) 80 | self.assertIn('episode_totals', data['multiobj_stats']) 81 | self.assertIn('episode_totals', data['return_stats']) 82 | self.assertLen(data['safety_stats']['total_violations'], n_steps) 83 | self.assertLen(data['safety_vars_stats'], n_steps) 84 | self.assertLen(data['multiobj_stats']['episode_totals'], n_steps) 85 | self.assertLen(data['return_stats']['episode_totals'], n_steps) 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/evaluators.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helper lib to calculate RWRL evaluators from logged stats.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import functools 23 | import pprint 24 | 25 | from absl import app 26 | from absl import flags 27 | import matplotlib.pyplot as plt 28 | import numpy as np 29 | import scipy.stats 30 | 31 | FLAGS = flags.FLAGS 32 | if 'stats_in' not in FLAGS: 33 | flags.DEFINE_string('stats_in', 'logs.npz', 'Filename to write out logs.') 34 | 35 | 36 | _CONFIDENCE_LEVEL = .95 37 | _ROLLING_WINDOW = 101 38 | _COLORS = ['r', 'g', 'b', 'k', 'y'] 39 | _MARKERS = ['x', 'o', '*', '^', 'v'] 40 | 41 | 42 | class Evaluators(object): 43 | """Calculate standardized RWRL evaluators of a run.""" 44 | 45 | def __init__(self, stats): 46 | """Generate object given stored statistics. 47 | 48 | Args: 49 | stats: list of standardized RWRL statistics. 50 | """ 51 | self._meta = stats['meta'].item() 52 | self._stats = stats['data'].item() 53 | self._moving_average_cache = {} 54 | 55 | @property 56 | def task_name(self): 57 | return self._meta['task_name'] 58 | 59 | def get_safety_evaluator(self): 60 | """Returns the RWRL safety function's evaluation. 61 | 62 | Function (3) in the RWRL paper, the per-constraint sum of violations. 63 | 64 | Returns: 65 | A dict of contraint_name -> # of violations in logs. 66 | """ 67 | constraint_names = self._meta['safety_constraints'] 68 | safety_stats = self._stats['safety_stats'] 69 | violations = np.sum(safety_stats['total_violations'], axis=0) 70 | evaluator_results = collections.OrderedDict([ 71 | (key, violations[idx]) for idx, key in enumerate(constraint_names) 72 | ]) 73 | return evaluator_results 74 | 75 | def get_safety_plot(self, do_total=True, do_per_step=True): 76 | """Generates standardized plots describing safety constraint violations.""" 77 | n_plots = int(do_total) + int(do_per_step) 78 | fig, axes = plt.subplots(1, n_plots, figsize=(8 * n_plots, 6)) 79 | plot_idx = 0 80 | if do_total: 81 | self._get_safety_totals_plot(axes[plot_idx], self._stats['safety_stats']) 82 | plot_idx += 1 83 | if do_per_step: 84 | self._get_safety_per_step_plots(axes[plot_idx], 85 | self._stats['safety_stats']) 86 | return fig 87 | 88 | def _get_safety_totals_plot(self, ax, safety_stats): 89 | """Generates plots describing total # of violations / episode.""" 90 | meta = self.meta 91 | violations_labels = meta['safety_constraints'] 92 | total_violations = safety_stats['total_violations'].T 93 | 94 | for idx, violations in enumerate(total_violations): 95 | label = violations_labels[idx] 96 | ax.plot(np.arange(violations.shape[0]), violations, label=label) 97 | 98 | ax.set_title('# violations / episode') 99 | ax.legend() 100 | ax.set_ylabel('# violations') 101 | ax.set_xlabel('Episode') 102 | ax.plot() 103 | 104 | def _get_safety_per_step_plots(self, ax, safety_stats): 105 | """Generates plots describing mean # of violations / timestep.""" 106 | meta = self.meta 107 | violations_labels = meta['safety_constraints'] 108 | per_step_violations = safety_stats['per_step_violations'] 109 | 110 | for idx, violations in enumerate(per_step_violations.T): 111 | label = violations_labels[idx] 112 | ax.plot( 113 | np.arange(violations.shape[0]), violations, label=label, alpha=0.75) 114 | 115 | ax.set_title('Mean violations / timestep') 116 | ax.legend(loc='upper right') 117 | ax.set_ylabel('Mean # violations') 118 | ax.set_xlabel('Timestep') 119 | ax.plot() 120 | 121 | def get_safety_vars_plot(self): 122 | """Get plots for statistics of safety-related variables.""" 123 | if 'safety_vars_stats' not in self.stats: 124 | raise ValueError('No safety vars statistics present in this evaluator.') 125 | 126 | safety_vars = self.stats['safety_vars_stats'][0].keys() 127 | n_plots = len(safety_vars) 128 | fig, axes = plt.subplots(n_plots, 1, figsize=(8, 6 * n_plots)) 129 | 130 | for idx, var in enumerate(safety_vars): 131 | series = collections.defaultdict(list) 132 | for ep in self.stats['safety_vars_stats']: 133 | for stat in ep[var]: 134 | series[stat].append(ep[var][stat]) 135 | ax = axes[idx] 136 | for stat in ['min', 'max']: 137 | ax.plot(np.squeeze(np.array(series[stat])), label=stat) 138 | x = range(len(series['mean'])) 139 | 140 | mean = np.squeeze(np.array(series['mean'])) 141 | std_dev = np.squeeze(np.array(series['std_dev'])) 142 | ax.plot(x, mean, label='Value') 143 | ax.fill_between( 144 | range(len(series['mean'])), mean - std_dev, mean + std_dev, alpha=0.3) 145 | ax.set_title('Stats for {}'.format(var)) 146 | ax.legend() 147 | ax.spines['top'].set_visible(False) 148 | 149 | ax.xaxis.set_ticks_position('bottom') 150 | ax.set_xlabel('Episode #') 151 | ax.set_ylabel('Magnitude') 152 | ax.plot() 153 | return fig 154 | 155 | def _get_return_per_step_plot(self, ax, return_stats): 156 | """Plot per-episode return.""" 157 | returns = return_stats['episode_totals'] 158 | 159 | ax.plot(np.arange(returns.shape[0]), returns, label='Return') 160 | 161 | ax.set_title('Return / Episode') 162 | ax.legend() 163 | ax.set_ylabel('Return') 164 | ax.set_xlabel('Episode') 165 | 166 | def get_return_plot(self): 167 | fig, ax = plt.subplots() 168 | self._get_return_per_step_plot(ax, self.stats['return_stats']) 169 | return fig 170 | 171 | def _moving_average(self, values, window=1, stride=1, p=None): 172 | """Computes moving averages and confidence intervals.""" 173 | # Cache results for convenience. 174 | key = (id(values), window, stride, p) 175 | if key in self._moving_average_cache: 176 | return self._moving_average_cache[key] 177 | # Compute rolling windows efficiently. 178 | def _rolling_window(a, window): 179 | shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) 180 | strides = a.strides + (a.strides[-1],) 181 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) 182 | x = np.mean( 183 | _rolling_window(np.arange(len(values)), window), axis=-1)[::stride] 184 | y = _rolling_window(values, window) 185 | y_mean = np.mean(y, axis=-1)[::stride] 186 | y_lower, y_upper = _errorbars(y, axis=-1, p=p) 187 | ret = (x, y_mean, (y_lower, y_upper)) 188 | self._moving_average_cache[key] = ret 189 | return ret 190 | 191 | def get_convergence_episode(self): 192 | """Returns the first episode that reaches the final return.""" 193 | values = self.stats['return_stats']['episode_totals'] 194 | _, y, (y_lower, _) = self._moving_average( 195 | values, window=_ROLLING_WINDOW, p=_CONFIDENCE_LEVEL) 196 | # The convergence is established as the first time the average return 197 | # is above the lower bounds of the final return. 198 | first_episode = max(np.argmax(y >= y_lower[-1]), 1) 199 | return first_episode 200 | 201 | def get_final_return(self): 202 | """Returns the average final return.""" 203 | values = self.stats['return_stats']['episode_totals'] 204 | _, y, (_, _) = self._moving_average(values, window=_ROLLING_WINDOW, 205 | p=_CONFIDENCE_LEVEL) 206 | return y[-1] 207 | 208 | def get_absolute_regret(self): 209 | """Returns the total regret until convergence.""" 210 | values = self.stats['return_stats']['episode_totals'] 211 | first_episode = self.get_convergence_episode() 212 | final_return = self.get_final_return() 213 | regret = np.sum(final_return - values[:first_episode]) 214 | return regret 215 | 216 | def get_normalized_regret(self): 217 | """Returns the normalized regret.""" 218 | final_return = self.get_final_return() 219 | return self.get_absolute_regret() / final_return 220 | 221 | def get_normalized_instability(self): 222 | """Returns the ratio of episodes that dip below the final return.""" 223 | values = self.stats['return_stats']['episode_totals'] 224 | _, _, (y_lower, _) = self._moving_average( 225 | values, window=_ROLLING_WINDOW, p=_CONFIDENCE_LEVEL) 226 | first_episode = self.get_convergence_episode() 227 | if first_episode == len(values) - 1: 228 | return None 229 | episodes = np.arange(len(values)) 230 | unstable_episodes = np.where( 231 | np.logical_and(values < y_lower[-1], episodes > first_episode))[0] 232 | return float(len(unstable_episodes)) / (len(values) - first_episode - 1) 233 | 234 | def get_convergence_plot(self): 235 | """Plots an illustration of the convergence analysis.""" 236 | fig, ax = plt.subplots() 237 | first_episode = self.get_convergence_episode() 238 | 239 | values = self.stats['return_stats']['episode_totals'] 240 | ax.plot(np.arange(len(values)), values, color='steelblue', lw=2, alpha=.9, 241 | label='Return') 242 | ax.axvline(first_episode, color='seagreen', lw=2, label='Converged') 243 | ax.set_xlim(left=0, right=first_episode * 2) 244 | 245 | ax.set_title('Normalized regret = {:.3f}'.format( 246 | self.get_normalized_regret())) 247 | ax.legend() 248 | ax.set_ylabel('Return') 249 | ax.set_xlabel('Episode') 250 | return fig 251 | 252 | def get_stability_plot(self): 253 | """Plots an illustration of the algorithm stability.""" 254 | fig, ax = plt.subplots() 255 | first_episode = self.get_convergence_episode() 256 | 257 | values = self.stats['return_stats']['episode_totals'] 258 | _, _, (y_lower, _) = self._moving_average( 259 | values, window=_ROLLING_WINDOW, p=_CONFIDENCE_LEVEL) 260 | episodes = np.arange(len(values)) 261 | unstable_episodes = np.where( 262 | np.logical_and(values < y_lower[-1], episodes > first_episode))[0] 263 | 264 | ax.plot(episodes, values, color='steelblue', lw=2, alpha=.9, 265 | label='Return') 266 | for i, episode in enumerate(unstable_episodes): 267 | ax.axvline(episode, color='salmon', lw=2, 268 | label='Unstable' if i == 0 else None) 269 | ax.axvline(first_episode, color='seagreen', lw=2, label='Converged') 270 | 271 | ax.set_title('Normalized instability = {:.3f}%'.format( 272 | self.get_normalized_instability() * 100.)) 273 | ax.legend() 274 | ax.set_ylabel('Return') 275 | ax.set_xlabel('Episode') 276 | return fig 277 | 278 | def get_multiobjective_plot(self): 279 | """Plots an illustration of the multi-objective analysis.""" 280 | fig, ax = plt.subplots() 281 | 282 | values = self.stats['multiobj_stats']['episode_totals'] 283 | for i in range(values.shape[1]): 284 | ax.plot(np.arange(len(values[:, i])), values[:, i], 285 | color=_COLORS[i % len(_COLORS)], lw=2, alpha=.9, 286 | label='Objective {}'.format(i)) 287 | ax.legend() 288 | ax.set_ylabel('Objective value') 289 | ax.set_xlabel('Episode') 290 | return fig 291 | 292 | def get_standard_evaluators(self): 293 | """This method returns the standard RWRL evaluators. 294 | 295 | Returns: 296 | Dict of evaluators: 297 | Off-Line Performance: NotImplemented 298 | Efficiency: NotImplemented 299 | Safety: Per-constraint # of violations 300 | Robustness: NotImplemented 301 | Discernment: NotImplemented 302 | """ 303 | evaluators = collections.OrderedDict( 304 | offline=None, 305 | efficiency=None, 306 | safety=self.get_safety_evaluator(), 307 | robustness=None, 308 | discernment=None) 309 | return evaluators 310 | 311 | @property 312 | def meta(self): 313 | return self._meta 314 | 315 | @property 316 | def stats(self): 317 | return self._stats 318 | 319 | 320 | def _errorbars(values, axis=None, p=None): 321 | mean = np.mean(values, axis=axis) 322 | if p is None: 323 | std = np.std(values, axis=axis) 324 | return mean - std, mean + std 325 | return scipy.stats.t.interval(.95, values.shape[axis] - 1, loc=mean, 326 | scale=scipy.stats.sem(values, axis=-1)) 327 | 328 | 329 | def _map(fn, values): 330 | return dict((k, fn(v)) for k, v in values.items()) 331 | 332 | 333 | def get_normalized_regret(evaluator_list): 334 | """Computes normalized regret for multiple seeds on multiple tasks.""" 335 | values = collections.defaultdict(list) 336 | for e in evaluator_list: 337 | values[e.task_name].append(e.get_normalized_regret()) 338 | return _map(np.mean, values), _map(np.std, values) 339 | 340 | 341 | def get_regret_plot(evaluator_list): 342 | """Plots normalized regret for multiple seeds on multiple tasks.""" 343 | means, stds = get_normalized_regret(evaluator_list) 344 | task_names = sorted(means.keys()) 345 | heights = [] 346 | errorbars = [] 347 | for task_name in task_names: 348 | heights.append(means[task_name]) 349 | errorbars.append(stds[task_name]) 350 | x = np.arange(len(task_names)) 351 | fig, ax = plt.subplots() 352 | ax.bar(x, heights, yerr=errorbars) 353 | ax.set_xticks(x) 354 | ax.set_xticklabels(task_names) 355 | ax.set_ylabel('Normalized regret') 356 | return fig 357 | 358 | 359 | def get_return_plot(evaluator_list, stride=500): 360 | """Plots the return per episode for multiple seeds on multiple tasks.""" 361 | values = collections.defaultdict(list) 362 | for e in evaluator_list: 363 | values[e.task_name].append(e.stats['return_stats']['episode_totals']) 364 | values = _map(np.vstack, values) 365 | means = _map(functools.partial(np.mean, axis=0), values) 366 | stds = _map(functools.partial(np.std, axis=0), values) 367 | 368 | fig, ax = plt.subplots() 369 | for i, task_name in enumerate(means): 370 | idx = i % len(_COLORS) 371 | x = np.arange(len(means[task_name])) 372 | ax.plot(x, means[task_name], lw=2, color=_COLORS[idx], alpha=.6, label=None) 373 | ax.plot(x[::stride], means[task_name][::stride], 'o', lw=2, 374 | marker=_MARKERS[idx], markersize=10, color=_COLORS[idx], 375 | label=task_name) 376 | ax.fill_between(x, means[task_name] - stds[task_name], 377 | means[task_name] + stds[task_name], alpha=.4, lw=2, 378 | color=_COLORS[idx]) 379 | ax.legend() 380 | ax.set_ylabel('Return') 381 | ax.set_xlabel('Episode') 382 | return fig 383 | 384 | 385 | def get_multiobjective_plot(evaluator_list, stride=500): 386 | """Plots the objectives per episode for multiple seeds on multiple tasks.""" 387 | num_objectives = ( 388 | evaluator_list[0].stats['multiobj_stats']['episode_totals'].shape[1]) 389 | values = [collections.defaultdict(list) for _ in range(num_objectives)] 390 | for e in evaluator_list: 391 | for i in range(num_objectives): 392 | values[i][e.task_name].append( 393 | e.stats['multiobj_stats']['episode_totals'][:, i]) 394 | means = [None] * num_objectives 395 | stds = [None] * num_objectives 396 | for i in range(num_objectives): 397 | values[i] = _map(np.vstack, values[i]) 398 | means[i] = _map(functools.partial(np.mean, axis=0), values[i]) 399 | stds[i] = _map(functools.partial(np.std, axis=0), values[i]) 400 | 401 | fig, axes = plt.subplots(num_objectives, 1, figsize=(8, 6 * num_objectives)) 402 | for objective_idx in range(num_objectives): 403 | ax = axes[objective_idx] 404 | for i, task_name in enumerate(means[objective_idx]): 405 | m = means[objective_idx][task_name] 406 | s = stds[objective_idx][task_name] 407 | idx = i % len(_COLORS) 408 | x = np.arange(len(m)) 409 | ax.plot(x, m, lw=2, color=_COLORS[idx], alpha=.6, label=None) 410 | ax.plot(x[::stride], m[::stride], 'o', lw=2, marker=_MARKERS[idx], 411 | markersize=10, color=_COLORS[idx], label=task_name) 412 | ax.fill_between(x, m - s, m + s, alpha=.4, lw=2, color=_COLORS[idx]) 413 | ax.legend() 414 | ax.set_ylabel('Objective {}'.format(objective_idx)) 415 | ax.set_xlabel('Episode') 416 | return fig 417 | 418 | 419 | def main(argv): 420 | if len(argv) > 1: 421 | raise app.UsageError('Too many command-line arguments.') 422 | with open(FLAGS.stats_in, 'rb') as f: 423 | stats = np.load(f) 424 | evals = Evaluators(stats) 425 | pp = pprint.PrettyPrinter(indent=2) 426 | pp.pprint(evals.get_standard_evaluators()) 427 | 428 | 429 | if __name__ == '__main__': 430 | flags.mark_flag_as_required('stats_in') 431 | 432 | app.run(main) 433 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/evaluators_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for evaluators.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl.testing import absltest 24 | from absl.testing import parameterized 25 | import numpy as np 26 | import realworldrl_suite.environments as rwrl 27 | from realworldrl_suite.utils import evaluators 28 | 29 | 30 | class RandomAgent(object): 31 | 32 | def __init__(self, action_spec): 33 | self.action_spec = action_spec 34 | 35 | def action(self): 36 | return np.random.uniform( 37 | self.action_spec.minimum, 38 | self.action_spec.maximum, 39 | size=self.action_spec.shape) 40 | 41 | 42 | class EvaluatorsTest(parameterized.TestCase): 43 | 44 | def _gen_stats(self, domain_name, task_name): 45 | temp_dir = self.create_tempdir() 46 | env = rwrl.load( 47 | domain_name=domain_name, 48 | task_name=task_name, 49 | safety_spec={'enable': True}, 50 | log_output=os.path.join(temp_dir.full_path, 'test.pickle'), 51 | environment_kwargs=dict(log_safety_vars=True, flat_observation=True)) 52 | random_policy = RandomAgent(env.action_spec()).action 53 | for _ in range(3): 54 | timestep = env.step(random_policy()) 55 | while not timestep.last(): 56 | timestep = env.step(random_policy()) 57 | env.write_logs() 58 | return env.logs_path 59 | 60 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 61 | def test_loading(self, domain_name, task_name): 62 | temp_path = self._gen_stats(domain_name, task_name) 63 | data_in = np.load(temp_path, allow_pickle=True) 64 | evaluators.Evaluators(data_in) 65 | 66 | def test_safety_evaluator(self): 67 | # TODO(dulacarnold): Make this test general to all envs. 68 | temp_path = self._gen_stats( 69 | domain_name='cartpole', task_name='realworld_balance') 70 | data_in = np.load(temp_path, allow_pickle=True) 71 | ev = evaluators.Evaluators(data_in) 72 | self.assertLen(ev.get_safety_evaluator(), 3) 73 | 74 | def test_standard_evaluators(self): 75 | # TODO(dulacarnold): Make this test general to all envs. 76 | temp_path = self._gen_stats( 77 | domain_name='cartpole', task_name='realworld_balance') 78 | data_in = np.load(temp_path, allow_pickle=True) 79 | ev = evaluators.Evaluators(data_in) 80 | self.assertLen(ev.get_standard_evaluators(), 5) 81 | 82 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 83 | def test_safety_plot(self, domain_name, task_name): 84 | temp_path = self._gen_stats(domain_name, task_name) 85 | data_in = np.load(temp_path, allow_pickle=True) 86 | ev = evaluators.Evaluators(data_in) 87 | ev.get_safety_plot() 88 | 89 | 90 | if __name__ == '__main__': 91 | absltest.main() 92 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/loggers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Various logging classes to easily log to different backends.""" 17 | import copy 18 | import io 19 | import os 20 | import time 21 | 22 | from absl import logging 23 | import numpy as np 24 | 25 | 26 | 27 | 28 | class PickleLogger(object): 29 | """Saves data to a pickle file. 30 | 31 | This logger will save data as a list of stored elements, written to a python3 32 | pickle file. This data can then get retrieved by a PickleReader class. 33 | """ 34 | 35 | def __init__(self, path): 36 | """Generate a PickleLogger object. 37 | 38 | Args: 39 | path: string of path to write to 40 | """ 41 | self._meta = [] 42 | self._stack = [] 43 | # To know if our current run is responsible for existing files. 44 | # Used to figure out if we can overwrite a previous log, or if we've 45 | # been checkpointed in between. 46 | ts = str(int(time.time())) 47 | split = os.path.split(path) 48 | self._pickle_path = os.path.join(split[0], '{}-{}'.format(ts, split[1])) 49 | 50 | def set_meta(self, meta): 51 | """Pickleable object of metadata about the task.""" 52 | self._meta = copy.deepcopy(meta) 53 | 54 | def push(self, data): 55 | self._stack.append(copy.deepcopy(data)) 56 | 57 | def save(self, data=None): 58 | """Save data to disk. 59 | 60 | Args: 61 | data: Additional data structure you want to save to disk, will use the 62 | 'data' key for storage. 63 | """ 64 | logs = self.logs 65 | if data is not None: 66 | logs['data'] = data 67 | with open(self._pickle_path, 'wb') as f: 68 | np.savez_compressed(f, **logs) 69 | logging.info('Saved stats to %s.', format(self._pickle_path)) 70 | 71 | @property 72 | def logs(self): 73 | return dict(meta=self._meta, stack=self._stack) 74 | 75 | @property 76 | def logs_path(self): 77 | return self._pickle_path 78 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/loggers_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for loggers.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | import numpy as np 23 | import numpy.testing as npt 24 | from realworldrl_suite.utils import loggers 25 | 26 | 27 | class LoggersTest(absltest.TestCase): 28 | 29 | def test_write(self): 30 | temp_file = self.create_tempfile() 31 | plogger = loggers.PickleLogger(path=temp_file.full_path) 32 | write_meta = np.random.randn(10, 10) 33 | push_data = np.random.randn(10, 10) 34 | save_data = np.random.randn(10, 10) 35 | plogger.set_meta(write_meta) 36 | plogger.push(push_data) 37 | plogger.save(data=save_data) 38 | with open(plogger.logs_path, 'rb') as f: 39 | read_data = np.load(f, allow_pickle=True) 40 | npt.assert_array_equal(read_data['meta'], write_meta) 41 | npt.assert_array_equal(read_data['stack'][0], push_data) 42 | npt.assert_array_equal(read_data['data'], save_data) 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/multiobj_objectives.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Some basic complementary rewards for the multi-objective task.""" 17 | 18 | import abc 19 | import numpy as np 20 | 21 | 22 | class Objective(abc.ABC): 23 | 24 | @abc.abstractmethod 25 | def get_objectives(self, task_obj): 26 | raise NotImplementedError 27 | 28 | @abc.abstractmethod 29 | def merge_reward(self, task_obj, reward, alpha): 30 | raise NotImplementedError 31 | 32 | 33 | class SafetyObjective(Objective): 34 | """This class defines an extra objective related to safety.""" 35 | 36 | def get_objectives(self, task_obj): 37 | """Returns the safety objective: sum of satisfied constraints.""" 38 | if task_obj.safety_enabled: 39 | num_constraints = float(task_obj.constraints_obs.shape[0]) 40 | num_satisfied = task_obj.constraints_obs.sum() 41 | s_reward = num_satisfied / num_constraints 42 | return np.array([s_reward]) 43 | else: 44 | raise Exception('Safety not enabled. Safety-based multi-objective reward' 45 | ' requires safety spec to be enabled.') 46 | 47 | def merge_reward(self, task_obj, physics, base_reward, alpha): 48 | """Returns the sum of safety violations normalized to 1.""" 49 | s_reward = self.get_objectives(task_obj)[0] 50 | return (1 - alpha) * base_reward + alpha * s_reward 51 | 52 | 53 | OBJECTIVES = {'safety': SafetyObjective} 54 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/multiobj_objectives_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for multi-objective reward.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | import numpy as np 22 | import realworldrl_suite.environments as rwrl 23 | from realworldrl_suite.utils import multiobj_objectives 24 | 25 | 26 | class MultiObjTest(parameterized.TestCase): 27 | 28 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 29 | def testMultiObjNoSafety(self, domain_name, task_name): 30 | """Ensure multi-objective safety reward can be loaded.""" 31 | env = rwrl.load( 32 | domain_name=domain_name, 33 | task_name=task_name, 34 | safety_spec={'enable': False}, 35 | multiobj_spec={ 36 | 'enable': True, 37 | 'objective': 'safety', 38 | 'observed': True, 39 | 'coeff': 0.5 40 | }) 41 | with self.assertRaises(Exception): 42 | env.reset() 43 | env.step(0) 44 | 45 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 46 | def testMultiObjPassedObjective(self, domain_name, task_name): 47 | """Ensure objective class can be passed directly.""" 48 | multiobj_class = lambda: multiobj_objectives.SafetyObjective() # pylint: disable=unnecessary-lambda 49 | env = rwrl.load( 50 | domain_name=domain_name, 51 | task_name=task_name, 52 | safety_spec={'enable': True}, 53 | multiobj_spec={ 54 | 'enable': True, 55 | 'objective': multiobj_class, 56 | 'observed': True, 57 | 'coeff': 0.5 58 | }) 59 | env.reset() 60 | env.step(0) 61 | 62 | multiobj_class = multiobj_objectives.SafetyObjective 63 | env = rwrl.load( 64 | domain_name=domain_name, 65 | task_name=task_name, 66 | safety_spec={'enable': True}, 67 | multiobj_spec={ 68 | 'enable': True, 69 | 'objective': multiobj_class, 70 | 'observed': True, 71 | 'coeff': 0.5 72 | }) 73 | env.reset() 74 | env.step(0) 75 | 76 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 77 | def testMultiObjSafetyNoRewardObs(self, domain_name, task_name): 78 | """Ensure multi-objective safety reward can be loaded.""" 79 | env = rwrl.load( 80 | domain_name=domain_name, 81 | task_name=task_name, 82 | safety_spec={'enable': True}, 83 | multiobj_spec={ 84 | 'enable': True, 85 | 'objective': 'safety', 86 | 'reward': False, 87 | 'observed': True, 88 | 'coeff': 0.5 89 | }) 90 | env.reset() 91 | env.step(0) 92 | ts = env.step(0) 93 | self.assertIn('multiobj', ts.observation) 94 | # Make sure we see a 1 in normalized violations 95 | env.task._constraints_obs = np.ones( 96 | env.task._constraints_obs.shape).astype(np.bool) 97 | obs = env.task.get_observation(env.physics) 98 | self.assertEqual(obs['multiobj'][1], 1) 99 | # And that there is no effect on rewards 100 | env.task._multiobj_coeff = 0 101 | r1 = env.task.get_reward(env.physics) 102 | env.task._multiobj_coeff = 1 103 | r2 = env.task.get_reward(env.physics) 104 | self.assertEqual(r1, r2) 105 | 106 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 107 | def testMultiObjSafetyRewardNoObs(self, domain_name, task_name): 108 | """Ensure multi-objective safety reward can be loaded.""" 109 | env = rwrl.load( 110 | domain_name=domain_name, 111 | task_name=task_name, 112 | safety_spec={'enable': True}, 113 | multiobj_spec={ 114 | 'enable': True, 115 | 'objective': 'safety', 116 | 'reward': True, 117 | 'observed': False, 118 | 'coeff': 0.5 119 | }) 120 | env.reset() 121 | ts = env.step(0) 122 | self.assertNotIn('multiobj', ts.observation) 123 | # Make sure the method calls without error. 124 | env.task.get_multiobj_reward(env._physics, 0) 125 | env.task._constraints_obs = np.ones( 126 | env.task._constraints_obs.shape).astype(np.bool) 127 | 128 | # Make sure the mixing is working and global reward calls without error. 129 | env.task._multiobj_coeff = 1 130 | max_reward = env.task.get_reward(env.physics) 131 | env.task._multiobj_coeff = 0.5 132 | mid_reward = env.task.get_reward(env.physics) 133 | env.task._multiobj_coeff = 0.0 134 | min_reward = env.task.get_reward(env.physics) 135 | self.assertGreaterEqual(max_reward, min_reward) 136 | self.assertGreaterEqual(mid_reward, min_reward) 137 | 138 | env.task._multiobj_coeff = 0.5 139 | max_reward = env.task.get_reward(env.physics) 140 | self.assertEqual( 141 | env.task._multiobj_objective.merge_reward(env.task, env._physics, 0, 1), 142 | 1) 143 | self.assertEqual(env.task.get_multiobj_reward(env._physics, 0), 0.5) 144 | self.assertEqual(env.task.get_multiobj_reward(env._physics, 0.5), 0.75) 145 | self.assertEqual(env.task.get_multiobj_reward(env._physics, 1), 1) 146 | 147 | 148 | if __name__ == '__main__': 149 | absltest.main() 150 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/viewer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Simple viewer for realworld environments.""" 17 | from absl import app 18 | from absl import flags 19 | from absl import logging 20 | 21 | from dm_control import suite 22 | from dm_control import viewer 23 | 24 | import numpy as np 25 | import realworldrl_suite.environments as rwrl 26 | 27 | flags.DEFINE_enum('suite', 'rwrl', ['rwrl', 'dm_control'], 'Suite choice') 28 | flags.DEFINE_string('domain_name', 'cartpole', 'domain name') 29 | flags.DEFINE_string('task_name', 'realworld_balance', 'Task name') 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | class RandomAgent(object): 35 | 36 | def __init__(self, action_spec): 37 | self.action_spec = action_spec 38 | 39 | def action(self, timestep): 40 | del timestep 41 | return np.random.uniform( 42 | self.action_spec.minimum, 43 | self.action_spec.maximum, 44 | size=self.action_spec.shape) 45 | 46 | 47 | def main(_): 48 | if FLAGS.suite == 'dm_control': 49 | logging.info('Loading from dm_control...') 50 | env = suite.load(domain_name=FLAGS.domain_name, task_name=FLAGS.task_name) 51 | elif FLAGS.suite == 'rwrl': 52 | logging.info('Loading from rwrl...') 53 | env = rwrl.load(domain_name=FLAGS.domain_name, task_name=FLAGS.task_name) 54 | random_policy = RandomAgent(env.action_spec()).action 55 | viewer.launch(env, policy=random_policy) 56 | 57 | 58 | if __name__ == '__main__': 59 | app.run(main) 60 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/wrappers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """RealWorld RL env logging wrappers.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | 24 | from dm_control.rl import control 25 | import dm_env 26 | from dm_env import specs 27 | from realworldrl_suite.utils import accumulators 28 | import six 29 | 30 | 31 | class LoggingEnv(control.Environment): 32 | """Subclass of control.Environment which adds logging.""" 33 | 34 | def __init__(self, 35 | physics, 36 | task, 37 | logger=None, 38 | log_safety_vars=False, 39 | time_limit=float('inf'), 40 | control_timestep=None, 41 | n_sub_steps=None, 42 | log_every=100, 43 | flat_observation=False): 44 | """A subclass of `Environment` with logging hooks. 45 | 46 | Args: 47 | physics: Instance of `Physics`. 48 | task: Instance of `Task`. 49 | logger: Instance of 'realworldrl.utils.loggers.LoggerEnv', if specified 50 | will be used to log necessary data for realworld eval. 51 | log_safety_vars: If we should also log vars in self._task.safety_vars(), 52 | generally used for debugging or to find pertinent values for vars, will 53 | increase size of files on disk 54 | time_limit: Optional `int`, maximum time for each episode in seconds. By 55 | default this is set to infinite. 56 | control_timestep: Optional control time-step, in seconds. 57 | n_sub_steps: Optional number of physical time-steps in one control 58 | time-step, aka "action repeats". Can only be supplied if 59 | `control_timestep` is not specified. 60 | log_every: How many episodes between each log write. 61 | flat_observation: If True, observations will be flattened and concatenated 62 | into a single numpy array. 63 | 64 | Raises: 65 | ValueError: If both `n_sub_steps` and `control_timestep` are supplied. 66 | """ 67 | super(LoggingEnv, self).__init__( 68 | physics, 69 | task, 70 | time_limit, 71 | control_timestep, 72 | n_sub_steps, 73 | flat_observation=False) 74 | self._flat_observation_ = flat_observation 75 | self._logger = logger 76 | self._buffer = [] 77 | self._counter = 0 78 | self._log_every = log_every 79 | self._ep_counter = 0 80 | self._log_safety_vars = self._task.safety_enabled and log_safety_vars 81 | if self._logger: 82 | meta_dict = dict(task_name=type(self._task).__name__) 83 | if self._task.safety_enabled: 84 | meta_dict['safety_constraints'] = list(self._task.constraints.keys()) 85 | if self._log_safety_vars: 86 | meta_dict['safety_vars'] = list( 87 | list(self._task.safety_vars(self._physics).keys())) 88 | self._logger.set_meta(meta_dict) 89 | 90 | self._stats_acc = accumulators.StatisticsAccumulator( 91 | acc_safety=self._task.safety_enabled, 92 | acc_safety_vars=self._log_safety_vars, 93 | acc_multiobj=self._task.multiobj_enabled) 94 | else: 95 | self._stats_acc = None 96 | 97 | def reset(self): 98 | """Starts a new episode and returns the first `TimeStep`.""" 99 | if self._stats_acc: 100 | self._stats_acc.clear_buffer() 101 | if self._task.perturb_enabled: 102 | if self._counter % self._task.perturb_period == 0: 103 | self._physics = self._task.update_physics() 104 | self._counter += 1 105 | timestep = super(LoggingEnv, self).reset() 106 | self._track(timestep) 107 | if self._flat_observation_: 108 | timestep = dm_env.TimeStep( 109 | step_type=timestep.step_type, 110 | reward=None, 111 | discount=None, 112 | observation=control.flatten_observation( 113 | timestep.observation)['observations']) 114 | return timestep 115 | 116 | def observation_spec(self): 117 | """Returns the observation specification for this environment. 118 | 119 | Infers the spec from the observation, unless the Task implements the 120 | `observation_spec` method. 121 | 122 | Returns: 123 | An dict mapping observation name to `ArraySpec` containing observation 124 | shape and dtype. 125 | """ 126 | self._flat_observation = self._flat_observation_ 127 | obs_spec = super(LoggingEnv, self).observation_spec() 128 | self._flat_observation = False 129 | if self._flat_observation_: 130 | return obs_spec['observations'] 131 | return obs_spec 132 | 133 | def step(self, action): 134 | """Updates the environment using the action and returns a `TimeStep`.""" 135 | do_track = not self._reset_next_step 136 | timestep = super(LoggingEnv, self).step(action) 137 | if do_track: 138 | self._track(timestep) 139 | if timestep.last(): 140 | self._ep_counter += 1 141 | if self._ep_counter % self._log_every == 0: 142 | self.write_logs() 143 | # Only flatten observation if we're not forwarding one from a reset(), 144 | # as it will already be flattened. 145 | if self._flat_observation_ and not timestep.first(): 146 | timestep = dm_env.TimeStep( 147 | step_type=timestep.step_type, 148 | reward=timestep.reward, 149 | discount=timestep.discount, 150 | observation=control.flatten_observation( 151 | timestep.observation)['observations']) 152 | return timestep 153 | 154 | def _track(self, timestep): 155 | if self._logger is None: 156 | return 157 | ts = copy.deepcopy(timestep) 158 | # Augment the timestep with unobserved variables for logging purposes. 159 | # Add safety-related observations. 160 | if self._task.safety_enabled and 'constraints' not in ts.observation: 161 | ts.observation['constraints'] = copy.copy(self._task.constraints_obs) 162 | if self._log_safety_vars: 163 | ts.observation['safety_vars'] = copy.deepcopy( 164 | self._task.safety_vars(self._physics)) 165 | if self._task.multiobj_enabled and 'multiobj' not in ts.observation: 166 | ts.observation['multiobj'] = self._task.get_multiobj_obs(self._physics) 167 | self._stats_acc.push(ts) 168 | 169 | def get_logs(self): 170 | return self._logger.logs 171 | 172 | def write_logs(self): 173 | if self._logger is None: 174 | return 175 | self._logger.save(data=self._stats_acc.to_ndarray_dict()) 176 | 177 | @property 178 | def stats_acc(self): 179 | return self._stats_acc 180 | 181 | @property 182 | def logs_path(self): 183 | if self._logger is None: 184 | return None 185 | return self._logger.logs_path 186 | 187 | 188 | def _spec_from_observation(observation): 189 | result = collections.OrderedDict() 190 | for key, value in six.iteritems(observation): 191 | result[key] = specs.Array(value.shape, value.dtype, name=key) 192 | return result 193 | -------------------------------------------------------------------------------- /realworldrl_suite/utils/wrappers_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Real-World RL Suite Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for realworldrl.utils.wrappers.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl.testing import absltest 24 | from absl.testing import parameterized 25 | import numpy as np 26 | import realworldrl_suite.environments as rwrl 27 | 28 | 29 | class RandomAgent(object): 30 | 31 | def __init__(self, action_spec): 32 | self.action_spec = action_spec 33 | 34 | def action(self): 35 | return np.random.uniform( 36 | self.action_spec.minimum, 37 | self.action_spec.maximum, 38 | size=self.action_spec.shape) 39 | 40 | 41 | class WrappersTest(parameterized.TestCase): 42 | 43 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 44 | def test_init(self, domain_name, task_name): 45 | temp_file = self.create_tempfile() 46 | env = rwrl.load( 47 | domain_name=domain_name, 48 | task_name=task_name, 49 | safety_spec={'enable': True}, 50 | log_output=temp_file.full_path, 51 | environment_kwargs=dict(log_safety_vars=True)) 52 | env.write_logs() 53 | 54 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 55 | def test_no_logger(self, domain_name, task_name): 56 | env = rwrl.load( 57 | domain_name=domain_name, 58 | task_name=task_name, 59 | safety_spec={'enable': True}, 60 | log_output=None, 61 | environment_kwargs=dict(log_safety_vars=True)) 62 | env.write_logs() 63 | 64 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 65 | def test_disabled_safety_obs(self, domain_name, task_name): 66 | temp_file = self.create_tempfile() 67 | env = rwrl.load( 68 | domain_name=domain_name, 69 | task_name=task_name, 70 | safety_spec={'enable': True, 'observations': False}, 71 | log_output=temp_file.full_path, 72 | environment_kwargs=dict(log_safety_vars=True)) 73 | env.reset() 74 | timestep = env.step(0) 75 | self.assertNotIn('constraints', timestep.observation.keys()) 76 | self.assertIn('constraints', env._stats_acc._buffer[-1].observation) 77 | env.write_logs() 78 | 79 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 80 | def test_custom_safety_constraint(self, domain_name, task_name): 81 | const_constraint = (domain_name == rwrl.ALL_TASKS[0][1]) 82 | def custom_constraint(unused_env_obj, unused_safety_vars): 83 | return const_constraint 84 | constraints = {'custom_constraint': custom_constraint} 85 | temp_file = self.create_tempfile() 86 | env = rwrl.load( 87 | domain_name=domain_name, 88 | task_name=task_name, 89 | safety_spec={'enable': True, 'observations': True, 90 | 'constraints': constraints}, 91 | log_output=temp_file.full_path, 92 | environment_kwargs=dict(log_safety_vars=True)) 93 | env.reset() 94 | timestep = env.step(0) 95 | self.assertIn('constraints', timestep.observation.keys()) 96 | self.assertEqual(timestep.observation['constraints'], 97 | np.array([const_constraint])) 98 | env.write_logs() 99 | 100 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 101 | def test_action_roc_constraint(self, domain_name, task_name): 102 | is_within_constraint = (domain_name == rwrl.ALL_TASKS[0][1]) 103 | constraints = { 104 | 'action_roc_constraint': rwrl.DOMAINS[domain_name].action_roc_constraint 105 | } 106 | temp_file = self.create_tempfile() 107 | env = rwrl.load( 108 | domain_name=domain_name, 109 | task_name=task_name, 110 | safety_spec={'enable': True, 'observations': True, 111 | 'constraints': constraints}, 112 | log_output=temp_file.full_path, 113 | environment_kwargs=dict(log_safety_vars=True)) 114 | env.reset() 115 | _ = env.step(-100) 116 | timestep_2 = env.step(-100 if is_within_constraint else 100) 117 | self.assertIn('constraints', timestep_2.observation.keys()) 118 | self.assertEqual(timestep_2.observation['constraints'], 119 | np.array([is_within_constraint])) 120 | env.write_logs() 121 | 122 | # @parameterized.parameters(*(5,)) 123 | def test_log_every_n(self, every_n=5): 124 | domain_name = 'cartpole' 125 | task_name = 'realworld_balance' 126 | temp_dir = self.create_tempdir() 127 | env = rwrl.load( 128 | domain_name=domain_name, 129 | task_name=task_name, 130 | safety_spec={'enable': True, 'observations': False}, 131 | log_output=os.path.join(temp_dir.full_path, 'test.pickle'), 132 | environment_kwargs=dict(log_safety_vars=True, log_every=every_n)) 133 | env.reset() 134 | n = 0 135 | while True: 136 | timestep = env.step(0) 137 | if timestep.last(): 138 | n += 1 139 | if n % every_n == 0: 140 | self.assertTrue(os.path.exists(env.logs_path)) 141 | os.remove(env.logs_path) 142 | self.assertFalse(os.path.exists(env.logs_path)) 143 | if n > 20: 144 | break 145 | 146 | @parameterized.named_parameters(*rwrl.ALL_TASKS) 147 | def test_flat_obs(self, domain_name, task_name): 148 | temp_file = self.create_tempfile() 149 | env = rwrl.load( 150 | domain_name=domain_name, 151 | task_name=task_name, 152 | safety_spec={'enable': True}, 153 | log_output=temp_file.full_path, 154 | environment_kwargs=dict(log_safety_vars=True, flat_observation=True)) 155 | self.assertLen(env.observation_spec().shape, 1) 156 | if domain_name == 'cartpole' and task_name in ['swingup', 'balance']: 157 | self.assertEqual(env.observation_spec().shape[0], 8) 158 | timestep = env.step(0) 159 | self.assertLen(timestep.observation.shape, 1) 160 | if domain_name == 'cartpole' and task_name in ['swingup', 'balance']: 161 | self.assertEqual(timestep.observation.shape[0], 8) 162 | env.write_logs() 163 | 164 | 165 | if __name__ == '__main__': 166 | absltest.main() 167 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Read World RL Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Setup for pip package.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import io 22 | import sys 23 | import unittest 24 | 25 | from setuptools import find_packages 26 | from setuptools import setup 27 | from setuptools.command.test import test as TestCommandBase 28 | 29 | REQUIRED_PACKAGES = [ 30 | 'six', 'absl-py', 'numpy', 'dm-env', 'dm-control', 'lxml', 'scipy', 31 | 'matplotlib' 32 | ] 33 | 34 | 35 | class StderrWrapper(io.IOBase): 36 | 37 | def write(self, *args, **kwargs): 38 | return sys.stderr.write(*args, **kwargs) 39 | 40 | def writeln(self, *args, **kwargs): 41 | if args or kwargs: 42 | sys.stderr.write(*args, **kwargs) 43 | sys.stderr.write('\n') 44 | 45 | 46 | # Needed to ensure that flags are correctly parsed. 47 | class Test(TestCommandBase): 48 | 49 | def run_tests(self): 50 | # Import absl inside run, where dependencies have been loaded already. 51 | from absl import app # pylint: disable=g-import-not-at-top 52 | 53 | def main(_): 54 | test_loader = unittest.TestLoader() 55 | test_suite = test_loader.discover( 56 | 'realworldrl_suite', pattern='*_test.py') 57 | stderr = StderrWrapper() 58 | result = unittest.TextTestResult(stderr, descriptions=True, verbosity=2) 59 | test_suite.run(result) 60 | result.printErrors() 61 | 62 | final_output = ('Tests run: {}. '.format(result.testsRun) + 63 | 'Errors: {} Failures: {}.'.format( 64 | len(result.errors), len(result.failures))) 65 | 66 | header = '=' * len(final_output) 67 | stderr.writeln(header) 68 | stderr.writeln(final_output) 69 | stderr.writeln(header) 70 | 71 | if result.wasSuccessful(): 72 | return 0 73 | else: 74 | return 1 75 | 76 | # Run inside absl.app.run to ensure flags parsing is done. 77 | return app.run(main) 78 | 79 | 80 | def rwrl_test_suite(): 81 | test_loader = unittest.TestLoader() 82 | test_suite = test_loader.discover('realworldrl_suite', pattern='*_test.py') 83 | return test_suite 84 | 85 | 86 | setup( 87 | name='realworldrl_suite', 88 | version='1.0', 89 | description='RL evaluation framework for the real world.', 90 | url='https://github.com/google-research/realworldrl_suite', 91 | author='Google', 92 | author_email='no-reply@google.com', 93 | # Contained modules and scripts. 94 | packages=find_packages(), 95 | install_requires=REQUIRED_PACKAGES, 96 | extras_require={}, 97 | platforms=['any'], 98 | license='Apache 2.0', 99 | cmdclass={ 100 | 'test': Test, 101 | }, 102 | ) 103 | --------------------------------------------------------------------------------