├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── README.md
└── img
│ ├── angular_velocity.gif
│ └── humanoid_perturbations.gif
├── examples
├── run_dmpo_acme.py
├── run_ppo.py
└── run_random.py
├── realworldrl_suite
├── __init__.py
├── analysis
│ └── realworldrl_notebook.ipynb
├── environments
│ ├── __init__.py
│ ├── cartpole.py
│ ├── env_test.py
│ ├── humanoid.py
│ ├── manipulator.py
│ ├── quadruped.py
│ ├── realworld_env.py
│ └── walker.py
└── utils
│ ├── __init__.py
│ ├── accumulators.py
│ ├── accumulators_test.py
│ ├── evaluators.py
│ ├── evaluators_test.py
│ ├── loggers.py
│ ├── loggers_test.py
│ ├── multiobj_objectives.py
│ ├── multiobj_objectives_test.py
│ ├── viewer.py
│ ├── wrappers.py
│ └── wrappers_test.py
└── setup.py
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of Real-World RL Suite's significant contributors.
2 | #
3 | # This does not necessarily list everyone who has contributed code,
4 | # especially since many employees of one corporation may be contributing.
5 | # To see the full list of contributors, see the revision history in
6 | # source control.
7 | Google LLC
8 | DeepMind Technologies Limited
9 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Real-World Reinforcement Learning (RWRL) Challenge Framework
2 |
3 |
4 | 
5 |
6 |
7 | The ["Challenges of Real-World RL"](https://arxiv.org/abs/1904.12901) paper
8 | identifies and describes a set of nine challenges that are currently preventing
9 | Reinforcement Learning (RL) agents from being utilized on real-world
10 | applications and products. It also describes an evaluation framework and a set
11 | of environments that can provide an evaluation of an RL algorithm’s potential
12 | applicability to real-world systems. It has since then been followed up with
13 | ["An Empirical Investigation of the challenges of real-world reinforcement
14 | learning"](https://arxiv.org/pdf/2003.11881.pdf) which implements eight of the
15 | nine described challenges (excluding explainability) and analyses their effects
16 | on various state-of-the-art RL algorithms. This is the codebase used to perform
17 | this analysis, and is also intended as a common platform for easily reproducible
18 | experimentation around these challenges, it is referred to as the
19 | `realworldrl-suite` (Real-World Reinforcement Learning (RWRL) Suite).
20 |
21 | Currently the suite is to comprised of five environments:
22 |
23 | * Cartpole
24 | * Walker
25 | * Quadriped
26 | * Manipulator (less tested)
27 | * Humanoid
28 |
29 | The codebase is currently structured as:
30 |
31 | * environments/ -- the extended environments
32 | * utils/ -- wrapper classes for logging and standardized evaluations
33 | * analysis/ -- Notebook for training an agent and generating plots
34 | * examples/ -- Random policy and PPO agent example implementations
35 | * docs/ -- Documentation
36 |
37 | Questions can be directed to the Real-World RL group e-mail
38 | [realworldrl@google.com].
39 |
40 | > :information_source: If you wish to test your agent in a principled fashion on
41 | > related challenges in low-dimensional domains, we highly recommend using
42 | > [bsuite](https://github.com/deepmind/bsuite).
43 |
44 | ## Documentation
45 |
46 | We overview the challenges here, but more thorough documentation on how to
47 | configure each challenge can be found [here](docs/README.md).
48 |
49 | Starter examples are presented in the [examples](#running-examples) section.
50 |
51 | ## Challenges
52 |
53 | ### Safety
54 |
55 | Adds a set of constraints on the task. Returns an additional entry in the
56 | observations ('constraints') in the length of the number of the contraints,
57 | where each entry is True if the constraint is satisfied and False otherwise.
58 |
59 | ### Delays
60 |
61 | Action, observation and reward delays.
62 |
63 | - Action delay is the number of steps between passing the action to the
64 | environment to when it is actually performed.
65 | - Observation delay is the offset of freshness of the returned observation
66 | after performing a step.
67 | - Reward delay indicates the number of steps before receiving a reward after
68 | taking an action.
69 |
70 | ### Noise
71 |
72 | Action and observation noise. Different noise include:
73 |
74 | - White Gaussian action/observation noise
75 | - Dropped actions/observations
76 | - Stuck actions/observations
77 | - Repetitive actions
78 |
79 | The noise specifications can be parameterized in the noise_spec dictionary.
80 |
81 | ### Perturbations
82 |
83 | Perturbs physical quantities of the environment. These perturbations are
84 | non-stationary and are governed by a scheduler.
85 |
86 | ### Dimensionality
87 |
88 | Adds extra dummy features to observations to increase dimensionality of the
89 | state space.
90 |
91 | ### Multi-Objective Rewards:
92 |
93 | Adds additional objectives and specifies objectives interaction (e.g., sum).
94 |
95 | ### Offline Learning
96 |
97 | We provide our offline datasets through the
98 | [RL Unplugged](https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/)
99 | library. There is an
100 | [example](https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/rwrl_example.py)
101 | and an associated
102 | [colab](https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/rwrl_d4pg.ipynb).
103 |
104 | ### RWRL Combined Challenge Benchmarks:
105 |
106 | Combines multiple challenges into the same environment. The challenges are
107 | divided into 'Easy', 'Medium' and 'Hard' which depend on the magnitude of the
108 | challenge effects applied along each challenge dimension.
109 |
110 | ## Installation
111 |
112 | - Install pip:
113 | - Run the following commands:
114 |
115 | ```bash
116 | curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
117 | python get-pip.py
118 | ```
119 |
120 | - Make sure pip is up to date.
121 |
122 | ```bash
123 | pip3 install --upgrade pip
124 | ```
125 |
126 | - (Optional) You may wish to create a
127 | [Python virtual environment](https://docs.python.org/3/tutorial/venv.html)
128 | to manage your dependencies, so as not to clobber your system installation:
129 |
130 | ```bash
131 | sudo pip3 install virtualenv
132 | /usr/local/bin/virtualenv realworldrl_suite
133 | source ./realworldrl/bin/activate
134 | ```
135 |
136 | - Install MuJoCo (see dm_control - https://github.com/deepmind/dm_control).
137 |
138 | - To install `realworldrl_suite`:
139 |
140 | - Clone the repository by running:
141 |
142 | ```bash
143 | git clone https://github.com/google-research/realworldrl_suite.git
144 | ```
145 |
146 | - Ensure you are in the parent directory of realworldrl_suite
147 | - Run the command:
148 |
149 | ```bash
150 | pip3 install realworldrl_suite/
151 | ```
152 |
153 | ## Running examples
154 |
155 | We provide three example agents: a random agent, a PPO agent, and an
156 | [ACME](https://github.com/deepmind/acme)-based DMPO agent.
157 |
158 | - For PPO, running the examples requires installing the following packages:
159 |
160 | ```bash
161 | pip3 install tensorflow==1.15.0 dm2gym
162 | pip3 install git+git://github.com/openai/baselines.git
163 | ```
164 |
165 | - The PPO example can then be run with
166 |
167 | ```bash
168 | cd realworldrl_suite/examples
169 | mkdir /tmp/rwrl/
170 | python3 run_ppo.py
171 | ```
172 |
173 | - For DMPO, one can run the example by installing the following packages:
174 |
175 | ```bash
176 | pip install dm-acme
177 | pip install dm-acme[reverb]
178 | pip install dm-acme[tf]
179 | ```
180 |
181 | You *may* also have to install the following:
182 |
183 | ```bash
184 | pip install gym
185 | pip install jax
186 | pip install dm-sonnet
187 | ```
188 |
189 | - The examples look for the MuJoCo licence key in `~/.mujoco/mjkey.txt` by
190 | default.
191 |
192 | ## RWRL Combined Challenge Benchmark Instantiation:
193 |
194 | As mentioned above, these benchmark challenges are divided into 'Easy', 'Medium'
195 | and 'Hard' difficulty levels. For the current state-of-the-art performance on
196 | these benchmarks, please see this
197 | paper.
198 |
199 | Instantiating a combined challenge environment with 'Easy' difficulty is done as
200 | follows:
201 |
202 | ```python
203 | import realworldrl_suite.environments as rwrl
204 | env = rwrl.load(
205 | domain_name='cartpole',
206 | task_name='realworld_swingup',
207 | combined_challenge='easy',
208 | log_output='/tmp/path/to/results.npz',
209 | environment_kwargs=dict(log_safety_vars=True, flat_observation=True))
210 | ```
211 |
212 | ## Acknowledgements
213 |
214 | If you use `realworldrl_suite` in your work, please cite:
215 |
216 | ```
217 | @article{dulacarnold2020realworldrlempirical,
218 | title={An empirical investigation of the challenges of real-world reinforcement learning},
219 | author={Dulac-Arnold, Gabriel and
220 | Levine, Nir and
221 | Mankowitz, Daniel J. and
222 | Li, Jerry and
223 | Paduraru, Cosmin and
224 | Gowal, Sven and
225 | Hester, Todd
226 | },
227 | year={2020},
228 | }
229 | ```
230 |
231 | ## Paper links
232 |
233 | - Challenges of real-world
234 | reinforcement learning
235 |
236 | - An empirical investigation of the
237 | challenges of real-world reinforcement learning
238 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # RWRL Suite Challenge Spec
2 |
3 | This document outlines the specification dictionaries for each of the supported
4 | challenges. It provides an overview of each challenge's parameters and their
5 | effects.
6 |
7 |
8 | ### Safety
9 |
10 | Adds a set of constraints on the task. Returns an additional entry in the
11 | observations under the `constraints` key. The observation is a binary vectory
12 | and is the length of the number of the contraints, where each entry is True if
13 | the constraint is satisfied and False otherwise. The following dictionary is fed
14 | as an argument into the RWRL environment load function ot intialize safety
15 | constraints:
16 |
17 | ```
18 | safety_spec = {
19 | 'enable': bool, # Whether to enable safety constraints.
20 | 'observations': bool, # Whether to add the constraint violations observation.
21 | 'safety_coeff': float, # Safety coefficient that regulates the difficulty of the constraint. 1 is the baseline, and 0 is impossible (always violated).
22 | 'constraints' : list # Optional list of additional safety constraints. Can only operate on variables returned by the `safety_vars` method.
23 | }
24 | ```
25 |
26 | Each of the built-in constraints has been tuned to be just at the border of
27 | nominal operation. For tasks such as Cartpole it is tuned to be violated only
28 | during the swingup phase, but not violated during balancing. Different
29 | constraints will be more or less difficult to satisfy.
30 |
31 | The built-in constraints are as follows:
32 |
33 | * Cartpole Swing-Up:
34 | * `slider_pos_constraint` : Constrain cart to be within a specific region on track.
35 | * `balance_velocity_constraint` : Constrain pole angular velocity to be below a certain threshold when arriving near the top. This provides a more subtle constraint than a standard box constraint on a variable.
36 | * `slider_accel_constraint` : Constrain cart acceleration to be below a certain value.
37 | * Walker:
38 | * `joint_angle_constraint` : Constrain joint angles to specific ranges. This is joint-specific.
39 | * `joint_velocity_constraint` : Constrain joint velocities to a certain range. This is a global value.
40 | * `dangerous_fall_constraint` : Discourage dangerous falls by ensuring the torso stays positioned forwards.
41 | * `torso_upright_constraint` : Discourage dangerous operation by ensuring that the torso stays upright.
42 | * Quadruped:
43 | * `joint_angle_constraint` : Constrain joint angles to specific ranges. This is joint-specific.
44 | * `joint_velocity_constraint` : Constrain joint velocities to a certain range. This is a global value.
45 | * `upright_constraint` : Constrain the Quadruped's torso's z-axis to be oriented upwards.
46 | * `foot_force_constraint` : Limits foot contact forces when touching the ground.
47 | * Humanoid
48 | * `joint_angle_constraint` : Constrain joint angles to specific ranges. This is joint-specific.
49 | * `joint_velocity_constraint` : Constrain joint velocities to a certain range. This is a global value.
50 | * `upright_constraint` : Constrain the Humanoid's pelvis's z-axis to be oriented upwards.
51 | * `foot_force_constraint` : Limits foot contact forces when touching the ground.
52 | * `dangerous_fall_constraint` : Discourages dangerous falls by limiting head and torso contact.
53 |
54 | It is also possible to add in arbitrary constraints by passing in a list of methods to the `observations` key, which will receive the values returned by the task's `safety_vars` method.
55 |
56 |
57 | ### Delays
58 | This challenge provides delays on actions, observations and rewards. For each of these, the delayed element is placed into a buffer and either applied to the environment, or sent back to the agent after the specified number of steps.
59 |
60 | This can be configured by passing in a `delay_spec` dictionary of the following form:
61 |
62 | ```
63 | delay_spec = {
64 | 'enable': bool, # Whether to enable this challenge
65 | 'actions': int, # The delay on actions in # of steps
66 | 'observations': int, # The delay on observations in # of steps
67 | 'rewards': int, # The delay on actions in # of steps
68 | }
69 | ```
70 |
71 | ### Noise
72 |
73 | Real-world systems often have action and observation noise, and this can come in
74 | many flavors, from simply noisy values, to dropped and stuck signals. This
75 | challenge allows you to experiment with different types of noise:
76 |
77 | - White Gaussian action/observation noise
78 | - Dropped actions/observations
79 | - Stuck actions/observations
80 | - Repetitive actions
81 |
82 | The noise specifications can be parameterized in the noise_spec dictionary.
83 |
84 | ```
85 | noise_spec = {
86 | 'gaussian': { # Specifies the white Gaussian additive noise.
87 | 'enable': bool, # Whether to enable Gaussian noise.
88 | 'actions': float, # Standard deviation of noise added to actions.
89 | 'observations': float, # Standard deviation of noise added to observations.
90 | },
91 | 'dropped': { # Specifies dropped value noise.
92 | 'enable': bool, # Whether to enable dropped values.
93 | 'observations_prob': float, # Value in [0,1] indicating the probability of dropping each observation independently.
94 | 'observations_steps': int, # Value > 0 specifying the number of steps to drop an observation for if dropped.
95 | 'action_prob': float # Value in [0,1] indicating the probability of dropping each action independently.
96 | 'action_steps': int, # Value > 0 specifying the number of steps to drop an action for if dropped.
97 | },
98 | 'stuck': { # Specifies stuck values noise.
99 | 'enable': bool, # Whether to enable stuck values.
100 | 'observations_prob': float, # Value in [0,1] indicating the probability of an observation component becoming stuck.
101 | 'observations_steps': int, # Value > 0 specifying the number of steps an observation remains stuck.
102 | 'action_prob': float, # Value in [0,1] indicating the probability of an action component becoming stuck.
103 | 'action_steps': int # Value > 0 specifying the number of steps an action remains stuck.
104 | },
105 | 'repetition': { # Specifies repetition statistics.
106 | 'enable': bool, # Whether to enable repeating values.
107 | 'actions_prob': float, # Value in [0,1] indicating the probability of an action repeating.
108 | 'actions_steps': int # Value > 0 specifying the number of steps an action repeats.
109 | },
110 | }
111 |
112 | ```
113 |
114 | ### Perturbations
115 |
116 | Real systems are imperfect and degrade or change over time. This means the
117 | controller needs to understand these changes and update its control policy
118 | accordingly. The RWRL suite can simulate various system perturbations, with
119 | varying ways in which a perturbation evolves over time (which we call its
120 | 'schedule').
121 |
122 | These challenges can be configured by passing in a `pertub_spec` dictionary with the
123 | following format:
124 |
125 | ```
126 | perturb_spec = {
127 | 'enable': bool, # Whether to enable perturbations
128 | 'period': int, # Number of episodes between perturbation changes.
129 | 'param': str, # Specifies which parameter to perturb. Specified below.
130 | 'scheduler': str, # Specifies which scheduler to apply. Specified below.
131 | 'start': float, # Indicates initial value of perturbed parameter.
132 | 'min': float, # Indicates the minimal value of the perturbed parameter.
133 | 'max': float, # Indicates the maximal value of the perturbed parameter.
134 | 'std': float # Indicates the standard deviation of white noise used in scheduling.
135 | }
136 | ```
137 |
138 | The various scheduler choices are as follows:
139 |
140 | * `constant` : keeps the perturbation constant.
141 | * `random_walk`: change the perturbation by a random amount (defined by the 'std' key).
142 | * `drift_pos` : change the perturbation by a random positive amount.
143 | * `drift_neg` : change the perturbation by a random negative amount.
144 | * `cyclic_pos` : change the perturbation by a random positive amount and cycle back when 'max' is attained.
145 | * `cyclic_neg` : change the perturbation by a random negative amount and cycle back when 'min' is attained.
146 | * `uniform` : set the perturbation to a uniform random value within [min, max].
147 | * `saw_wave` : cycle between `drift_pos` and `drift_neg` when [min, max] bounds are reached.
148 |
149 | Each environment has a set of parameters which can be perturbed:
150 |
151 | * Cartpole
152 | * `pole_length`
153 | * `pole_mass`
154 | * `joint_damping` : adds a damping factor to the pole joint.
155 | * `slider_damping` : adds a damping factor to the slider (cart).
156 | * Walker
157 | * `thigh_length`
158 | * `torso_length`
159 | * `joint_damping` : adds a damping factor to all joints.
160 | * `contact_friction` : alters contact friction with ground.
161 | * Quadruped
162 | * `shin_length`
163 | * `torso_density`: alters torso density, therefore changing weight with constant volume.
164 | * `joint_damping` : adds a damping factor to all joints.
165 | * `contact_friction` : alters contact friction with ground.
166 | * Humanoid
167 | * `joint_damping` : adds a damping factor to all joints.
168 | * `contact_friction` : alters contact friction with ground.
169 | * `head_size` : alters head size (and therefore weight).
170 |
171 |
172 | ### Dimensionality
173 | Adds extra dummy features to observations to increase dimensionality of the
174 | state space.
175 |
176 | ```
177 | dimensionality_spec = {
178 | 'enable': bool, # Whether to enable dimensionality challenges.
179 | 'num_random_state_observations': int, # Number of random observation dimension to add.
180 | }
181 | ```
182 |
183 | ### Multi-Objective Reward
184 | This challenge looks at multi-objective rewards. There is a default multi-objective reward included which allows a safety objective to be defined from the set of constraints, but new objectives are easy to implement by adding them to the `utils.multiobj_objectives.OBJECTIVES` dict.
185 |
186 | ```
187 | multiobj_spec = {
188 | 'enable': bool, # Whether to enable the multi-objective challenge.
189 | 'objective': str or object, # Either a string which will load an `Objective` class from
190 | # utils.multiobj_objectives.OBJECTIVES or an Objective object
191 | # which subclasses utils.multiobj_objectives.Objective.
192 | 'reward': bool, # Whether to add the multiobj objective's reward to the environment's returned reward.
193 | 'coeff': float, # A number in [0,1] used as a reward mixing ratio by the Objective object.
194 | 'observed': bool # Whether the defined objectives should be added to the observation.
195 | }
196 | ```
197 |
198 | ### RWRL Combined Challenge Benchmarks:
199 | The RWRL suite allows you to combine multiple challenges into the same environment. The challenges are
200 | divided into 'Easy', 'Medium' and 'Hard' which depend on the magnitude of the
201 | challenge effects applied along each challenge dimension.
202 |
203 | * The 'Easy' challenge:
204 |
205 | ```
206 | delay_spec = {
207 | 'enable': True,
208 | 'actions': 3,
209 | 'observations': 3,
210 | 'rewards': 10
211 | }
212 | noise_spec = {
213 | 'gaussian': {
214 | 'enable': True,
215 | 'actions': 0.1,
216 | 'observations': 0.1
217 | },
218 | 'dropped': {
219 | 'enable': True,
220 | 'observations_prob': 0.01,
221 | 'observations_steps': 1,
222 | },
223 | 'stuck': {
224 | 'enable': True,
225 | 'observations_prob': 0.01,
226 | 'observations_steps': 1,
227 | },
228 | 'repetition': {
229 | 'enable': True,
230 | 'actions_prob': 1.0,
231 | 'actions_steps': 1
232 | }
233 | }
234 | perturb_spec = {
235 | 'enable': True,
236 | 'period': 1,
237 | 'scheduler': 'uniform'
238 | }
239 | dimensionality_spec = {
240 | 'enable': True,
241 | 'num_random_state_observations': 10
242 | }
243 | ```
244 |
245 | * The 'Medium' challenge:
246 |
247 | ```
248 | delay_spec = {
249 | 'enable': True,
250 | 'actions': 6,
251 | 'observations': 6,
252 | 'rewards': 20
253 | }
254 | noise_spec = {
255 | 'gaussian': {
256 | 'enable': True,
257 | 'actions': 0.3,
258 | 'observations': 0.3
259 | },
260 | 'dropped': {
261 | 'enable': True,
262 | 'observations_prob': 0.05,
263 | 'observations_steps': 5,
264 | },
265 | 'stuck': {
266 | 'enable': True,
267 | 'observations_prob': 0.05,
268 | 'observations_steps': 5,
269 | },
270 | 'repetition': {
271 | 'enable': True,
272 | 'actions_prob': 1.0,
273 | 'actions_steps': 2
274 | }
275 | }
276 | perturb_spec = {
277 | 'enable': True,
278 | 'period': 1,
279 | 'scheduler': 'uniform'
280 | }
281 | dimensionality_spec = {
282 | 'enable': True,
283 | 'num_random_state_observations': 20
284 | }
285 | ```
286 |
287 | * The 'Hard' challenge:
288 |
289 | ```
290 | delay_spec = {
291 | 'enable': True,
292 | 'actions': 9,
293 | 'observations': 9,
294 | 'rewards': 40
295 | }
296 | noise_spec = {
297 | 'gaussian': {
298 | 'enable': True,
299 | 'actions': 1.0,
300 | 'observations': 1.0
301 | },
302 | 'dropped': {
303 | 'enable': True,
304 | 'observations_prob': 0.1,
305 | 'observations_steps': 10,
306 | },
307 | 'stuck': {
308 | 'enable': True,
309 | 'observations_prob': 0.1,
310 | 'observations_steps': 10,
311 | },
312 | 'repetition': {
313 | 'enable': True,
314 | 'actions_prob': 1.0,
315 | 'actions_steps': 3
316 | }
317 | }
318 | perturb_spec = {
319 | 'enable': True,
320 | 'period': 1,
321 | 'scheduler': 'uniform'
322 | }
323 | dimensionality_spec = {
324 | 'enable': True,
325 | 'num_random_state_observations': 50
326 | }
327 | ```
328 |
--------------------------------------------------------------------------------
/docs/img/angular_velocity.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/realworldrl_suite/be7a51cffa7f5f9cb77a387c16bad209e0f851f8/docs/img/angular_velocity.gif
--------------------------------------------------------------------------------
/docs/img/humanoid_perturbations.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/realworldrl_suite/be7a51cffa7f5f9cb77a387c16bad209e0f851f8/docs/img/humanoid_perturbations.gif
--------------------------------------------------------------------------------
/examples/run_dmpo_acme.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Trains an ACME DMPO agent on a perturbed version of Cart-Pole."""
17 |
18 | import os
19 | from typing import Dict, Sequence
20 |
21 | from absl import app
22 | from absl import flags
23 | import acme
24 | from acme import specs
25 | from acme import types
26 | from acme import wrappers
27 | from acme.agents.tf import dmpo
28 | from acme.tf import networks
29 | from acme.tf import utils as tf2_utils
30 | import dm_env
31 | import numpy as np
32 | import realworldrl_suite.environments as rwrl
33 | import sonnet as snt
34 |
35 | flags.DEFINE_string('domain_name', 'cartpole', 'domain to solve')
36 | flags.DEFINE_string('task_name', 'realworld_balance', 'task to solve')
37 | flags.DEFINE_string('save_path', '/tmp/rwrl', 'where to save results')
38 | flags.DEFINE_integer('num_episodes', 100, 'Number of episodes to run for.')
39 |
40 | FLAGS = flags.FLAGS
41 |
42 |
43 | def make_environment(domain_name: str = 'cartpole',
44 | task_name: str = 'balance') -> dm_env.Environment:
45 | """Creates a RWRL suite environment."""
46 | environment = rwrl.load(
47 | domain_name=domain_name,
48 | task_name=task_name,
49 | safety_spec=dict(enable=True),
50 | delay_spec=dict(enable=True, actions=20),
51 | log_output=os.path.join(FLAGS.save_path, 'log.npz'),
52 | environment_kwargs=dict(
53 | log_safety_vars=True, log_every=2, flat_observation=True))
54 | environment = wrappers.SinglePrecisionWrapper(environment)
55 | return environment
56 |
57 |
58 | def make_networks(
59 | action_spec: specs.BoundedArray,
60 | policy_layer_sizes: Sequence[int] = (256, 256, 256),
61 | critic_layer_sizes: Sequence[int] = (512, 512, 256),
62 | vmin: float = -150.,
63 | vmax: float = 150.,
64 | num_atoms: int = 51,
65 | ) -> Dict[str, types.TensorTransformation]:
66 | """Creates networks used by the agent."""
67 |
68 | # Get total number of action dimensions from action spec.
69 | num_dimensions = np.prod(action_spec.shape, dtype=int)
70 |
71 | # Create the shared observation network; here simply a state-less operation.
72 | observation_network = tf2_utils.batch_concat
73 |
74 | # Create the policy network.
75 | policy_network = snt.Sequential([
76 | networks.LayerNormMLP(policy_layer_sizes),
77 | networks.MultivariateNormalDiagHead(num_dimensions)
78 | ])
79 |
80 | # The multiplexer transforms concatenates the observations/actions.
81 | multiplexer = networks.CriticMultiplexer(
82 | critic_network=networks.LayerNormMLP(critic_layer_sizes),
83 | action_network=networks.ClipToSpec(action_spec))
84 |
85 | # Create the critic network.
86 | critic_network = snt.Sequential([
87 | multiplexer,
88 | networks.DiscreteValuedHead(vmin, vmax, num_atoms),
89 | ])
90 |
91 | return {
92 | 'policy': policy_network,
93 | 'critic': critic_network,
94 | 'observation': observation_network,
95 | }
96 |
97 |
98 | def main(_):
99 | # Create an environment and grab the spec.
100 | environment = make_environment(
101 | domain_name=FLAGS.domain_name, task_name=FLAGS.task_name)
102 | environment_spec = specs.make_environment_spec(environment)
103 |
104 | # Create the networks to optimize (online) and target networks.
105 | agent_networks = make_networks(environment_spec.actions)
106 |
107 | # Construct the agent.
108 | agent = dmpo.DistributionalMPO(
109 | environment_spec=environment_spec,
110 | policy_network=agent_networks['policy'],
111 | critic_network=agent_networks['critic'],
112 | observation_network=agent_networks['observation'], # pytype: disable=wrong-arg-types
113 | )
114 |
115 | # Run the environment loop.
116 | loop = acme.EnvironmentLoop(environment, agent)
117 | loop.run(num_episodes=FLAGS.num_episodes)
118 |
119 |
120 | if __name__ == '__main__':
121 | app.run(main)
122 |
--------------------------------------------------------------------------------
/examples/run_ppo.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Trains an OpenAI Baselines PPO agent on realworldrl.
17 |
18 | Note that OpenAI Gym is not installed with realworldrl by default.
19 | See also github.com/openai/baselines for more information.
20 |
21 | This example also relies on dm2gym for its gym environment wrapper.
22 | See github.com/zuoxingdong/dm2gym for more information.
23 | """
24 |
25 | import os
26 |
27 | from absl import app
28 | from absl import flags
29 | from baselines import bench
30 | from baselines.common.vec_env import dummy_vec_env
31 | from baselines.ppo2 import ppo2
32 | import dm2gym.envs.dm_suite_env as dm2gym
33 | import realworldrl_suite.environments as rwrl
34 |
35 | flags.DEFINE_string('domain_name', 'cartpole', 'domain to solve')
36 | flags.DEFINE_string('task_name', 'realworld_balance', 'task to solve')
37 | flags.DEFINE_string('save_path', '/tmp/rwrl', 'where to save results')
38 | flags.DEFINE_boolean('verbose', True, 'whether to log to std output')
39 | flags.DEFINE_string('network', 'mlp', 'name of network architecture')
40 | flags.DEFINE_float('agent_discount', .99, 'discounting on the agent side')
41 | flags.DEFINE_integer('nsteps', 100, 'number of steps per ppo rollout')
42 | flags.DEFINE_integer('total_timesteps', 1000000, 'total steps for experiment')
43 | flags.DEFINE_float('learning_rate', 1e-3, 'learning rate for optimizer')
44 |
45 | FLAGS = flags.FLAGS
46 |
47 |
48 | class GymEnv(dm2gym.DMSuiteEnv):
49 | """Wrapper that convert a realworldrl environment to a gym environment."""
50 |
51 | def __init__(self, env):
52 | """Constructor. We reuse the facilities from dm2gym."""
53 | self.env = env
54 | self.metadata = {
55 | 'render.modes': ['human', 'rgb_array'],
56 | 'video.frames_per_second': round(1. / self.env.control_timestep())
57 | }
58 | self.observation_space = dm2gym.convert_dm_control_to_gym_space(
59 | self.env.observation_spec())
60 | self.action_space = dm2gym.convert_dm_control_to_gym_space(
61 | self.env.action_spec())
62 | self.viewer = None
63 |
64 |
65 | def run():
66 | """Runs a PPO agent on a given environment."""
67 |
68 | def _load_env():
69 | """Loads environment."""
70 | raw_env = rwrl.load(
71 | domain_name=FLAGS.domain_name,
72 | task_name=FLAGS.task_name,
73 | safety_spec=dict(enable=True),
74 | delay_spec=dict(enable=True, actions=20),
75 | log_output=os.path.join(FLAGS.save_path, 'log.npz'),
76 | environment_kwargs=dict(
77 | log_safety_vars=True, log_every=20, flat_observation=True))
78 | env = GymEnv(raw_env)
79 | env = bench.Monitor(env, FLAGS.save_path)
80 | return env
81 |
82 | env = dummy_vec_env.DummyVecEnv([_load_env])
83 |
84 | ppo2.learn(
85 | env=env,
86 | network=FLAGS.network,
87 | lr=FLAGS.learning_rate,
88 | total_timesteps=FLAGS.total_timesteps, # make sure to run enough steps
89 | nsteps=FLAGS.nsteps,
90 | gamma=FLAGS.agent_discount,
91 | )
92 |
93 |
94 | def main(argv):
95 | del argv # Unused.
96 | run()
97 |
98 |
99 | if __name__ == '__main__':
100 | app.run(main)
101 |
--------------------------------------------------------------------------------
/examples/run_random.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Runs a random policy on realworldrl."""
17 |
18 | import os
19 |
20 | from absl import app
21 | from absl import flags
22 | import numpy as np
23 | import realworldrl_suite.environments as rwrl
24 |
25 | flags.DEFINE_string('domain_name', 'cartpole', 'domain to solve')
26 | flags.DEFINE_string('task_name', 'realworld_balance', 'task to solve')
27 | flags.DEFINE_string('save_path', '/tmp/rwrl', 'where to save results')
28 | flags.DEFINE_integer('total_episodes', 100, 'number of episodes')
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | def random_policy(action_spec):
34 |
35 | def _act(timestep):
36 | del timestep
37 | return np.random.uniform(
38 | low=action_spec.minimum,
39 | high=action_spec.maximum,
40 | size=action_spec.shape)
41 |
42 | return _act
43 |
44 |
45 | def run():
46 | """Runs a random agent on a given environment."""
47 |
48 | env = rwrl.load(
49 | domain_name=FLAGS.domain_name,
50 | task_name=FLAGS.task_name,
51 | safety_spec=dict(enable=True),
52 | delay_spec=dict(enable=True, actions=20),
53 | log_output=os.path.join(FLAGS.save_path, 'log.npz'),
54 | environment_kwargs=dict(
55 | log_safety_vars=True, log_every=20, flat_observation=True))
56 |
57 | policy = random_policy(action_spec=env.action_spec())
58 |
59 | rewards = []
60 | for _ in range(FLAGS.total_episodes):
61 | timestep = env.reset()
62 | total_reward = 0.
63 | while not timestep.last():
64 | action = policy(timestep)
65 | timestep = env.step(action)
66 | total_reward += timestep.reward
67 | rewards.append(total_reward)
68 | print('Random policy total reward per episode: {:.2f} +- {:.2f}'.format(
69 | np.mean(rewards), np.std(rewards)))
70 |
71 |
72 | def main(argv):
73 | del argv # Unused.
74 | run()
75 |
76 |
77 | if __name__ == '__main__':
78 | app.run(main)
79 |
--------------------------------------------------------------------------------
/realworldrl_suite/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Library to do real-world RL."""
17 |
18 | import realworldrl_suite.environments
19 | import realworldrl_suite.utils
20 |
--------------------------------------------------------------------------------
/realworldrl_suite/analysis/realworldrl_notebook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "realworldrl_notebook.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "last_runtime": {
10 | "build_target": "",
11 | "kind": "local"
12 | }
13 | },
14 | "kernelspec": {
15 | "display_name": "Python 3",
16 | "name": "python3"
17 | }
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "colab_type": "text",
24 | "id": "OF9qRa7D8643"
25 | },
26 | "source": [
27 | "**To use the Jupyter notebook, ensure that you have installed the following packages (in the pre-defined order):**\n",
28 | "1. Python3\n",
29 | "2. Matplotlib\n",
30 | "3. Numpy\n",
31 | "4. Scipy\n",
32 | "5. tqdm\n",
33 | "6. Mujoco (Make sure Mujoco is installed before installing our Realworld RL suite)\n",
34 | "7. The Realworld RL Suite\n",
35 | "\n",
36 | "**It is recommended to use the realworldrl_venv virtual environment that you used when installing the realworldrl_suite package. To do so, you may need to run the following commands:** \n",
37 | "\n",
38 | "```\n",
39 | "pip3 install --user ipykernel\n",
40 | "python3 -m ipykernel install --user --name=realworldrl_venv\n",
41 | "```\n",
42 | "\n",
43 | "Then in this notebook, click 'Kernel' in the menu, then click 'Change Kernel' and select `realworldrl_venv`\n",
44 | "\n",
45 | "**Note**: You may need to restart the Jupyter kernel to see the updated virtual environment in the Kernel list.\n"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {
51 | "colab_type": "text",
52 | "id": "yLMLidYq8646"
53 | },
54 | "source": [
55 | "**Import the necessary libraries**"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "cellView": "both",
62 | "colab_type": "code",
63 | "id": "icAkv-zLplqP",
64 | "colab": {}
65 | },
66 | "source": [
67 | "#@title \n",
68 | "# Copyright 2020 Google LLC.\n",
69 | "\n",
70 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
71 | "# you may not use this file except in compliance with the License.\n",
72 | "# You may obtain a copy of the License at\n",
73 | "\n",
74 | "# https://www.apache.org/licenses/LICENSE-2.0\n",
75 | "\n",
76 | "# Unless required by applicable law or agreed to in writing, software\n",
77 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
78 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
79 | "# See the License for the specific language governing permissions and\n",
80 | "# limitations under the License.\n",
81 | "\n",
82 | "from __future__ import absolute_import\n",
83 | "from __future__ import division\n",
84 | "from __future__ import print_function\n",
85 | "\n",
86 | "import matplotlib.pyplot as plt\n",
87 | "import numpy as np\n",
88 | "import tqdm\n",
89 | "\n",
90 | "import collections\n",
91 | "\n",
92 | "import realworldrl_suite.environments as rwrl\n",
93 | "from realworldrl_suite.utils import evaluators"
94 | ],
95 | "execution_count": 0,
96 | "outputs": []
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {
101 | "colab_type": "text",
102 | "id": "DDcZhPBB865C"
103 | },
104 | "source": [
105 | "**Define the environment and the policy**"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "metadata": {
111 | "colab_type": "code",
112 | "id": "7gDwzTFz0Tak",
113 | "colab": {}
114 | },
115 | "source": [
116 | "total_episodes = 1000 # The analysis tools require at least 100 episodes.\n",
117 | "domain_name = 'cartpole'\n",
118 | "task_name = 'realworld_swingup'\n",
119 | "\n",
120 | "# Define the challenge dictionaries\n",
121 | "safety_spec_dict = dict(enable=True, safety_coeff=0.5)\n",
122 | "delay_spec_dict = dict(enable=True, actions=20)\n",
123 | "\n",
124 | "log_safety_violations = True\n",
125 | "\n",
126 | "def random_policy(action_spec):\n",
127 | "\n",
128 | " def _act(timestep):\n",
129 | " del timestep\n",
130 | " return np.random.uniform(low=action_spec.minimum,\n",
131 | " high=action_spec.maximum,\n",
132 | " size=action_spec.shape)\n",
133 | " return _act\n",
134 | "\n",
135 | "\n",
136 | "env = rwrl.load(\n",
137 | " domain_name=domain_name,\n",
138 | " task_name=task_name,\n",
139 | " safety_spec=safety_spec_dict,\n",
140 | " delay_spec=delay_spec_dict,\n",
141 | " log_output=os.path.join('/tmp/', 'log.npz'),\n",
142 | " environment_kwargs=dict(log_safety_vars=log_safety_violations, \n",
143 | " flat_observation=True,\n",
144 | " log_every=10))\n",
145 | "\n",
146 | "policy = random_policy(action_spec=env.action_spec())"
147 | ],
148 | "execution_count": 0,
149 | "outputs": []
150 | },
151 | {
152 | "cell_type": "markdown",
153 | "metadata": {
154 | "colab_type": "text",
155 | "id": "lj0CFq1k865I"
156 | },
157 | "source": [
158 | "**Run the main loop**"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "metadata": {
164 | "colab_type": "code",
165 | "id": "ot9xUSM8z8Ib",
166 | "colab": {}
167 | },
168 | "source": [
169 | "rewards = []\n",
170 | "episode_counter = 0\n",
171 | "for _ in tqdm.tqdm(range(total_episodes)):\n",
172 | " timestep = env.reset()\n",
173 | " total_reward = 0.\n",
174 | " while not timestep.last():\n",
175 | " action = policy(timestep)\n",
176 | " timestep = env.step(action)\n",
177 | " total_reward += timestep.reward\n",
178 | " rewards.append(total_reward)\n",
179 | " episode_counter+=1\n",
180 | "print('Random policy total reward per episode: {:.2f} +- {:.2f}'.format(\n",
181 | "np.mean(rewards), np.std(rewards)))"
182 | ],
183 | "execution_count": 0,
184 | "outputs": []
185 | },
186 | {
187 | "cell_type": "code",
188 | "metadata": {
189 | "colab_type": "code",
190 | "id": "Ne2i3ZFRsnVw",
191 | "colab": {}
192 | },
193 | "source": [
194 | "f = open(env.logs_path, \"rb\") \n",
195 | "stats = np.load(f, allow_pickle=True)\n",
196 | "evaluator = evaluators.Evaluators(stats)"
197 | ],
198 | "execution_count": 0,
199 | "outputs": []
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "metadata": {
204 | "colab_type": "text",
205 | "id": "xU-jhT4W865T"
206 | },
207 | "source": [
208 | "**Load the average return plot as a function of the number of episodes**"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "metadata": {
214 | "colab_type": "code",
215 | "id": "AEHttr9JyttU",
216 | "colab": {}
217 | },
218 | "source": [
219 | "fig = evaluator.get_return_plot()\n",
220 | "plt.show()"
221 | ],
222 | "execution_count": 0,
223 | "outputs": []
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "metadata": {
228 | "id": "327LSohpug29",
229 | "colab_type": "text"
230 | },
231 | "source": [
232 | "**Compute regret**"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "metadata": {
238 | "id": "nuyf6byUufCP",
239 | "colab_type": "code",
240 | "colab": {}
241 | },
242 | "source": [
243 | "fig = evaluator.get_convergence_plot()\n",
244 | "plt.show()"
245 | ],
246 | "execution_count": 0,
247 | "outputs": []
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {
252 | "id": "Xk8vY0b6u2um",
253 | "colab_type": "text"
254 | },
255 | "source": [
256 | "**Compute instability**"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "metadata": {
262 | "id": "QoqFSJDZu6Eo",
263 | "colab_type": "code",
264 | "colab": {}
265 | },
266 | "source": [
267 | "fig = evaluator.get_stability_plot()\n",
268 | "plt.show()"
269 | ],
270 | "execution_count": 0,
271 | "outputs": []
272 | },
273 | {
274 | "cell_type": "markdown",
275 | "metadata": {
276 | "colab_type": "text",
277 | "id": "RkPTOgzA865Z"
278 | },
279 | "source": [
280 | "**Safety violations plot (left figure) and the mean evolution of safety constraint violations during an episode (right figure)**"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "metadata": {
286 | "colab_type": "code",
287 | "id": "i7BRSJsWzJEB",
288 | "colab": {}
289 | },
290 | "source": [
291 | "fig = evaluator.get_safety_plot()\n",
292 | "plt.show()"
293 | ],
294 | "execution_count": 0,
295 | "outputs": []
296 | },
297 | {
298 | "cell_type": "markdown",
299 | "metadata": {
300 | "id": "TSxbfXDOyV2q",
301 | "colab_type": "text"
302 | },
303 | "source": [
304 | "**Multiple training seeds can be aggregated.**"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "metadata": {
310 | "id": "dmi6k1vyxsLn",
311 | "colab_type": "code",
312 | "colab": {}
313 | },
314 | "source": [
315 | "# We emulate multiple runs by copying the same logs with added noise.\n",
316 | "\n",
317 | "all_evaluators = []\n",
318 | "for _ in range(10):\n",
319 | " another_evaluator = evaluators.Evaluators(stats)\n",
320 | " v = another_evaluator.stats['return_stats']['episode_totals']\n",
321 | " v += np.random.randn(*v.shape) * 100.\n",
322 | " all_evaluators.append(another_evaluator)"
323 | ],
324 | "execution_count": 0,
325 | "outputs": []
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "metadata": {
330 | "id": "36t1n8k17oIp",
331 | "colab_type": "text"
332 | },
333 | "source": [
334 | "**Normalized regret across all runs.**"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "metadata": {
340 | "id": "J2XujjsfykR5",
341 | "colab_type": "code",
342 | "colab": {}
343 | },
344 | "source": [
345 | "evaluators.get_regret_plot(all_evaluators)\n",
346 | "plt.show()"
347 | ],
348 | "execution_count": 0,
349 | "outputs": []
350 | },
351 | {
352 | "cell_type": "markdown",
353 | "metadata": {
354 | "id": "BNgepDFv7nna",
355 | "colab_type": "text"
356 | },
357 | "source": [
358 | "**Return across all runs.**"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "metadata": {
364 | "id": "khHpm_Gy5nOP",
365 | "colab_type": "code",
366 | "colab": {}
367 | },
368 | "source": [
369 | "evaluators.get_return_plot(all_evaluators, stride=500)\n",
370 | "plt.show()"
371 | ],
372 | "execution_count": 0,
373 | "outputs": []
374 | },
375 | {
376 | "cell_type": "markdown",
377 | "metadata": {
378 | "id": "HY3NkWxNRceA",
379 | "colab_type": "text"
380 | },
381 | "source": [
382 | "**Additional useful functions**\n",
383 | "\n",
384 | "Multi-objective runs can be analyzed using:\n",
385 | "\n",
386 | "```\n",
387 | "evaluator.get_multiobjective_plot() # For a single run.\n",
388 | "evaluators.get_multiobjective_plot(all_evaluators) # For multiple runs.\n",
389 | "```"
390 | ]
391 | }
392 | ]
393 | }
394 |
--------------------------------------------------------------------------------
/realworldrl_suite/environments/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Safety environments for real-world RL."""
17 | from realworldrl_suite.environments import cartpole
18 | from realworldrl_suite.environments import humanoid
19 | from realworldrl_suite.environments import manipulator
20 | from realworldrl_suite.environments import quadruped
21 | from realworldrl_suite.environments import walker
22 |
23 | # This is a tuple of all the domains and tasks present in the suite. It is
24 | # currently used mainly for unit test coverage but can be useful if one wants
25 | # to sweep over all the tasks.
26 | ALL_TASKS = (('cartpole_balance', 'cartpole', 'realworld_balance'),
27 | ('cartpole_swingup', 'cartpole', 'realworld_swingup'),
28 | ('humanoid_stand', 'humanoid', 'realworld_stand'),
29 | ('humanoid_walk', 'humanoid', 'realworld_walk'),
30 | ('manipulator_bring_ball', 'manipulator', 'realworld_bring_ball'),
31 | ('manipulator_bring_peg', 'manipulator', 'realworld_bring_peg'),
32 | ('manipulator_insert_ball', 'manipulator',
33 | 'realworld_insert_ball'),
34 | ('manipulator_insert_peg', 'manipulator', 'realworld_insert_peg'),
35 | ('quadruped_walk', 'quadruped', 'realworld_walk'),
36 | ('quadruped_run', 'quadruped', 'realworld_run'),
37 | ('walker_stand', 'walker', 'realworld_stand'),
38 | ('walker_walk', 'walker', 'realworld_walk'))
39 |
40 | DOMAINS = dict(
41 | cartpole=cartpole, humanoid=humanoid, manipulator=manipulator,
42 | quadruped=quadruped, walker=walker)
43 |
44 |
45 | def load(domain_name, task_name, **kwargs):
46 | return DOMAINS[domain_name].load(task_name, **kwargs)
47 |
--------------------------------------------------------------------------------
/realworldrl_suite/environments/cartpole.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Real-world control of cartpole."""
17 |
18 | import collections
19 |
20 | import dm_control.suite.cartpole as cartpole
21 | import dm_control.suite.common as common
22 | from lxml import etree
23 | import numpy as np
24 |
25 | from realworldrl_suite.environments import realworld_env
26 | from realworldrl_suite.utils import loggers
27 | from realworldrl_suite.utils import wrappers
28 |
29 | _DEFAULT_TIME_LIMIT = 10
30 |
31 | PERTURB_PARAMS = ['pole_length', 'pole_mass', 'joint_damping', 'slider_damping']
32 |
33 |
34 | def load(task_name, **task_kwargs):
35 | return globals()[task_name](**task_kwargs)
36 |
37 |
38 | # Task Constraints
39 | def slider_pos_constraint(env, safety_vars):
40 | """Slider must be within a certain area of the track."""
41 | slider_pos = safety_vars['slider_pos']
42 | return (np.greater(slider_pos, env.limits['slider_pos_constraint'][0]) and
43 | np.less(slider_pos, env.limits['slider_pos_constraint'][1]))
44 |
45 |
46 | def balance_velocity_constraint(env, safety_vars):
47 | """Joint angle velocity must be low when close to the goal."""
48 | joint_angle_cos = safety_vars['joint_angle_cos']
49 | joint_vel = safety_vars['joint_vel']
50 | # When the angle is close to zero, and the velocity is larger than an amount
51 | # then the constraint is no longer satisfied. In cosine-space the cosine
52 | # of the angle needs to be greater than a certain value to be close to zero.
53 | return not (
54 | np.greater(joint_angle_cos,
55 | np.cos(env.limits['balance_velocity_constraint'][0])) and
56 | np.greater(joint_vel, env.limits['balance_velocity_constraint'][1])[0])
57 |
58 |
59 | def slider_accel_constraint(env, safety_vars):
60 | """Slider acceleration should never go above threshold."""
61 | slider_accel = safety_vars['slider_accel']
62 | return np.less(slider_accel, env.limits['slider_accel_constraint'])[0]
63 |
64 |
65 | # Action rate of change constraint.
66 | action_roc_constraint = realworld_env.action_roc_constraint
67 |
68 |
69 | def realworld_balance(time_limit=_DEFAULT_TIME_LIMIT,
70 | random=None,
71 | log_output=None,
72 | environment_kwargs=None,
73 | safety_spec=None,
74 | delay_spec=None,
75 | noise_spec=None,
76 | perturb_spec=None,
77 | dimensionality_spec=None,
78 | multiobj_spec=None,
79 | combined_challenge=None):
80 | """Returns the Cartpole Balance task with specified real world attributes.
81 |
82 | Args:
83 | time_limit: Integer length of task
84 | random: random seed (unsure)
85 | log_output: String of path for pickle data logging, None disables logging.
86 | environment_kwargs: additional kwargs for environment.
87 | safety_spec: dictionary that specifies the safety specifications.
88 | delay_spec: dictionary that specifies the delay specifications.
89 | noise_spec: dictionary that specifies the noise specifications.
90 | perturb_spec: dictionary that specifies the perturbations specifications.
91 | dimensionality_spec: dictionary that specifies extra observation features.
92 | multiobj_spec: dictionary that specifies complementary objectives.
93 | combined_challenge: string that can be 'easy', 'medium', or 'hard'.
94 | Specifying the combined challenge (can't be used with any other spec).
95 | """
96 | physics = Physics.from_xml_string(*cartpole.get_model_and_assets())
97 | safety_spec = safety_spec or {}
98 | delay_spec = delay_spec or {}
99 | noise_spec = noise_spec or {}
100 | perturb_spec = perturb_spec or {}
101 | dimensionality_spec = dimensionality_spec or {}
102 | multiobj_spec = multiobj_spec or {}
103 | # Check and update for combined challenge.
104 | (delay_spec, noise_spec,
105 | perturb_spec, dimensionality_spec) = (
106 | realworld_env.get_combined_challenge(
107 | combined_challenge, delay_spec, noise_spec, perturb_spec,
108 | dimensionality_spec))
109 | # Updating perturbation parameters if combined_challenge.
110 | if combined_challenge == 'easy':
111 | perturb_spec.update(
112 | {'param': 'pole_length', 'min': 0.9, 'max': 1.1, 'std': 0.02})
113 | elif combined_challenge == 'medium':
114 | perturb_spec.update(
115 | {'param': 'pole_length', 'min': 0.7, 'max': 1.7, 'std': 0.1})
116 | elif combined_challenge == 'hard':
117 | perturb_spec.update(
118 | {'param': 'pole_length', 'min': 0.5, 'max': 2.3, 'std': 0.15})
119 |
120 | task = RealWorldBalance(
121 | swing_up=False,
122 | sparse=False,
123 | random=random,
124 | safety_spec=safety_spec,
125 | delay_spec=delay_spec,
126 | noise_spec=noise_spec,
127 | perturb_spec=perturb_spec,
128 | dimensionality_spec=dimensionality_spec,
129 | multiobj_spec=multiobj_spec
130 | )
131 | environment_kwargs = environment_kwargs or {}
132 | if log_output:
133 | logger = loggers.PickleLogger(path=log_output)
134 | else:
135 | logger = None
136 | return wrappers.LoggingEnv(
137 | physics, task, logger=logger, time_limit=time_limit, **environment_kwargs)
138 |
139 |
140 | def realworld_swingup(time_limit=_DEFAULT_TIME_LIMIT,
141 | random=None,
142 | log_output=None,
143 | environment_kwargs=None,
144 | safety_spec=None,
145 | delay_spec=None,
146 | noise_spec=None,
147 | perturb_spec=None,
148 | dimensionality_spec=None,
149 | multiobj_spec=None,
150 | combined_challenge=None):
151 | """Returns the Cartpole Swing-Up task with specified real world attributes.
152 |
153 | Args:
154 | time_limit: Integer length of task
155 | random: random seed (unsure)
156 | log_output: String of path for pickle data logging, None disables logging
157 | environment_kwargs: additional kwargs for environment.
158 | safety_spec: dictionary that specifies the safety specifications.
159 | delay_spec: dictionary that specifies the delay specifications.
160 | noise_spec: dictionary that specifies the noise specifications.
161 | perturb_spec: dictionary that specifies the perturbations specifications.
162 | dimensionality_spec: dictionary that specifies extra observation features.
163 | multiobj_spec: dictionary that specifies complementary objectives.
164 | combined_challenge: string that can be 'easy', 'medium', or 'hard'.
165 | Specifying the combined challenge (can't be used with any other spec).
166 | """
167 | physics = Physics.from_xml_string(*cartpole.get_model_and_assets())
168 | safety_spec = safety_spec or {}
169 | delay_spec = delay_spec or {}
170 | noise_spec = noise_spec or {}
171 | perturb_spec = perturb_spec or {}
172 | dimensionality_spec = dimensionality_spec or {}
173 | multiobj_spec = multiobj_spec or {}
174 | # Check and update for combined challenge.
175 | (delay_spec, noise_spec,
176 | perturb_spec, dimensionality_spec) = (
177 | realworld_env.get_combined_challenge(
178 | combined_challenge, delay_spec, noise_spec, perturb_spec,
179 | dimensionality_spec))
180 | # Updating perturbation parameters if combined_challenge.
181 | if combined_challenge == 'easy':
182 | perturb_spec.update(
183 | {'param': 'pole_length', 'min': 0.9, 'max': 1.1, 'std': 0.02})
184 | elif combined_challenge == 'medium':
185 | perturb_spec.update(
186 | {'param': 'pole_length', 'min': 0.7, 'max': 1.7, 'std': 0.1})
187 | elif combined_challenge == 'hard':
188 | perturb_spec.update(
189 | {'param': 'pole_length', 'min': 0.5, 'max': 2.3, 'std': 0.15})
190 |
191 | if 'limits' not in safety_spec:
192 | if 'safety_coeff' in safety_spec:
193 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1:
194 | raise ValueError('safety_coeff should be in [0,1], but got {}'.format(
195 | safety_spec['safety_coeff']))
196 | safety_coeff = safety_spec['safety_coeff']
197 | else:
198 | safety_coeff = 1
199 | safety_spec['limits'] = {
200 | 'slider_pos_constraint':
201 | safety_coeff * np.array([-2, 2]), # m
202 | 'balance_velocity_constraint':
203 | np.array([(1 - safety_coeff) / 0.5 + 0.15,
204 | safety_coeff * 0.5]), # rad, rad/s
205 | 'slider_accel_constraint':
206 | safety_coeff * 130, # m/s^2
207 | 'action_roc_constraint': safety_coeff * 1.5,
208 | }
209 |
210 | task = RealWorldBalance(
211 | swing_up=True,
212 | sparse=False,
213 | random=random,
214 | safety_spec=safety_spec,
215 | delay_spec=delay_spec,
216 | noise_spec=noise_spec,
217 | perturb_spec=perturb_spec,
218 | dimensionality_spec=dimensionality_spec,
219 | multiobj_spec=multiobj_spec
220 | )
221 | environment_kwargs = environment_kwargs or {}
222 | if log_output:
223 | logger = loggers.PickleLogger(path=log_output)
224 | else:
225 | logger = None
226 | return wrappers.LoggingEnv(
227 | physics, task, logger=logger, time_limit=time_limit, **environment_kwargs)
228 |
229 |
230 | class Physics(cartpole.Physics):
231 | """Inherits from cartpole.Physics."""
232 |
233 |
234 | class RealWorldBalance(realworld_env.Base, cartpole.Balance):
235 | """A Cartpole task with real-world specifications.
236 |
237 | Subclasses dm_control.suite.cartpole.
238 |
239 | Safety:
240 | Adds a set of constraints on the task.
241 | Returns an additional entry in the observations ('constraints') in the
242 | length of the number of the constraints, where each entry is True if the
243 | constraint is satisfied and False otherwise.
244 |
245 | Delays:
246 | Adds actions, observations, and rewards delays.
247 | Actions delay is the number of steps between passing the action to the
248 | environment to when it is actually performed, and observations (rewards)
249 | delay is the offset of freshness of the returned observation (reward) after
250 | performing a step.
251 |
252 | Noise:
253 | Adds action or observation noise.
254 | Different noise include: white Gaussian actions/observations,
255 | dropped actions/observations values, stuck actions/observations values,
256 | or repetitive actions.
257 |
258 | Perturbations:
259 | Perturbs physical quantities of the environment. These perturbations are
260 | non-stationary and are governed by a scheduler.
261 |
262 | Dimensionality:
263 | Adds extra dummy features to observations to increase dimensionality of the
264 | state space.
265 |
266 | Multi-Objective Reward:
267 | Adds additional objectives and specifies objectives interaction (e.g., sum).
268 | """
269 |
270 | def __init__(self, safety_spec, delay_spec, noise_spec, perturb_spec,
271 | dimensionality_spec, multiobj_spec, **kwargs):
272 | """Initialize the RealWorldBalance task.
273 |
274 | Args:
275 | safety_spec: dictionary that specifies the safety specifications of the
276 | task. It may contain the following fields:
277 | enable- bool that represents whether safety specifications are enabled.
278 | constraints- list of class methods returning boolean constraint
279 | satisfactions.
280 | limits- dictionary of constants used by the functions in 'constraints'.
281 | safety_coeff - a scalar between 1 and 0 that scales safety constraints,
282 | 1 producing the base constraints, and 0 likely producing an
283 | unsolveable task.
284 | observations- a default-True boolean that toggles the whether a vector
285 | of satisfied constraints is added to observations.
286 | delay_spec: dictionary that specifies the delay specifications of the
287 | task. It may contain the following fields:
288 | enable- bool that represents whether delay specifications are enabled.
289 | actions- integer indicating the number of steps actions are being
290 | delayed.
291 | observations- integer indicating the number of steps observations are
292 | being delayed.
293 | rewards- integer indicating the number of steps observations are being
294 | delayed.
295 | noise_spec: dictionary that specifies the noise specifications of the
296 | task. It may contains the following fields:
297 | gaussian- dictionary that specifies the white Gaussian additive noise.
298 | It may contain the following fields:
299 | enable- bool that represents whether noise specifications are enabled.
300 | actions- float inidcating the standard deviation of a white Gaussian
301 | noise added to each action.
302 | observations- similarly, additive white Gaussian noise to each
303 | returned observation.
304 | dropped- dictionary that specifies the dropped values noise.
305 | It may contain the following fields:
306 | enable- bool that represents whether dropped values specifications are
307 | enabled.
308 | observations_prob- float in [0,1] indicating the probability of
309 | dropping each observation component independently.
310 | observations_steps- positive integer indicating the number of time
311 | steps of dropping a value (setting to zero) if dropped.
312 | actions_prob- float in [0,1] indicating the probability of dropping
313 | each action component independently.
314 | actions_steps- positive integer indicating the number of time steps of
315 | dropping a value (setting to zero) if dropped.
316 | stuck- dictionary that specifies the stuck values noise.
317 | It may contain the following fields:
318 | enable- bool that represents whether stuck values specifications are
319 | enabled.
320 | observations_prob- float in [0,1] indicating the probability of each
321 | observation component becoming stuck.
322 | observations_steps- positive integer indicating the number of time
323 | steps an observation (or components of) stays stuck.
324 | actions_prob- float in [0,1] indicating the probability of each
325 | action component becoming stuck.
326 | actions_steps- positive integer indicating the number of time
327 | steps an action (or components of) stays stuck.
328 | repetition- dictionary that specifies the repetition statistics.
329 | It may contain the following fields:
330 | enable- bool that represents whether repetition specifications are
331 | enabled.
332 | actions_prob- float in [0,1] indicating the probability of the actions
333 | to be repeated in the following steps.
334 | actions_steps- positive integer indicating the number of time steps of
335 | repeating the same action if it to be repeated.
336 | perturb_spec: dictionary that specifies the perturbation specifications
337 | of the task. It may contain the following fields:
338 | enable- bool that represents whether perturbation specifications are
339 | enabled.
340 | period- int, number of episodes between updates perturbation updates.
341 | param - string indicating which parameter to perturb (currently
342 | supporting pole_length, pole_mass, joint_damping, slider_damping).
343 | scheduler- string inidcating the scheduler to apply to the perturbed
344 | parameter (currently supporting constant, random_walk, drift_pos,
345 | drift_neg, cyclic_pos, cyclic_neg, uniform, and saw_wave).
346 | start - float indicating the initial value of the perturbed parameter.
347 | min - float indicating the minimal value the perturbed parameter may be.
348 | max - float indicating the maximal value the perturbed parameter may be.
349 | std - float indicating the standard deviation of the white noise for the
350 | scheduling process.
351 | dimensionality_spec: dictionary that specifies the added dimensions to the
352 | state space. It may contain the following fields:
353 | enable- bool that represents whether dimensionality specifications are
354 | enabled.
355 | num_random_state_observations - num of random (unit Gaussian)
356 | observations to add.
357 | multiobj_spec: dictionary that sets up the multi-objective challenge.
358 | The challenge works by providing an `Objective` object which describes
359 | both numerical objectives and a reward-merging method that allow to both
360 | observe the various objectives in the observation and affect the
361 | returned reward in a manner defined by the Objective object.
362 | enable- bool that represents whether delay multi-objective
363 | specifications are enabled.
364 | objective - either a string which will load an `Objective` class from
365 | utils.multiobj_objectives.OBJECTIVES, or an Objective object which
366 | subclasses utils.multiobj_objectives.Objective.
367 | reward - boolean indicating whether to add the multiobj objective's
368 | reward to the environment's returned reward.
369 | coeff - a number in [0,1] that is passed into the Objective object to
370 | change the mix between the original reward and the Objective's
371 | rewards.
372 | observed - boolean indicating whether the defined objectives should be
373 | added to the observation.
374 | **kwargs: extra parameters passed to parent class (cartpole.Balance)
375 | """
376 | # Initialize parent classes.
377 | realworld_env.Base.__init__(self)
378 | cartpole.Balance.__init__(self, **kwargs)
379 |
380 | # Safety setup.
381 | self._setup_safety(safety_spec)
382 |
383 | # Delay setup.
384 | realworld_env.Base._setup_delay(self, delay_spec)
385 |
386 | # Noise setup.
387 | realworld_env.Base._setup_noise(self, noise_spec)
388 |
389 | # Perturb setup.
390 | self._setup_perturb(perturb_spec)
391 |
392 | # Dimensionality setup
393 | realworld_env.Base._setup_dimensionality(self, dimensionality_spec)
394 |
395 | # Multi-objective setup
396 | realworld_env.Base._setup_multiobj(self, multiobj_spec)
397 |
398 | # Safety methods.
399 | def _setup_safety(self, safety_spec):
400 | """Setup for the safety specifications of the task."""
401 | self._safety_enabled = safety_spec.get('enable', False)
402 | self._safety_observed = safety_spec.get('observations', True)
403 |
404 | if self._safety_enabled:
405 | # Add safety specifications.
406 | if 'constraints' in safety_spec:
407 | self.constraints = safety_spec['constraints']
408 | else:
409 | self.constraints = collections.OrderedDict([
410 | ('slider_pos_constraint', slider_pos_constraint),
411 | ('slider_accel_constraint', slider_accel_constraint),
412 | ('balance_velocity_constraint', balance_velocity_constraint)
413 | ])
414 | if 'limits' in safety_spec:
415 | self.limits = safety_spec['limits']
416 | else:
417 | if 'safety_coeff' in safety_spec:
418 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1:
419 | raise ValueError(
420 | 'safety_coeff should be in [0,1], but got {}'.format(
421 | safety_spec['safety_coeff']))
422 | safety_coeff = safety_spec['safety_coeff']
423 | else:
424 | safety_coeff = 1
425 | self.limits = {
426 | 'slider_pos_constraint':
427 | safety_coeff * np.array([-1.5, 1.5]), # m
428 | 'balance_velocity_constraint':
429 | np.array([(1 - safety_coeff) / 0.5 + 0.15,
430 | safety_coeff * 0.5]), # rad, rad/s
431 | 'slider_accel_constraint':
432 | safety_coeff * 10, # m/s^2
433 | 'action_roc_constraint': safety_coeff * 1.5
434 | }
435 | self._constraints_obs = np.ones(len(self.constraints), dtype=bool)
436 |
437 | def safety_vars(self, physics):
438 | safety_vars = collections.OrderedDict(
439 | slider_pos=physics.cart_position().copy(),
440 | joint_angle_cos=physics.pole_angle_cosine().copy(),
441 | joint_vel=np.abs(physics.angular_vel().copy()),
442 | slider_accel=np.abs(physics.named.data.qacc['slider'].copy()),
443 | actions=physics.control(),)
444 | return safety_vars
445 |
446 | def _setup_perturb(self, perturb_spec):
447 | """Setup for the perturbations specification of the task."""
448 | self._perturb_enabled = perturb_spec.get('enable', False)
449 | self._perturb_period = perturb_spec.get('period', 1)
450 |
451 | if self._perturb_enabled:
452 | # Add perturbations specifications.
453 | self._perturb_param = perturb_spec.get('param', 'pole_length')
454 | # Making sure object to perturb is supported.
455 | if self._perturb_param not in PERTURB_PARAMS:
456 | raise ValueError("""param was: {}. Currently only supporting {}.
457 | """.format(self._perturb_param, PERTURB_PARAMS))
458 |
459 | # Setting perturbation function.
460 | self._perturb_scheduler = perturb_spec.get('scheduler', 'constant')
461 | if self._perturb_scheduler not in realworld_env.PERTURB_SCHEDULERS:
462 | raise ValueError("""scheduler was: {}. Currently only supporting {}.
463 | """.format(self._perturb_scheduler, realworld_env.PERTURB_SCHEDULERS))
464 |
465 | # Setting perturbation process parameters.
466 | if self._perturb_param == 'pole_length':
467 | self._perturb_cur = perturb_spec.get('start', 1.)
468 | self._perturb_start = perturb_spec.get('start', 1.)
469 | self._perturb_min = perturb_spec.get('min', 0.3)
470 | self._perturb_max = perturb_spec.get('max', 3.)
471 | self._perturb_std = perturb_spec.get('std', 0.3)
472 | elif self._perturb_param == 'pole_mass':
473 | self._perturb_cur = perturb_spec.get('start', 0.1)
474 | self._perturb_start = perturb_spec.get('start', 0.1)
475 | self._perturb_min = perturb_spec.get('min', 0.1)
476 | self._perturb_max = perturb_spec.get('max', 10.)
477 | self._perturb_std = perturb_spec.get('std', 0.5)
478 | elif self._perturb_param == 'joint_damping':
479 | self._perturb_cur = perturb_spec.get('start', 2e-6)
480 | self._perturb_start = perturb_spec.get('start', 2e-6)
481 | self._perturb_min = perturb_spec.get('min', 2e-6)
482 | self._perturb_max = perturb_spec.get('max', 2e-1)
483 | self._perturb_std = perturb_spec.get('std', 2e-2)
484 | elif self._perturb_param == 'slider_damping':
485 | self._perturb_cur = perturb_spec.get('start', 5e-4)
486 | self._perturb_start = perturb_spec.get('start', 5e-4)
487 | self._perturb_min = perturb_spec.get('min', 5e-4)
488 | self._perturb_max = perturb_spec.get('max', 3.0)
489 | self._perturb_std = perturb_spec.get('std', 0.3)
490 |
491 | def update_physics(self):
492 | """Returns a new Physics object with perturbed parameter."""
493 | # Generate the new perturbed parameter.
494 | realworld_env.Base._generate_parameter(self)
495 |
496 | # Create new physics object with the perturb parameter.
497 | xml_string = common.read_model('cartpole.xml')
498 | mjcf = etree.fromstring(xml_string)
499 |
500 | if self._perturb_param in ['pole_length', 'pole_mass']:
501 | pole = mjcf.find('./default/default/geom')
502 | if self._perturb_param == 'pole_length':
503 | pole.set('fromto', '0 0 0 0 0 {}'.format(self._perturb_cur))
504 | pole.set('mass', str(self._perturb_cur / 10.))
505 | elif self._perturb_param == 'pole_mass':
506 | pole.set('mass', str(self._perturb_cur))
507 | elif self._perturb_param == 'joint_damping':
508 | pole_joint = mjcf.find('./default/default/joint')
509 | pole_joint.set('damping', str(self._perturb_cur))
510 | elif self._perturb_param == 'slider_damping':
511 | sliders_joint = mjcf.find('./worldbody/body/joint')
512 | sliders_joint.set('damping', str(self._perturb_cur))
513 |
514 | xml_string_modified = etree.tostring(mjcf, pretty_print=True)
515 | physics = Physics.from_xml_string(xml_string_modified, common.ASSETS)
516 | return physics
517 |
518 | def before_step(self, action, physics):
519 | """Updates the environment using the action and returns a `TimeStep`."""
520 | self._last_action = physics.control()
521 | action_min = self.action_spec(physics).minimum[:]
522 | action_max = self.action_spec(physics).maximum[:]
523 | action = realworld_env.Base.before_step(self, action, action_min,
524 | action_max)
525 | cartpole.Balance.before_step(self, action, physics)
526 |
527 | def after_step(self, physics):
528 | realworld_env.Base.after_step(self, physics)
529 | cartpole.Balance.after_step(self, physics)
530 | self._last_action = None
531 |
--------------------------------------------------------------------------------
/realworldrl_suite/environments/humanoid.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Real-world control of humanoid."""
17 |
18 | import collections
19 |
20 | from dm_control.suite import common
21 | from dm_control.suite import humanoid
22 | from lxml import etree
23 | import numpy as np
24 |
25 | from realworldrl_suite.environments import realworld_env
26 | from realworldrl_suite.utils import loggers
27 | from realworldrl_suite.utils import wrappers
28 |
29 | _DEFAULT_TIME_LIMIT = 25
30 | _CONTROL_TIMESTEP = .025
31 |
32 | # Height of head above which stand reward is 1.
33 | _STAND_HEIGHT = 1.4
34 |
35 | # Horizontal speeds above which move reward is 1.
36 | _WALK_SPEED = 1
37 | _RUN_SPEED = 10
38 |
39 | PERTURB_PARAMS = ['joint_damping', 'contact_friction', 'head_size']
40 |
41 |
42 | def load(task_name, **task_kwargs):
43 | return globals()[task_name](**task_kwargs)
44 |
45 |
46 | # Task Constraints
47 | def joint_angle_constraint(env, safety_vars):
48 | """Slider must be within a certain area of the track."""
49 | joint_pos = safety_vars['joint_pos']
50 | return np.less(np.max(np.abs(joint_pos)), env.limits['joint_pos_constraint'])
51 |
52 |
53 | def joint_velocity_constraint(env, safety_vars):
54 | """Joint angle velocities must stay below a certain limit."""
55 | joint_vels = safety_vars['joint_vels']
56 | return np.less(np.max(joint_vels), env.limits['joint_velocity_constraint'])
57 |
58 |
59 | def upright_constraint(env, safety_vars):
60 | """Pelvis orientantion should remain upright."""
61 | z_up = safety_vars['z_up']
62 | return np.greater(z_up, env.limits['torso_upright_constraint'])
63 |
64 |
65 | def foot_force_constraint(env, safety_vars):
66 | """Limit foot contact forces, uses ankle forces as proxy."""
67 | force_max = np.max(
68 | [safety_vars[idx] for idx in ['rl_f', 'rr_f', 'll_f', 'lr_f']])
69 | return np.less(force_max, env.limits['foot_force_constraint'])
70 |
71 |
72 | def dangerous_fall_constraint(env, safety_vars):
73 | """Discourage dangerous falls (head & torso contact)."""
74 | force_max = np.max(
75 | [safety_vars[idx] for idx in ['head_touch', 'torso_touch']])
76 | return np.less(force_max, env.limits['dangerous_fall_constraint'])
77 |
78 |
79 | # Action rate of change constraint.
80 | action_roc_constraint = realworld_env.action_roc_constraint
81 |
82 |
83 | def realworld_stand(time_limit=_DEFAULT_TIME_LIMIT,
84 | random=None,
85 | log_output=None,
86 | environment_kwargs=None,
87 | safety_spec=None,
88 | delay_spec=None,
89 | noise_spec=None,
90 | perturb_spec=None,
91 | dimensionality_spec=None,
92 | multiobj_spec=None,
93 | combined_challenge=None):
94 | """Returns the Humanoid Stand task with specified real world attributes.
95 |
96 | Args:
97 | time_limit: Integer length of task
98 | random: random seed (unsure)
99 | log_output: String of path for pickle data logging, None disables logging
100 | environment_kwargs: additional kwargs for environment.
101 | safety_spec: dictionary that specifies the safety specifications.
102 | delay_spec: dictionary that specifies the delay specifications.
103 | noise_spec: dictionary that specifies the noise specifications.
104 | perturb_spec: dictionary that specifies the perturbations specifications.
105 | dimensionality_spec: dictionary that specifies extra observation features.
106 | multiobj_spec: dictionary that specifies complementary objectives.
107 | combined_challenge: string that can be 'easy', 'medium', or 'hard'.
108 | Specifying the combined challenge (can't be used with any other spec).
109 | """
110 | physics = humanoid.Physics.from_xml_string(*humanoid.get_model_and_assets())
111 | safety_spec = safety_spec or {}
112 | delay_spec = delay_spec or {}
113 | noise_spec = noise_spec or {}
114 | perturb_spec = perturb_spec or {}
115 | dimensionality_spec = dimensionality_spec or {}
116 | multiobj_spec = multiobj_spec or {}
117 | # Check and update for combined challenge.
118 | (delay_spec, noise_spec,
119 | perturb_spec, dimensionality_spec) = (
120 | realworld_env.get_combined_challenge(
121 | combined_challenge, delay_spec, noise_spec, perturb_spec,
122 | dimensionality_spec))
123 | # Updating perturbation parameters if combined_challenge.
124 | if combined_challenge == 'easy':
125 | perturb_spec.update(
126 | {'param': 'contact_friction', 'min': 0.6, 'max': 0.8, 'std': 0.02})
127 | elif combined_challenge == 'medium':
128 | perturb_spec.update(
129 | {'param': 'contact_friction', 'min': 0.5, 'max': 0.9, 'std': 0.04})
130 | elif combined_challenge == 'hard':
131 | perturb_spec.update(
132 | {'param': 'contact_friction', 'min': 0.4, 'max': 1.0, 'std': 0.06})
133 |
134 | task = RealWorldHumanoid(
135 | move_speed=0,
136 | pure_state=False,
137 | random=random,
138 | safety_spec=safety_spec,
139 | delay_spec=delay_spec,
140 | noise_spec=noise_spec,
141 | perturb_spec=perturb_spec,
142 | dimensionality_spec=dimensionality_spec,
143 | multiobj_spec=multiobj_spec)
144 | environment_kwargs = environment_kwargs or {}
145 | if log_output:
146 | logger = loggers.PickleLogger(path=log_output)
147 | else:
148 | logger = None
149 | return wrappers.LoggingEnv(
150 | physics,
151 | task,
152 | logger=logger,
153 | time_limit=time_limit,
154 | control_timestep=_CONTROL_TIMESTEP,
155 | **environment_kwargs)
156 |
157 |
158 | def realworld_walk(time_limit=_DEFAULT_TIME_LIMIT,
159 | random=None,
160 | log_output=None,
161 | environment_kwargs=None,
162 | safety_spec=None,
163 | delay_spec=None,
164 | noise_spec=None,
165 | perturb_spec=None,
166 | dimensionality_spec=None,
167 | multiobj_spec=None,
168 | combined_challenge=None):
169 | """Returns the Walk task with specified real world attributes.
170 |
171 | Args:
172 | time_limit: Integer length of task
173 | random: random seed (unsure)
174 | log_output: String of path for pickle data logging, None disables logging
175 | environment_kwargs: additional kwargs for environment.
176 | safety_spec: dictionary that specifies the safety specifications.
177 | delay_spec: dictionary that specifies the delay specifications.
178 | noise_spec: dictionary that specifies the noise specifications.
179 | perturb_spec: dictionary that specifies the perturbations specifications.
180 | dimensionality_spec: dictionary that specifies extra observation features.
181 | multiobj_spec: dictionary that specifies complementary objectives.
182 | combined_challenge: string that can be 'easy', 'medium', or 'hard'.
183 | Specifying the combined challenge (can't be used with any other spec).
184 | """
185 | physics = humanoid.Physics.from_xml_string(*humanoid.get_model_and_assets())
186 | safety_spec = safety_spec or {}
187 | delay_spec = delay_spec or {}
188 | noise_spec = noise_spec or {}
189 | perturb_spec = perturb_spec or {}
190 | dimensionality_spec = dimensionality_spec or {}
191 | multiobj_spec = multiobj_spec or {}
192 | # Check and update for combined challenge.
193 | (delay_spec, noise_spec,
194 | perturb_spec, dimensionality_spec) = (
195 | realworld_env.get_combined_challenge(
196 | combined_challenge, delay_spec, noise_spec, perturb_spec,
197 | dimensionality_spec))
198 | # Updating perturbation parameters if combined_challenge.
199 | if combined_challenge == 'easy':
200 | perturb_spec.update(
201 | {'param': 'contact_friction', 'min': 0.6, 'max': 0.8, 'std': 0.02})
202 | elif combined_challenge == 'medium':
203 | perturb_spec.update(
204 | {'param': 'contact_friction', 'min': 0.5, 'max': 0.9, 'std': 0.04})
205 | elif combined_challenge == 'hard':
206 | perturb_spec.update(
207 | {'param': 'contact_friction', 'min': 0.4, 'max': 1.0, 'std': 0.06})
208 |
209 | if 'limits' not in safety_spec:
210 | if 'safety_coeff' in safety_spec:
211 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1:
212 | raise ValueError('safety_coeff should be in [0,1], but got {}'.format(
213 | safety_spec['safety_coeff']))
214 | safety_coeff = safety_spec['safety_coeff']
215 | else:
216 | safety_coeff = 1
217 |
218 | safety_spec['limits'] = {
219 | 'joint_pos_constraint': safety_coeff * np.pi, # rad
220 | 'joint_velocity_constraint': safety_coeff * 90, # rad/s
221 | 'foot_force_constraint': safety_coeff * 8000, # newtons
222 | 'dangerous_fall_constraint': safety_coeff * 100, # newtons
223 | 'torso_upright_constraint': (1 - safety_coeff), # vector magnitude
224 | 'action_roc_constraint': safety_coeff * 1.85,
225 | }
226 |
227 | task = RealWorldHumanoid(
228 | move_speed=1,
229 | pure_state=False,
230 | random=random,
231 | safety_spec=safety_spec,
232 | delay_spec=delay_spec,
233 | noise_spec=noise_spec,
234 | perturb_spec=perturb_spec,
235 | dimensionality_spec=dimensionality_spec,
236 | multiobj_spec=multiobj_spec)
237 | environment_kwargs = environment_kwargs or {}
238 | if log_output:
239 | logger = loggers.PickleLogger(path=log_output)
240 | else:
241 | logger = None
242 | return wrappers.LoggingEnv(
243 | physics,
244 | task,
245 | logger=logger,
246 | time_limit=time_limit,
247 | control_timestep=_CONTROL_TIMESTEP,
248 | **environment_kwargs)
249 |
250 |
251 | class RealWorldHumanoid(realworld_env.Base, humanoid.Humanoid):
252 | """A Humanoid task with real-world specifications.
253 |
254 | Subclasses dm_control.suite.humanoid.
255 |
256 | Safety:
257 | Adds a set of constraints on the task.
258 | Returns an additional entry in the observations ('constraints') in the
259 | length of the number of the constraints, where each entry is True if the
260 | constraint is satisfied and False otherwise.
261 |
262 | Delays:
263 | Adds actions, observations, and rewards delays.
264 | Actions delay is the number of steps between passing the action to the
265 | environment to when it is actually performed, and observations (rewards)
266 | delay is the offset of freshness of the returned observation (reward) after
267 | performing a step.
268 |
269 | Noise:
270 | Adds action or observation noise.
271 | Different noise include: white Gaussian actions/observations,
272 | dropped actions/observations values, stuck actions/observations values,
273 | or repetitive actions.
274 |
275 | Perturbations:
276 | Perturbs physical quantities of the environment. These perturbations are
277 | non-stationary and are governed by a scheduler.
278 |
279 | Dimensionality:
280 | Adds extra dummy features to observations to increase dimensionality of the
281 | state space.
282 |
283 | Multi-Objective Reward:
284 | Adds additional objectives and specifies objectives interaction (e.g., sum).
285 | """
286 |
287 | def __init__(self, move_speed, pure_state, safety_spec, delay_spec,
288 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec,
289 | **kwargs):
290 | """Initialize the RealWorldHumanoid task.
291 |
292 | Args:
293 | move_speed: float. If this value is zero, reward is given simply for
294 | standing up. Otherwise this specifies a target horizontal velocity for
295 | the walking task.
296 | pure_state: bool. Whether the observations consist of the pure MuJoCo
297 | state or includes some useful features thereof.
298 | safety_spec: dictionary that specifies the safety specifications of the
299 | task. It may contain the following fields:
300 | enable- bool that represents whether safety specifications are enabled.
301 | constraints- list of class methods returning boolean constraint
302 | satisfactions.
303 | limits- dictionary of constants used by the functions in 'constraints'.
304 | safety_coeff - a scalar between 1 and 0 that scales safety constraints,
305 | 1 producing the base constraints, and 0 likely producing an
306 | unsolveable task.
307 | observations- a default-True boolean that toggles the whether a vector
308 | of satisfied constraints is added to observations.
309 | delay_spec: dictionary that specifies the delay specifications of the
310 | task. It may contain the following fields:
311 | enable- bool that represents whether delay specifications are enabled.
312 | actions- integer indicating the number of steps actions are being
313 | delayed.
314 | observations- integer indicating the number of steps observations are
315 | being delayed.
316 | rewards- integer indicating the number of steps observations are being
317 | delayed.
318 | noise_spec: dictionary that specifies the noise specifications of the
319 | task. It may contains the following fields:
320 | gaussian- dictionary that specifies the white Gaussian additive noise.
321 | It may contain the following fields:
322 | enable- bool that represents whether noise specifications are enabled.
323 | actions- float inidcating the standard deviation of a white Gaussian
324 | noise added to each action.
325 | observations- similarly, additive white Gaussian noise to each
326 | returned observation.
327 | dropped- dictionary that specifies the dropped values noise.
328 | It may contain the following fields:
329 | enable- bool that represents whether dropped values specifications are
330 | enabled.
331 | observations_prob- float in [0,1] indicating the probability of
332 | dropping each observation component independently.
333 | observations_steps- positive integer indicating the number of time
334 | steps of dropping a value (setting to zero) if dropped.
335 | actions_prob- float in [0,1] indicating the probability of dropping
336 | each action component independently.
337 | actions_steps- positive integer indicating the number of time steps of
338 | dropping a value (setting to zero) if dropped.
339 | stuck- dictionary that specifies the stuck values noise.
340 | It may contain the following fields:
341 | enable- bool that represents whether stuck values specifications are
342 | enabled.
343 | observations_prob- float in [0,1] indicating the probability of each
344 | observation component becoming stuck.
345 | observations_steps- positive integer indicating the number of time
346 | steps an observation (or components of) stays stuck.
347 | actions_prob- float in [0,1] indicating the probability of each
348 | action component becoming stuck.
349 | actions_steps- positive integer indicating the number of time
350 | steps an action (or components of) stays stuck.
351 | repetition- dictionary that specifies the repetition statistics.
352 | It may contain the following fields:
353 | enable- bool that represents whether repetition specifications are
354 | enabled.
355 | actions_prob- float in [0,1] indicating the probability of the actions
356 | to be repeated in the following steps.
357 | actions_steps- positive integer indicating the number of time steps of
358 | repeating the same action if it to be repeated.
359 | perturb_spec: dictionary that specifies the perturbation specifications
360 | of the task. It may contain the following fields:
361 | enable- bool that represents whether perturbation specifications are
362 | enabled.
363 | period- int, number of episodes between updates perturbation updates.
364 | param - string indicating which parameter to perturb (currently
365 | supporting joint_damping, contact_friction, head_size).
366 | scheduler- string inidcating the scheduler to apply to the perturbed
367 | parameter (currently supporting constant, random_walk, drift_pos,
368 | drift_neg, cyclic_pos, cyclic_neg, uniform, and saw_wave).
369 | start - float indicating the initial value of the perturbed parameter.
370 | min - float indicating the minimal value the perturbed parameter may be.
371 | max - float indicating the maximal value the perturbed parameter may be.
372 | std - float indicating the standard deviation of the white noise for the
373 | scheduling process.
374 | dimensionality_spec: dictionary that specifies the added dimensions to the
375 | state space. It may contain the following fields:
376 | enable - bool that represents whether dimensionality specifications are
377 | enabled.
378 | num_random_state_observations - num of random (unit Gaussian)
379 | observations to add.
380 | multiobj_spec: dictionary that sets up the multi-objective challenge.
381 | The challenge works by providing an `Objective` object which describes
382 | both numerical objectives and a reward-merging method that allow to both
383 | observe the various objectives in the observation and affect the
384 | returned reward in a manner defined by the Objective object.
385 | enable- bool that represents whether delay multi-objective
386 | specifications are enabled.
387 | objective - either a string which will load an `Objective` class from
388 | utils.multiobj_objectives.OBJECTIVES, or an Objective object which
389 | subclasses utils.multiobj_objectives.Objective.
390 | reward - boolean indicating whether to add the multiobj objective's
391 | reward to the environment's returned reward.
392 | coeff - a number in [0,1] that is passed into the Objective object to
393 | change the mix between the original reward and the Objective's
394 | rewards.
395 | observed - boolean indicating whether the defined objectives should be
396 | added to the observation.
397 | **kwargs: extra parameters passed to parent class (humanoid.Humanoid)
398 | """
399 | # Initialize parent classes.
400 | realworld_env.Base.__init__(self)
401 | humanoid.Humanoid.__init__(self, move_speed, pure_state, **kwargs)
402 |
403 | # Safety setup.
404 | self._setup_safety(safety_spec)
405 |
406 | # Delay setup.
407 | realworld_env.Base._setup_delay(self, delay_spec)
408 |
409 | # Noise setup.
410 | realworld_env.Base._setup_noise(self, noise_spec)
411 |
412 | # Perturb setup.
413 | self._setup_perturb(perturb_spec)
414 |
415 | # Dimensionality setup
416 | realworld_env.Base._setup_dimensionality(self, dimensionality_spec)
417 |
418 | # Multi-objective setup
419 | realworld_env.Base._setup_multiobj(self, multiobj_spec)
420 |
421 | # Safety methods.
422 | def _setup_safety(self, safety_spec):
423 | """Setup for the safety specifications of the task."""
424 | self._safety_enabled = safety_spec.get('enable', False)
425 | self._safety_observed = safety_spec.get('observations', True)
426 |
427 | if self._safety_enabled:
428 | # Add safety specifications.
429 | if 'constraints' in safety_spec:
430 | self.constraints = safety_spec['constraints']
431 | else:
432 | self.constraints = collections.OrderedDict([
433 | ('joint_angle_constraint', joint_angle_constraint),
434 | ('joint_velocity_constraint', joint_velocity_constraint),
435 | ('upright_constraint', upright_constraint),
436 | ('dangerous_fall_constraint', dangerous_fall_constraint),
437 | ('foot_force_constraint', foot_force_constraint)
438 | ])
439 | if 'limits' in safety_spec:
440 | self.limits = safety_spec['limits']
441 | else:
442 | if 'safety_coeff' in safety_spec:
443 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1:
444 | raise ValueError(
445 | 'safety_coeff should be in [0,1], but got {}'.format(
446 | safety_spec['safety_coeff']))
447 | safety_coeff = safety_spec['safety_coeff']
448 | else:
449 | safety_coeff = 1
450 | self.limits = {
451 | 'joint_pos_constraint': safety_coeff * np.pi, # rad
452 | 'joint_velocity_constraint': safety_coeff * 90, # rad/s
453 | 'foot_force_constraint': safety_coeff * 8000, # newtons
454 | 'dangerous_fall_constraint': safety_coeff * 100, # newtons
455 | 'torso_upright_constraint':
456 | (1 - safety_coeff), # vector magnitude
457 | 'action_roc_constraint': safety_coeff * 1.85,
458 | }
459 | self._constraints_obs = np.ones(len(self.constraints), dtype=bool)
460 |
461 | def safety_vars(self, physics):
462 | """Centralized retrieval of safety-related variables to simplify logging."""
463 | safety_vars = collections.OrderedDict(
464 | joint_pos=physics.named.data.qpos[7:].copy(),
465 | joint_vels=np.abs(physics.named.data.qvel[7:]).copy(),
466 | z_up=physics.torso_upright(),
467 | rl_f=np.linalg.norm(
468 | physics.named.data.sensordata['right_left_foot_touch'].copy()),
469 | rr_f=np.linalg.norm(
470 | physics.named.data.sensordata['right_right_foot_touch'].copy()),
471 | ll_f=np.linalg.norm(
472 | physics.named.data.sensordata['left_left_foot_touch'].copy()),
473 | lr_f=np.linalg.norm(
474 | physics.named.data.sensordata['left_right_foot_touch'].copy()),
475 | head_touch=np.linalg.norm(
476 | physics.named.data.sensordata['head_touch'].copy()),
477 | torso_touch=np.linalg.norm(
478 | physics.named.data.sensordata['torso_touch'].copy()),
479 | actions=physics.control(),
480 | )
481 | return safety_vars
482 |
483 | def _setup_perturb(self, perturb_spec):
484 | """Setup for the perturbations specification of the task."""
485 | self._perturb_enabled = perturb_spec.get('enable', False)
486 | self._perturb_period = perturb_spec.get('period', 1)
487 |
488 | if self._perturb_enabled:
489 | # Add perturbations specifications.
490 | self._perturb_param = perturb_spec.get('param', 'contact_friction')
491 | # Making sure object to perturb is supported.
492 | if self._perturb_param not in PERTURB_PARAMS:
493 | raise ValueError("""param was: {}. Currently only supporting {}.
494 | """.format(self._perturb_param, PERTURB_PARAMS))
495 |
496 | # Setting perturbation function.
497 | self._perturb_scheduler = perturb_spec.get('scheduler', 'constant')
498 | if self._perturb_scheduler not in realworld_env.PERTURB_SCHEDULERS:
499 | raise ValueError("""scheduler was: {}. Currently only supporting {}.
500 | """.format(self._perturb_scheduler, realworld_env.PERTURB_SCHEDULERS))
501 |
502 | # Setting perturbation process parameters.
503 | if self._perturb_param == 'contact_friction':
504 | self._perturb_cur = perturb_spec.get('start', 0.7)
505 | self._perturb_start = perturb_spec.get('start', 0.7)
506 | self._perturb_min = perturb_spec.get('min', 0.05)
507 | self._perturb_max = perturb_spec.get('max', 1.2)
508 | self._perturb_std = perturb_spec.get('std', 0.1)
509 | elif self._perturb_param == 'joint_damping':
510 | self._perturb_cur = perturb_spec.get('start', 0.2)
511 | self._perturb_start = perturb_spec.get('start', 0.2)
512 | self._perturb_min = perturb_spec.get('min', 0.01)
513 | self._perturb_max = perturb_spec.get('max', 2.5)
514 | self._perturb_std = perturb_spec.get('std', 0.2)
515 | elif self._perturb_param == 'head_size':
516 | self._perturb_cur = perturb_spec.get('start', 0.09)
517 | self._perturb_start = perturb_spec.get('start', 0.09)
518 | self._perturb_min = perturb_spec.get('min', 0.01)
519 | self._perturb_max = perturb_spec.get('max', 0.19)
520 | self._perturb_std = perturb_spec.get('std', 0.02)
521 |
522 | def update_physics(self):
523 | """Returns a new Physics object with perturbed parameter."""
524 | # Generate the new perturbed parameter.
525 | realworld_env.Base._generate_parameter(self)
526 |
527 | # Create new physics object with the perturb parameter.
528 | xml_string = common.read_model('humanoid.xml')
529 | mjcf = etree.fromstring(xml_string)
530 |
531 | if self._perturb_param == 'joint_damping':
532 | # Joint damping is a coefficient that provides a countering force
533 | # proportional to angular velocity.
534 | joint_damping = mjcf.find('./default/default/joint')
535 | joint_damping.set('damping', str(self._perturb_cur))
536 | elif self._perturb_param == 'contact_friction':
537 | # Need to set the friction co-efficient on floor and body geoms:
538 | geom_contact = mjcf.find('./default/default/geom') # Body geom.
539 | geom_contact.set('friction', '{} .1 .1'.format(self._perturb_cur))
540 | floor_contact = mjcf.find('./worldbody/geom') # Floor geom.
541 | floor_contact.set('friction', '{} .1 .1'.format(self._perturb_cur))
542 | elif self._perturb_param == 'head_size':
543 | geom_head = mjcf.find('./worldbody/body/body/geom')
544 | geom_head.set('size', '{}'.format(self._perturb_cur))
545 | xml_string_modified = etree.tostring(mjcf, pretty_print=True)
546 | physics = Physics.from_xml_string(xml_string_modified, common.ASSETS)
547 |
548 | return physics
549 |
550 | def before_step(self, action, physics):
551 | """Updates the environment using the action and returns a `TimeStep`."""
552 | self._last_action = physics.control()
553 | action_min = self.action_spec(physics).minimum[:]
554 | action_max = self.action_spec(physics).maximum[:]
555 | action = realworld_env.Base.before_step(self, action, action_min,
556 | action_max)
557 | humanoid.Humanoid.before_step(self, action, physics)
558 |
559 | def after_step(self, physics):
560 | realworld_env.Base.after_step(self, physics)
561 | humanoid.Humanoid.after_step(self, physics)
562 | self._last_action = None
563 |
564 |
565 | class Physics(humanoid.Physics):
566 | """Inherits from humanoid.Physics."""
567 |
--------------------------------------------------------------------------------
/realworldrl_suite/environments/manipulator.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Real-World Planar Manipulator domain."""
17 |
18 | import collections
19 |
20 | from dm_control.suite import common
21 | from dm_control.suite import manipulator
22 | from lxml import etree
23 | import numpy as np
24 |
25 | from realworldrl_suite.environments import realworld_env
26 | from realworldrl_suite.utils import loggers
27 | from realworldrl_suite.utils import wrappers
28 |
29 | _CONTROL_TIMESTEP = .01 # (Seconds)
30 | _TIME_LIMIT = 10 # (Seconds)
31 | _ARM_JOINTS = ['arm_shoulder', 'arm_elbow'] # Used for constraints
32 |
33 | PERTURB_PARAMS = ['lower_arm_length', 'root_damping', 'shoulder_damping']
34 |
35 |
36 | def load(task_name, **task_kwargs):
37 | return globals()[task_name](**task_kwargs)
38 |
39 |
40 | # Task Constraints
41 | def joint_angle_constraint(env, safety_vars):
42 | """Joint angles must be within a certain area of the track."""
43 | joint_pos = safety_vars['joint_angle']
44 | return (np.greater(joint_pos,
45 | env.limits['joint_angle_constraint'][0]).all() and
46 | np.less(joint_pos, env.limits['joint_angle_constraint'][1])).all()
47 |
48 |
49 | def joint_velocity_constraint(env, safety_vars):
50 | """Joint angle velocities must stay below a certain limit."""
51 | joint_vels = safety_vars['joint_vels']
52 | return np.less(np.max(joint_vels), env.limits['joint_velocity_constraint'])
53 |
54 |
55 | def joint_accel_constraint(env, safety_vars):
56 | """Joint angle velocities must stay below a certain limit."""
57 | joint_accels = safety_vars['joint_accels']
58 | return np.less(np.max(joint_accels), env.limits['joint_accel_constraint'])
59 |
60 |
61 | def grasp_force_constraint(env, safety_vars):
62 | """Limit gripper contact forces."""
63 | return np.less(
64 | np.max(safety_vars['grasp_force']), env.limits['grasp_force_constraint'])
65 |
66 |
67 | # Action rate of change constraint.
68 | action_roc_constraint = realworld_env.action_roc_constraint
69 |
70 |
71 | def gen_task(use_peg,
72 | insert,
73 | fully_observable=True,
74 | time_limit=_TIME_LIMIT,
75 | random=None,
76 | log_output=None,
77 | environment_kwargs=None,
78 | safety_spec=None,
79 | delay_spec=None,
80 | noise_spec=None,
81 | perturb_spec=None,
82 | dimensionality_spec=None,
83 | multiobj_spec=None,
84 | combined_challenge=None):
85 | """Returns the Manipulator Bring task with specified real world attributes.
86 |
87 | Args:
88 | use_peg: A `bool`, whether to replace the ball prop with the peg prop.
89 | insert: A `bool`, whether to insert the prop in a receptacle.
90 | fully_observable: A `bool`, whether the observation should contain the
91 | position and velocity of the object being manipulated and the target
92 | location.
93 | time_limit: Integer length of task
94 | random: Optional, either a `numpy.random.RandomState` instance, an integer
95 | seed for creating a new `RandomState`, or None to select a seed
96 | automatically (default).
97 | log_output: String of path for pickle data logging, None disables logging
98 | environment_kwargs: additional kwargs for environment
99 | safety_spec: dictionary that specifies the safety specifications.
100 | delay_spec: dictionary that specifies the delay.
101 | noise_spec: dictionary that specifies the noise specifications.
102 | perturb_spec: dictionary that specifies the perturbations specifications.
103 | dimensionality_spec: dictionary that specifies extra observation features.
104 | multiobj_spec: dictionary that specifies complementary objectives.
105 | combined_challenge: string that can be 'easy', 'medium', or 'hard'.
106 | Specifying the combined challenge (can't be used with any other spec).
107 | """
108 | physics = manipulator.Physics.from_xml_string(
109 | *manipulator.make_model(use_peg, insert))
110 | safety_spec = safety_spec or {}
111 | delay_spec = delay_spec or {}
112 | noise_spec = noise_spec or {}
113 | perturb_spec = perturb_spec or {}
114 | dimensionality_spec = dimensionality_spec or {}
115 | multiobj_spec = multiobj_spec or {}
116 | # Check and update for combined challenge.
117 | (delay_spec, noise_spec,
118 | perturb_spec, dimensionality_spec) = (
119 | realworld_env.get_combined_challenge(
120 | combined_challenge, delay_spec, noise_spec, perturb_spec,
121 | dimensionality_spec))
122 |
123 | task = RealWorldBring(
124 | use_peg=use_peg,
125 | insert=insert,
126 | fully_observable=fully_observable,
127 | safety_spec=safety_spec,
128 | delay_spec=delay_spec,
129 | noise_spec=noise_spec,
130 | perturb_spec=perturb_spec,
131 | dimensionality_spec=dimensionality_spec,
132 | multiobj_spec=multiobj_spec,
133 | random=random)
134 |
135 | environment_kwargs = environment_kwargs or {}
136 | if log_output:
137 | logger = loggers.PickleLogger(path=log_output)
138 | else:
139 | logger = None
140 | return wrappers.LoggingEnv(
141 | physics,
142 | task,
143 | logger=logger,
144 | control_timestep=_CONTROL_TIMESTEP,
145 | time_limit=time_limit,
146 | **environment_kwargs)
147 |
148 |
149 | def realworld_bring_ball(fully_observable=True,
150 | time_limit=_TIME_LIMIT,
151 | random=None,
152 | log_output=None,
153 | environment_kwargs=None,
154 | safety_spec=None,
155 | delay_spec=None,
156 | noise_spec=None,
157 | perturb_spec=None,
158 | dimensionality_spec=None,
159 | multiobj_spec=None,
160 | combined_challenge=None):
161 | """Returns manipulator bring task with the ball prop."""
162 | use_peg = False
163 | insert = False
164 | return gen_task(use_peg, insert, fully_observable, time_limit, random,
165 | log_output, environment_kwargs, safety_spec, delay_spec,
166 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec,
167 | combined_challenge)
168 |
169 |
170 | def realworld_bring_peg(fully_observable=True,
171 | time_limit=_TIME_LIMIT,
172 | random=None,
173 | log_output=None,
174 | environment_kwargs=None,
175 | safety_spec=None,
176 | delay_spec=None,
177 | noise_spec=None,
178 | perturb_spec=None,
179 | dimensionality_spec=None,
180 | multiobj_spec=None,
181 | combined_challenge=None):
182 | """Returns manipulator bring task with the peg prop."""
183 | use_peg = True
184 | insert = False
185 | return gen_task(use_peg, insert, fully_observable, time_limit, random,
186 | log_output, environment_kwargs, safety_spec, delay_spec,
187 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec,
188 | combined_challenge)
189 |
190 |
191 | def realworld_insert_ball(fully_observable=True,
192 | time_limit=_TIME_LIMIT,
193 | random=None,
194 | log_output=None,
195 | environment_kwargs=None,
196 | safety_spec=None,
197 | delay_spec=None,
198 | noise_spec=None,
199 | perturb_spec=None,
200 | dimensionality_spec=None,
201 | multiobj_spec=None,
202 | combined_challenge=None):
203 | """Returns manipulator insert task with the ball prop."""
204 | use_peg = False
205 | insert = True
206 | return gen_task(use_peg, insert, fully_observable, time_limit, random,
207 | log_output, environment_kwargs, safety_spec, delay_spec,
208 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec,
209 | combined_challenge)
210 |
211 |
212 | def realworld_insert_peg(fully_observable=True,
213 | time_limit=_TIME_LIMIT,
214 | random=None,
215 | log_output=None,
216 | environment_kwargs=None,
217 | safety_spec=None,
218 | delay_spec=None,
219 | noise_spec=None,
220 | perturb_spec=None,
221 | dimensionality_spec=None,
222 | multiobj_spec=None,
223 | combined_challenge=None):
224 | """Returns manipulator insert task with the peg prop."""
225 | use_peg = True
226 | insert = True
227 | return gen_task(use_peg, insert, fully_observable, time_limit, random,
228 | log_output, environment_kwargs, safety_spec, delay_spec,
229 | noise_spec, perturb_spec, dimensionality_spec, multiobj_spec,
230 | combined_challenge)
231 |
232 |
233 | class Physics(manipulator.Physics):
234 | """Inherits from manipulator.Physics."""
235 |
236 |
237 | class RealWorldBring(realworld_env.Base, manipulator.Bring):
238 | """A Manipulator task with real-world specifications.
239 |
240 | Subclasses dm_control.suite.manipulator.
241 |
242 | Safety:
243 | Adds a set of constraints on the task.
244 | Returns an additional entry in the observations ('constraints') in the
245 | length of the number of the constraints, where each entry is True if the
246 | constraint is satisfied and False otherwise.
247 |
248 | Delays:
249 | Adds actions, observations, and rewards delays.
250 | Actions delay is the number of steps between passing the action to the
251 | environment to when it is actually performed, and observations (rewards)
252 | delay is the offset of freshness of the returned observation (reward) after
253 | performing a step.
254 |
255 | Noise:
256 | Adds action or observation noise.
257 | Different noise include: white Gaussian actions/observations,
258 | dropped actions/observations values, stuck actions/observations values,
259 | or repetitive actions.
260 |
261 | Perturbations:
262 | Perturbs physical quantities of the environment. These perturbations are
263 | non-stationary and are governed by a scheduler.
264 |
265 | Dimensionality:
266 | Adds extra dummy features to observations to increase dimensionality of the
267 | state space.
268 |
269 | Multi-Objective Reward:
270 | Adds additional objectives and specifies objectives interaction (e.g., sum).
271 | """
272 |
273 | def __init__(self,
274 | use_peg,
275 | insert,
276 | fully_observable,
277 | safety_spec,
278 | delay_spec,
279 | noise_spec,
280 | perturb_spec,
281 | dimensionality_spec,
282 | multiobj_spec,
283 | random=None,
284 | **kwargs):
285 | """Initialize the RealWorldBring task.
286 |
287 | Args:
288 | use_peg: A `bool`, whether to replace the ball prop with the peg prop.
289 | insert: A `bool`, whether to insert the prop in a receptacle.
290 | fully_observable: A `bool`, whether the observation should contain the
291 | position and velocity of the object being manipulated and the target
292 | location.
293 | safety_spec: dictionary that specifies the safety specifications of the
294 | task. It may contain the following fields:
295 | enable- bool that represents whether safety specifications are enabled.
296 | constraints- list of class methods returning boolean constraint
297 | satisfactions.
298 | limits- dictionary of constants used by the functions in 'constraints'.
299 | safety_coeff - a scalar between 1 and 0 that scales safety constraints,
300 | 1 producing the base constraints, and 0 likely producing an
301 | unsolveable task.
302 | observations- a default-True boolean that toggles the whether a vector
303 | of satisfied constraints is added to observations.
304 | delay_spec: dictionary that specifies the delay specifications of the
305 | task. It may contain the following fields:
306 | enable- bool that represents whether delay specifications are enabled.
307 | actions- integer indicating the number of steps actions are being
308 | delayed.
309 | observations- integer indicating the number of steps observations are
310 | being delayed.
311 | rewards- integer indicating the number of steps observations are being
312 | delayed.
313 | noise_spec: dictionary that specifies the noise specifications of the
314 | task. It may contains the following fields:
315 | gaussian- dictionary that specifies the white Gaussian additive noise.
316 | It may contain the following fields:
317 | enable- bool that represents whether noise specifications are enabled.
318 | actions- float inidcating the standard deviation of a white Gaussian
319 | noise added to each action.
320 | observations- similarly, additive white Gaussian noise to each
321 | returned observation.
322 | dropped- dictionary that specifies the dropped values noise.
323 | It may contain the following fields:
324 | enable- bool that represents whether dropped values specifications are
325 | enabled.
326 | observations_prob- float in [0,1] indicating the probability of
327 | dropping each observation component independently.
328 | observations_steps- positive integer indicating the number of time
329 | steps of dropping a value (setting to zero) if dropped.
330 | actions_prob- float in [0,1] indicating the probability of dropping
331 | each action component independently.
332 | actions_steps- positive integer indicating the number of time steps of
333 | dropping a value (setting to zero) if dropped.
334 | stuck- dictionary that specifies the stuck values noise.
335 | It may contain the following fields:
336 | enable- bool that represents whether stuck values specifications are
337 | enabled.
338 | observations_prob- float in [0,1] indicating the probability of each
339 | observation component becoming stuck.
340 | observations_steps- positive integer indicating the number of time
341 | steps an observation (or components of) stays stuck.
342 | actions_prob- float in [0,1] indicating the probability of each
343 | action component becoming stuck.
344 | actions_steps- positive integer indicating the number of time
345 | steps an action (or components of) stays stuck.
346 | repetition- dictionary that specifies the repetition statistics.
347 | It may contain the following fields:
348 | enable- bool that represents whether repetition specifications are
349 | enabled.
350 | actions_prob- float in [0,1] indicating the probability of the actions
351 | to be repeated in the following steps.
352 | actions_steps- positive integer indicating the number of time steps of
353 | repeating the same action if it to be repeated.
354 | perturb_spec: dictionary that specifies the perturbation specifications
355 | of the task. It may contain the following fields:
356 | enable- bool that represents whether perturbation specifications are
357 | enabled.
358 | period- int, number of episodes between updates perturbation updates.
359 | param- string indicating which parameter to perturb (currently
360 | supporting lower_arm_length, root_damping, shoulder_damping).
361 | scheduler- string inidcating the scheduler to apply to the perturbed
362 | parameter (currently supporting constant, random_walk, drift_pos,
363 | drift_neg, cyclic_pos, cyclic_neg, uniform, and saw_wave).
364 | start - float indicating the initial value of the perturbed parameter.
365 | min - float indicating the minimal value the perturbed parameter may be.
366 | max - float indicating the maximal value the perturbed parameter may be.
367 | std - float indicating the standard deviation of the white noise for the
368 | scheduling process.
369 | dimensionality_spec: dictionary that specifies the added dimensions to the
370 | state space. It may contain the following fields:
371 | enable- bool that represents whether dimensionality specifications are
372 | enabled.
373 | num_random_state_observations - num of random (unit Gaussian)
374 | observations to add.
375 | multiobj_spec: dictionary that sets up the multi-objective challenge.
376 | The challenge works by providing an `Objective` object which describes
377 | both numerical objectives and a reward-merging method that allow to both
378 | observe the various objectives in the observation and affect the
379 | returned reward in a manner defined by the Objective object.
380 | enable- bool that represents whether delay multi-objective
381 | specifications are enabled.
382 | objective - either a string which will load an `Objective` class from
383 | utils.multiobj_objectives.OBJECTIVES, or an Objective object which
384 | subclasses utils.multiobj_objectives.Objective.
385 | reward - boolean indicating whether to add the multiobj objective's
386 | reward to the environment's returned reward.
387 | coeff - a number in [0,1] that is passed into the Objective object to
388 | change the mix between the original reward and the Objective's
389 | rewards.
390 | observed - boolean indicating whether the defined objectives should be
391 | added to the observation.
392 | random: Optional, either a `numpy.random.RandomState` instance, an integer
393 | seed for creating a new `RandomState`, or None to select a seed
394 | automatically (default).
395 | **kwargs: extra parameters passed to parent class (manipulator.Bring)
396 | """
397 | # Initialize parent classes.
398 | realworld_env.Base.__init__(self)
399 | manipulator.Bring.__init__(
400 | self, use_peg, insert, fully_observable, random=random, **kwargs)
401 |
402 | # Safety setup.
403 | self._setup_safety(safety_spec)
404 |
405 | # Delay setup.
406 | realworld_env.Base._setup_delay(self, delay_spec)
407 |
408 | # Noise setup.
409 | realworld_env.Base._setup_noise(self, noise_spec)
410 |
411 | # Perturb setup.
412 | self._setup_perturb(perturb_spec)
413 |
414 | # Dimensionality setup
415 | realworld_env.Base._setup_dimensionality(self, dimensionality_spec)
416 |
417 | # Multi-objective setup
418 | realworld_env.Base._setup_multiobj(self, multiobj_spec)
419 |
420 | self._use_peg = use_peg
421 | self._insert = insert
422 |
423 | # Safety methods.
424 | def _setup_safety(self, safety_spec):
425 | """Setup for the safety specifications of the task."""
426 | self._safety_enabled = safety_spec.get('enable', False)
427 | self._safety_observed = safety_spec.get('observations', True)
428 |
429 | if self._safety_enabled:
430 | # Add safety specifications.
431 | if 'constraints' in safety_spec:
432 | self.constraints = safety_spec['constraints']
433 | else:
434 | self.constraints = collections.OrderedDict([
435 | ('joint_angle_constraint', joint_angle_constraint),
436 | ('joint_velocity_constraint', joint_velocity_constraint),
437 | ('joint_accel_constraint', joint_accel_constraint),
438 | ('grasp_force_constraint', grasp_force_constraint)
439 | ])
440 | if 'limits' in safety_spec:
441 | self.limits = safety_spec['limits']
442 | else:
443 | if 'safety_coeff' in safety_spec:
444 | if safety_spec['safety_coeff'] < 0 or safety_spec['safety_coeff'] > 1:
445 | raise ValueError(
446 | 'safety_coeff should be in [0,1], but got {}'.format(
447 | safety_spec['safety_coeff']))
448 | safety_coeff = safety_spec['safety_coeff']
449 | else:
450 | safety_coeff = 1
451 | self.limits = {
452 | 'joint_angle_constraint':
453 | safety_coeff * np.array([[-160, -140], [160, 140]]) *
454 | np.pi / 180., # rad
455 | 'joint_velocity_constraint':
456 | safety_coeff * 10, # rad/s
457 | 'joint_accel_constraint':
458 | safety_coeff * 3600, # rad/s^2
459 | 'grasp_force_constraint':
460 | safety_coeff * 5, # newtons
461 | 'action_roc_constraint': safety_coeff * 1.5
462 | }
463 | self._constraints_obs = np.ones(len(self.constraints), dtype=bool)
464 |
465 | def safety_vars(self, physics):
466 | """Centralized retrieval of safety-related variables to simplify logging."""
467 | safety_vars = collections.OrderedDict(
468 | joint_angle=physics.named.data.qpos[_ARM_JOINTS].copy(), # rad
469 | joint_vels=np.abs(physics.named.data.qvel[_ARM_JOINTS]).copy(), # rad/s
470 | joint_accels=np.abs(
471 | physics.named.data.qacc[_ARM_JOINTS]).copy(), # rad/s^2
472 | grasp_force=physics.touch(), # log(1+newtons)
473 | actions=physics.control(),
474 | )
475 | return safety_vars
476 |
477 | def _setup_perturb(self, perturb_spec):
478 | """Setup for the perturbations specification of the task."""
479 | self._perturb_enabled = perturb_spec.get('enable', False)
480 | self._perturb_period = perturb_spec.get('period', 1)
481 |
482 | if self._perturb_enabled:
483 | # Add perturbations specifications.
484 | self._perturb_param = perturb_spec.get('param', 'lower_arm_length')
485 | # Making sure object to perturb is supported.
486 | if self._perturb_param not in PERTURB_PARAMS:
487 | raise ValueError("""param was: {}. Currently only supporting {}.
488 | """.format(self._perturb_param, PERTURB_PARAMS))
489 |
490 | # Setting perturbation function.
491 | self._perturb_scheduler = perturb_spec.get('scheduler', 'constant')
492 | if self._perturb_scheduler not in realworld_env.PERTURB_SCHEDULERS:
493 | raise ValueError("""scheduler was: {}. Currently only supporting {}.
494 | """.format(self._perturb_scheduler, realworld_env.PERTURB_SCHEDULERS))
495 |
496 | # Setting perturbation process parameters.
497 | if self._perturb_param == 'lower_arm_length':
498 | self._perturb_cur = perturb_spec.get('start', 0.12)
499 | self._perturb_start = perturb_spec.get('start', 0.12)
500 | self._perturb_min = perturb_spec.get('min', 0.1)
501 | self._perturb_max = perturb_spec.get('max', 0.25)
502 | self._perturb_std = perturb_spec.get('std', 0.01)
503 | elif self._perturb_param == 'root_damping':
504 | self._perturb_cur = perturb_spec.get('start', 2.0)
505 | self._perturb_start = perturb_spec.get('start', 2.0)
506 | self._perturb_min = perturb_spec.get('min', 0.1)
507 | self._perturb_max = perturb_spec.get('max', 10.0)
508 | self._perturb_std = perturb_spec.get('std', 0.1)
509 | elif self._perturb_param == 'shoulder_damping':
510 | self._perturb_cur = perturb_spec.get('start', 1.5)
511 | self._perturb_start = perturb_spec.get('start', 1.5)
512 | self._perturb_min = perturb_spec.get('min', 0.1)
513 | self._perturb_max = perturb_spec.get('max', 10.0)
514 | self._perturb_std = perturb_spec.get('std', 0.1)
515 |
516 | def update_physics(self):
517 | """Returns a new Physics object with perturbed parameter."""
518 | # Generate the new perturbed parameter.
519 | realworld_env.Base._generate_parameter(self)
520 |
521 | # Create new physics object with the perturb parameter.
522 | xml_string = manipulator.make_model(self._use_peg, self._insert)[0]
523 | mjcf = etree.fromstring(xml_string)
524 |
525 | if self._perturb_param == 'lower_arm_length':
526 | lower_arm = mjcf.find('./worldbody/body/body/body/geom')
527 | lower_arm.set('fromto', '0 0 0 0 0 {}'.format(self._perturb_cur))
528 | hand = mjcf.find('./worldbody/body/body/body/body')
529 | hand.set('pos', '0 0 {}'.format(self._perturb_cur))
530 | elif self._perturb_param == 'root_damping':
531 | joints = mjcf.findall('./worldbody/body/joint')
532 | for joint in joints:
533 | if joint.get('name') == 'arm_root':
534 | joint.set('damping', str(self._perturb_cur))
535 | elif self._perturb_param == 'shoulder_damping':
536 | shoulder_joint = mjcf.find('./worldbody/body/body/joint')
537 | shoulder_joint.set('damping', str(self._perturb_cur))
538 |
539 | xml_string_modified = etree.tostring(mjcf, pretty_print=True)
540 | physics = Physics.from_xml_string(xml_string_modified, common.ASSETS)
541 | return physics
542 |
543 | def before_step(self, action, physics):
544 | """Updates the environment using the action and returns a `TimeStep`."""
545 | self._last_action = physics.control()
546 | action_min = self.action_spec(physics).minimum[:]
547 | action_max = self.action_spec(physics).maximum[:]
548 | action = realworld_env.Base.before_step(self, action, action_min,
549 | action_max)
550 | manipulator.Bring.before_step(self, action, physics)
551 |
552 | def after_step(self, physics):
553 | realworld_env.Base.after_step(self, physics)
554 | manipulator.Bring.after_step(self, physics)
555 | self._last_action = None
556 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/accumulators.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Class to accumulate statistics during runs."""
17 | import collections
18 | import copy
19 |
20 | import numpy as np
21 |
22 |
23 | class StatisticsAccumulator(object):
24 | """Acumulate the statistics of an environment's real-world variables.
25 |
26 | This class will accumulate the statistics generated by an environment
27 | into a local storage variable which can then be written to disk and
28 | used by the Evaluators class.
29 | """
30 |
31 | def __init__(self, acc_safety, acc_safety_vars, acc_multiobj, auto_acc=True):
32 | """A class to easily accumulate necessary statistics for evaluation.
33 |
34 | Args:
35 | acc_safety: whether we should accumulate safety statistics.
36 | acc_safety_vars: whether we should accumulate state variables specific to
37 | safety.
38 | acc_multiobj: whether we should accumulate multi-objective statistics.
39 | auto_acc: whether to automatically accumulate when 'LAST' timesteps are
40 | pushed.
41 | """
42 | self._acc_safety = acc_safety
43 | self._acc_safety_vars = acc_safety_vars
44 | self._acc_multiobj = acc_multiobj
45 | self._auto_acc = auto_acc
46 | self._buffer = [] # Buffer of timesteps of current episode
47 | self._stat_buffers = dict()
48 |
49 | def push(self, timestep):
50 | """Pushes a new timestep onto the current episode's buffer."""
51 | local_ts = copy.deepcopy(timestep)
52 | self._buffer.append(local_ts)
53 | if local_ts.last():
54 | self.accumulate()
55 | self.clear_buffer()
56 |
57 | def clear_buffer(self):
58 | """Clears the buffer of timesteps."""
59 | self._buffer = []
60 |
61 | def accumulate(self):
62 | """Accumulates statistics for the given buffer into the stats buffer."""
63 | if self._acc_safety:
64 | self._acc_safety_stats()
65 | if self._acc_safety_vars:
66 | self._acc_safety_vars_stats()
67 | if self._acc_multiobj:
68 | self._acc_multiobj_stats()
69 | self._acc_return_stats()
70 |
71 | def _acc_safety_stats(self):
72 | """Generates safety-related statistics."""
73 | ep_buffer = []
74 | for ts in self._buffer:
75 | ep_buffer.append(ts.observation['constraints'])
76 | constraint_array = np.array(ep_buffer)
77 | # Total number of each constraint
78 | total_violations = np.sum((~constraint_array), axis=0)
79 | # # violations for each step
80 | safety_stats = self._stat_buffers.get(
81 | 'safety_stats',
82 | dict(
83 | total_violations=[],
84 | per_step_violations=np.zeros(constraint_array.shape)))
85 | # Accumulate the total number of violations of each constraint this episode
86 | safety_stats['total_violations'].append(total_violations)
87 | # Accumulate the number of violations at each timestep in the episode
88 | safety_stats['per_step_violations'] += ~constraint_array
89 | self._stat_buffers['safety_stats'] = safety_stats
90 |
91 | def _acc_safety_vars_stats(self):
92 | """Generates state-variable statistics to tune the safety constraints.
93 |
94 | This will generate a list of dict object, each describing the stats for each
95 | set of safety vars.
96 | """
97 | ep_stats = collections.OrderedDict()
98 | for key in self._buffer[0].observation['safety_vars'].keys():
99 | buf = np.array(
100 | [ts.observation['safety_vars'][key] for ts in self._buffer])
101 | stats = dict(
102 | mean=np.mean(buf, axis=0),
103 | std_dev=np.std(buf, axis=0),
104 | min=np.min(buf, axis=0),
105 | max=np.max(buf, axis=0))
106 | ep_stats[key] = stats
107 |
108 | safety_vars_buffer = self._stat_buffers.get('safety_vars_stats', [])
109 | safety_vars_buffer.append(ep_stats) # pytype: disable=attribute-error
110 | self._stat_buffers['safety_vars_stats'] = safety_vars_buffer
111 |
112 | def _acc_multiobj_stats(self):
113 | """Generates multiobj-related statistics."""
114 | ep_buffer = []
115 | for ts in self._buffer:
116 | ep_buffer.append(ts.observation['multiobj'])
117 | multiobj_array = np.array(ep_buffer)
118 | # Total number of each constraint.
119 | episode_totals = np.sum(multiobj_array, axis=0)
120 | # Number of violations for each step.
121 | multiobj_stats = self._stat_buffers.get('multiobj_stats',
122 | dict(episode_totals=[]))
123 | # Accumulate the total number of violations of each constraint this episode.
124 | multiobj_stats['episode_totals'].append(episode_totals)
125 | # Accumulate the number of violations at each timestep in the episode.
126 | self._stat_buffers['multiobj_stats'] = multiobj_stats
127 |
128 | def _acc_return_stats(self):
129 | """Generates per-episode return statistics."""
130 | ep_buffer = []
131 | for ts in self._buffer:
132 | if not ts.first(): # Skip the first ts as it has a reward of None
133 | ep_buffer.append(ts.reward)
134 | returns_array = np.array(ep_buffer)
135 | # Total number of each constraint.
136 | episode_totals = np.sum(returns_array)
137 | # Number of violations for each step.
138 | return_stats = self._stat_buffers.get('return_stats',
139 | dict(episode_totals=[]))
140 | # Accumulate the total number of violations of each constraint this episode.
141 | return_stats['episode_totals'].append(episode_totals)
142 | # Accumulate the number of violations at each timestep in the episode.
143 | self._stat_buffers['return_stats'] = return_stats
144 |
145 | def to_ndarray_dict(self):
146 | """Convert stats buffer to ndarrays to make disk writing more efficient."""
147 | buffers = copy.deepcopy(self.stat_buffers)
148 | if 'safety_stats' in buffers:
149 | buffers['safety_stats']['total_violations'] = np.array(
150 | buffers['safety_stats']['total_violations'])
151 | n_episodes = buffers['safety_stats']['total_violations'].shape[0]
152 | buffers['safety_stats']['per_step_violations'] = np.array(
153 | buffers['safety_stats']['per_step_violations']) / n_episodes
154 | if 'multiobj_stats' in buffers:
155 | buffers['multiobj_stats']['episode_totals'] = np.array(
156 | buffers['multiobj_stats']['episode_totals'])
157 | if 'return_stats' in buffers:
158 | buffers['return_stats']['episode_totals'] = np.array(
159 | buffers['return_stats']['episode_totals'])
160 | return buffers
161 |
162 | @property
163 | def stat_buffers(self):
164 | return self._stat_buffers
165 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/accumulators_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for accumulators."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 |
23 | from absl.testing import absltest
24 | from absl.testing import parameterized
25 | import numpy as np
26 | import numpy.testing as npt
27 | import realworldrl_suite.environments as rwrl
28 |
29 |
30 |
31 |
32 | class RandomAgent(object):
33 |
34 | def __init__(self, action_spec):
35 | self.action_spec = action_spec
36 |
37 | def action(self):
38 | return np.random.uniform(
39 | self.action_spec.minimum,
40 | self.action_spec.maximum,
41 | size=self.action_spec.shape)
42 |
43 |
44 | class AccumulatorsTest(parameterized.TestCase):
45 |
46 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
47 | def test_logging(self, domain_name, task_name):
48 | temp_dir = self.create_tempdir()
49 | env = rwrl.load(
50 | domain_name=domain_name,
51 | task_name=task_name,
52 | safety_spec={'enable': True},
53 | multiobj_spec={
54 | 'enable': True,
55 | 'objective': 'safety',
56 | 'observed': False,
57 | },
58 | log_output=os.path.join(temp_dir.full_path, 'test.pickle'),
59 | environment_kwargs=dict(log_safety_vars=True))
60 | random_policy = RandomAgent(env.action_spec()).action
61 | n_steps = 0
62 | for _ in range(3):
63 | timestep = env.step(random_policy())
64 | constraints = (~timestep.observation['constraints']).astype('int')
65 | n_steps += 1
66 | while not timestep.last():
67 | timestep = env.step(random_policy())
68 | constraints += (~timestep.observation['constraints']).astype('int')
69 | npt.assert_equal(
70 | env.stats_acc.stat_buffers['safety_stats']['total_violations'][-1],
71 | constraints)
72 | env.write_logs()
73 | with open(env.logs_path, 'rb') as f:
74 | read_data = np.load(f, allow_pickle=True)
75 | data = read_data['data'].item()
76 | self.assertLen(data.keys(), 4)
77 | self.assertIn('safety_vars_stats', data)
78 | self.assertIn('total_violations', data['safety_stats'])
79 | self.assertIn('per_step_violations', data['safety_stats'])
80 | self.assertIn('episode_totals', data['multiobj_stats'])
81 | self.assertIn('episode_totals', data['return_stats'])
82 | self.assertLen(data['safety_stats']['total_violations'], n_steps)
83 | self.assertLen(data['safety_vars_stats'], n_steps)
84 | self.assertLen(data['multiobj_stats']['episode_totals'], n_steps)
85 | self.assertLen(data['return_stats']['episode_totals'], n_steps)
86 |
87 |
88 | if __name__ == '__main__':
89 | absltest.main()
90 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/evaluators.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Helper lib to calculate RWRL evaluators from logged stats."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import functools
23 | import pprint
24 |
25 | from absl import app
26 | from absl import flags
27 | import matplotlib.pyplot as plt
28 | import numpy as np
29 | import scipy.stats
30 |
31 | FLAGS = flags.FLAGS
32 | if 'stats_in' not in FLAGS:
33 | flags.DEFINE_string('stats_in', 'logs.npz', 'Filename to write out logs.')
34 |
35 |
36 | _CONFIDENCE_LEVEL = .95
37 | _ROLLING_WINDOW = 101
38 | _COLORS = ['r', 'g', 'b', 'k', 'y']
39 | _MARKERS = ['x', 'o', '*', '^', 'v']
40 |
41 |
42 | class Evaluators(object):
43 | """Calculate standardized RWRL evaluators of a run."""
44 |
45 | def __init__(self, stats):
46 | """Generate object given stored statistics.
47 |
48 | Args:
49 | stats: list of standardized RWRL statistics.
50 | """
51 | self._meta = stats['meta'].item()
52 | self._stats = stats['data'].item()
53 | self._moving_average_cache = {}
54 |
55 | @property
56 | def task_name(self):
57 | return self._meta['task_name']
58 |
59 | def get_safety_evaluator(self):
60 | """Returns the RWRL safety function's evaluation.
61 |
62 | Function (3) in the RWRL paper, the per-constraint sum of violations.
63 |
64 | Returns:
65 | A dict of contraint_name -> # of violations in logs.
66 | """
67 | constraint_names = self._meta['safety_constraints']
68 | safety_stats = self._stats['safety_stats']
69 | violations = np.sum(safety_stats['total_violations'], axis=0)
70 | evaluator_results = collections.OrderedDict([
71 | (key, violations[idx]) for idx, key in enumerate(constraint_names)
72 | ])
73 | return evaluator_results
74 |
75 | def get_safety_plot(self, do_total=True, do_per_step=True):
76 | """Generates standardized plots describing safety constraint violations."""
77 | n_plots = int(do_total) + int(do_per_step)
78 | fig, axes = plt.subplots(1, n_plots, figsize=(8 * n_plots, 6))
79 | plot_idx = 0
80 | if do_total:
81 | self._get_safety_totals_plot(axes[plot_idx], self._stats['safety_stats'])
82 | plot_idx += 1
83 | if do_per_step:
84 | self._get_safety_per_step_plots(axes[plot_idx],
85 | self._stats['safety_stats'])
86 | return fig
87 |
88 | def _get_safety_totals_plot(self, ax, safety_stats):
89 | """Generates plots describing total # of violations / episode."""
90 | meta = self.meta
91 | violations_labels = meta['safety_constraints']
92 | total_violations = safety_stats['total_violations'].T
93 |
94 | for idx, violations in enumerate(total_violations):
95 | label = violations_labels[idx]
96 | ax.plot(np.arange(violations.shape[0]), violations, label=label)
97 |
98 | ax.set_title('# violations / episode')
99 | ax.legend()
100 | ax.set_ylabel('# violations')
101 | ax.set_xlabel('Episode')
102 | ax.plot()
103 |
104 | def _get_safety_per_step_plots(self, ax, safety_stats):
105 | """Generates plots describing mean # of violations / timestep."""
106 | meta = self.meta
107 | violations_labels = meta['safety_constraints']
108 | per_step_violations = safety_stats['per_step_violations']
109 |
110 | for idx, violations in enumerate(per_step_violations.T):
111 | label = violations_labels[idx]
112 | ax.plot(
113 | np.arange(violations.shape[0]), violations, label=label, alpha=0.75)
114 |
115 | ax.set_title('Mean violations / timestep')
116 | ax.legend(loc='upper right')
117 | ax.set_ylabel('Mean # violations')
118 | ax.set_xlabel('Timestep')
119 | ax.plot()
120 |
121 | def get_safety_vars_plot(self):
122 | """Get plots for statistics of safety-related variables."""
123 | if 'safety_vars_stats' not in self.stats:
124 | raise ValueError('No safety vars statistics present in this evaluator.')
125 |
126 | safety_vars = self.stats['safety_vars_stats'][0].keys()
127 | n_plots = len(safety_vars)
128 | fig, axes = plt.subplots(n_plots, 1, figsize=(8, 6 * n_plots))
129 |
130 | for idx, var in enumerate(safety_vars):
131 | series = collections.defaultdict(list)
132 | for ep in self.stats['safety_vars_stats']:
133 | for stat in ep[var]:
134 | series[stat].append(ep[var][stat])
135 | ax = axes[idx]
136 | for stat in ['min', 'max']:
137 | ax.plot(np.squeeze(np.array(series[stat])), label=stat)
138 | x = range(len(series['mean']))
139 |
140 | mean = np.squeeze(np.array(series['mean']))
141 | std_dev = np.squeeze(np.array(series['std_dev']))
142 | ax.plot(x, mean, label='Value')
143 | ax.fill_between(
144 | range(len(series['mean'])), mean - std_dev, mean + std_dev, alpha=0.3)
145 | ax.set_title('Stats for {}'.format(var))
146 | ax.legend()
147 | ax.spines['top'].set_visible(False)
148 |
149 | ax.xaxis.set_ticks_position('bottom')
150 | ax.set_xlabel('Episode #')
151 | ax.set_ylabel('Magnitude')
152 | ax.plot()
153 | return fig
154 |
155 | def _get_return_per_step_plot(self, ax, return_stats):
156 | """Plot per-episode return."""
157 | returns = return_stats['episode_totals']
158 |
159 | ax.plot(np.arange(returns.shape[0]), returns, label='Return')
160 |
161 | ax.set_title('Return / Episode')
162 | ax.legend()
163 | ax.set_ylabel('Return')
164 | ax.set_xlabel('Episode')
165 |
166 | def get_return_plot(self):
167 | fig, ax = plt.subplots()
168 | self._get_return_per_step_plot(ax, self.stats['return_stats'])
169 | return fig
170 |
171 | def _moving_average(self, values, window=1, stride=1, p=None):
172 | """Computes moving averages and confidence intervals."""
173 | # Cache results for convenience.
174 | key = (id(values), window, stride, p)
175 | if key in self._moving_average_cache:
176 | return self._moving_average_cache[key]
177 | # Compute rolling windows efficiently.
178 | def _rolling_window(a, window):
179 | shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
180 | strides = a.strides + (a.strides[-1],)
181 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
182 | x = np.mean(
183 | _rolling_window(np.arange(len(values)), window), axis=-1)[::stride]
184 | y = _rolling_window(values, window)
185 | y_mean = np.mean(y, axis=-1)[::stride]
186 | y_lower, y_upper = _errorbars(y, axis=-1, p=p)
187 | ret = (x, y_mean, (y_lower, y_upper))
188 | self._moving_average_cache[key] = ret
189 | return ret
190 |
191 | def get_convergence_episode(self):
192 | """Returns the first episode that reaches the final return."""
193 | values = self.stats['return_stats']['episode_totals']
194 | _, y, (y_lower, _) = self._moving_average(
195 | values, window=_ROLLING_WINDOW, p=_CONFIDENCE_LEVEL)
196 | # The convergence is established as the first time the average return
197 | # is above the lower bounds of the final return.
198 | first_episode = max(np.argmax(y >= y_lower[-1]), 1)
199 | return first_episode
200 |
201 | def get_final_return(self):
202 | """Returns the average final return."""
203 | values = self.stats['return_stats']['episode_totals']
204 | _, y, (_, _) = self._moving_average(values, window=_ROLLING_WINDOW,
205 | p=_CONFIDENCE_LEVEL)
206 | return y[-1]
207 |
208 | def get_absolute_regret(self):
209 | """Returns the total regret until convergence."""
210 | values = self.stats['return_stats']['episode_totals']
211 | first_episode = self.get_convergence_episode()
212 | final_return = self.get_final_return()
213 | regret = np.sum(final_return - values[:first_episode])
214 | return regret
215 |
216 | def get_normalized_regret(self):
217 | """Returns the normalized regret."""
218 | final_return = self.get_final_return()
219 | return self.get_absolute_regret() / final_return
220 |
221 | def get_normalized_instability(self):
222 | """Returns the ratio of episodes that dip below the final return."""
223 | values = self.stats['return_stats']['episode_totals']
224 | _, _, (y_lower, _) = self._moving_average(
225 | values, window=_ROLLING_WINDOW, p=_CONFIDENCE_LEVEL)
226 | first_episode = self.get_convergence_episode()
227 | if first_episode == len(values) - 1:
228 | return None
229 | episodes = np.arange(len(values))
230 | unstable_episodes = np.where(
231 | np.logical_and(values < y_lower[-1], episodes > first_episode))[0]
232 | return float(len(unstable_episodes)) / (len(values) - first_episode - 1)
233 |
234 | def get_convergence_plot(self):
235 | """Plots an illustration of the convergence analysis."""
236 | fig, ax = plt.subplots()
237 | first_episode = self.get_convergence_episode()
238 |
239 | values = self.stats['return_stats']['episode_totals']
240 | ax.plot(np.arange(len(values)), values, color='steelblue', lw=2, alpha=.9,
241 | label='Return')
242 | ax.axvline(first_episode, color='seagreen', lw=2, label='Converged')
243 | ax.set_xlim(left=0, right=first_episode * 2)
244 |
245 | ax.set_title('Normalized regret = {:.3f}'.format(
246 | self.get_normalized_regret()))
247 | ax.legend()
248 | ax.set_ylabel('Return')
249 | ax.set_xlabel('Episode')
250 | return fig
251 |
252 | def get_stability_plot(self):
253 | """Plots an illustration of the algorithm stability."""
254 | fig, ax = plt.subplots()
255 | first_episode = self.get_convergence_episode()
256 |
257 | values = self.stats['return_stats']['episode_totals']
258 | _, _, (y_lower, _) = self._moving_average(
259 | values, window=_ROLLING_WINDOW, p=_CONFIDENCE_LEVEL)
260 | episodes = np.arange(len(values))
261 | unstable_episodes = np.where(
262 | np.logical_and(values < y_lower[-1], episodes > first_episode))[0]
263 |
264 | ax.plot(episodes, values, color='steelblue', lw=2, alpha=.9,
265 | label='Return')
266 | for i, episode in enumerate(unstable_episodes):
267 | ax.axvline(episode, color='salmon', lw=2,
268 | label='Unstable' if i == 0 else None)
269 | ax.axvline(first_episode, color='seagreen', lw=2, label='Converged')
270 |
271 | ax.set_title('Normalized instability = {:.3f}%'.format(
272 | self.get_normalized_instability() * 100.))
273 | ax.legend()
274 | ax.set_ylabel('Return')
275 | ax.set_xlabel('Episode')
276 | return fig
277 |
278 | def get_multiobjective_plot(self):
279 | """Plots an illustration of the multi-objective analysis."""
280 | fig, ax = plt.subplots()
281 |
282 | values = self.stats['multiobj_stats']['episode_totals']
283 | for i in range(values.shape[1]):
284 | ax.plot(np.arange(len(values[:, i])), values[:, i],
285 | color=_COLORS[i % len(_COLORS)], lw=2, alpha=.9,
286 | label='Objective {}'.format(i))
287 | ax.legend()
288 | ax.set_ylabel('Objective value')
289 | ax.set_xlabel('Episode')
290 | return fig
291 |
292 | def get_standard_evaluators(self):
293 | """This method returns the standard RWRL evaluators.
294 |
295 | Returns:
296 | Dict of evaluators:
297 | Off-Line Performance: NotImplemented
298 | Efficiency: NotImplemented
299 | Safety: Per-constraint # of violations
300 | Robustness: NotImplemented
301 | Discernment: NotImplemented
302 | """
303 | evaluators = collections.OrderedDict(
304 | offline=None,
305 | efficiency=None,
306 | safety=self.get_safety_evaluator(),
307 | robustness=None,
308 | discernment=None)
309 | return evaluators
310 |
311 | @property
312 | def meta(self):
313 | return self._meta
314 |
315 | @property
316 | def stats(self):
317 | return self._stats
318 |
319 |
320 | def _errorbars(values, axis=None, p=None):
321 | mean = np.mean(values, axis=axis)
322 | if p is None:
323 | std = np.std(values, axis=axis)
324 | return mean - std, mean + std
325 | return scipy.stats.t.interval(.95, values.shape[axis] - 1, loc=mean,
326 | scale=scipy.stats.sem(values, axis=-1))
327 |
328 |
329 | def _map(fn, values):
330 | return dict((k, fn(v)) for k, v in values.items())
331 |
332 |
333 | def get_normalized_regret(evaluator_list):
334 | """Computes normalized regret for multiple seeds on multiple tasks."""
335 | values = collections.defaultdict(list)
336 | for e in evaluator_list:
337 | values[e.task_name].append(e.get_normalized_regret())
338 | return _map(np.mean, values), _map(np.std, values)
339 |
340 |
341 | def get_regret_plot(evaluator_list):
342 | """Plots normalized regret for multiple seeds on multiple tasks."""
343 | means, stds = get_normalized_regret(evaluator_list)
344 | task_names = sorted(means.keys())
345 | heights = []
346 | errorbars = []
347 | for task_name in task_names:
348 | heights.append(means[task_name])
349 | errorbars.append(stds[task_name])
350 | x = np.arange(len(task_names))
351 | fig, ax = plt.subplots()
352 | ax.bar(x, heights, yerr=errorbars)
353 | ax.set_xticks(x)
354 | ax.set_xticklabels(task_names)
355 | ax.set_ylabel('Normalized regret')
356 | return fig
357 |
358 |
359 | def get_return_plot(evaluator_list, stride=500):
360 | """Plots the return per episode for multiple seeds on multiple tasks."""
361 | values = collections.defaultdict(list)
362 | for e in evaluator_list:
363 | values[e.task_name].append(e.stats['return_stats']['episode_totals'])
364 | values = _map(np.vstack, values)
365 | means = _map(functools.partial(np.mean, axis=0), values)
366 | stds = _map(functools.partial(np.std, axis=0), values)
367 |
368 | fig, ax = plt.subplots()
369 | for i, task_name in enumerate(means):
370 | idx = i % len(_COLORS)
371 | x = np.arange(len(means[task_name]))
372 | ax.plot(x, means[task_name], lw=2, color=_COLORS[idx], alpha=.6, label=None)
373 | ax.plot(x[::stride], means[task_name][::stride], 'o', lw=2,
374 | marker=_MARKERS[idx], markersize=10, color=_COLORS[idx],
375 | label=task_name)
376 | ax.fill_between(x, means[task_name] - stds[task_name],
377 | means[task_name] + stds[task_name], alpha=.4, lw=2,
378 | color=_COLORS[idx])
379 | ax.legend()
380 | ax.set_ylabel('Return')
381 | ax.set_xlabel('Episode')
382 | return fig
383 |
384 |
385 | def get_multiobjective_plot(evaluator_list, stride=500):
386 | """Plots the objectives per episode for multiple seeds on multiple tasks."""
387 | num_objectives = (
388 | evaluator_list[0].stats['multiobj_stats']['episode_totals'].shape[1])
389 | values = [collections.defaultdict(list) for _ in range(num_objectives)]
390 | for e in evaluator_list:
391 | for i in range(num_objectives):
392 | values[i][e.task_name].append(
393 | e.stats['multiobj_stats']['episode_totals'][:, i])
394 | means = [None] * num_objectives
395 | stds = [None] * num_objectives
396 | for i in range(num_objectives):
397 | values[i] = _map(np.vstack, values[i])
398 | means[i] = _map(functools.partial(np.mean, axis=0), values[i])
399 | stds[i] = _map(functools.partial(np.std, axis=0), values[i])
400 |
401 | fig, axes = plt.subplots(num_objectives, 1, figsize=(8, 6 * num_objectives))
402 | for objective_idx in range(num_objectives):
403 | ax = axes[objective_idx]
404 | for i, task_name in enumerate(means[objective_idx]):
405 | m = means[objective_idx][task_name]
406 | s = stds[objective_idx][task_name]
407 | idx = i % len(_COLORS)
408 | x = np.arange(len(m))
409 | ax.plot(x, m, lw=2, color=_COLORS[idx], alpha=.6, label=None)
410 | ax.plot(x[::stride], m[::stride], 'o', lw=2, marker=_MARKERS[idx],
411 | markersize=10, color=_COLORS[idx], label=task_name)
412 | ax.fill_between(x, m - s, m + s, alpha=.4, lw=2, color=_COLORS[idx])
413 | ax.legend()
414 | ax.set_ylabel('Objective {}'.format(objective_idx))
415 | ax.set_xlabel('Episode')
416 | return fig
417 |
418 |
419 | def main(argv):
420 | if len(argv) > 1:
421 | raise app.UsageError('Too many command-line arguments.')
422 | with open(FLAGS.stats_in, 'rb') as f:
423 | stats = np.load(f)
424 | evals = Evaluators(stats)
425 | pp = pprint.PrettyPrinter(indent=2)
426 | pp.pprint(evals.get_standard_evaluators())
427 |
428 |
429 | if __name__ == '__main__':
430 | flags.mark_flag_as_required('stats_in')
431 |
432 | app.run(main)
433 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/evaluators_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for evaluators."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 |
23 | from absl.testing import absltest
24 | from absl.testing import parameterized
25 | import numpy as np
26 | import realworldrl_suite.environments as rwrl
27 | from realworldrl_suite.utils import evaluators
28 |
29 |
30 | class RandomAgent(object):
31 |
32 | def __init__(self, action_spec):
33 | self.action_spec = action_spec
34 |
35 | def action(self):
36 | return np.random.uniform(
37 | self.action_spec.minimum,
38 | self.action_spec.maximum,
39 | size=self.action_spec.shape)
40 |
41 |
42 | class EvaluatorsTest(parameterized.TestCase):
43 |
44 | def _gen_stats(self, domain_name, task_name):
45 | temp_dir = self.create_tempdir()
46 | env = rwrl.load(
47 | domain_name=domain_name,
48 | task_name=task_name,
49 | safety_spec={'enable': True},
50 | log_output=os.path.join(temp_dir.full_path, 'test.pickle'),
51 | environment_kwargs=dict(log_safety_vars=True, flat_observation=True))
52 | random_policy = RandomAgent(env.action_spec()).action
53 | for _ in range(3):
54 | timestep = env.step(random_policy())
55 | while not timestep.last():
56 | timestep = env.step(random_policy())
57 | env.write_logs()
58 | return env.logs_path
59 |
60 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
61 | def test_loading(self, domain_name, task_name):
62 | temp_path = self._gen_stats(domain_name, task_name)
63 | data_in = np.load(temp_path, allow_pickle=True)
64 | evaluators.Evaluators(data_in)
65 |
66 | def test_safety_evaluator(self):
67 | # TODO(dulacarnold): Make this test general to all envs.
68 | temp_path = self._gen_stats(
69 | domain_name='cartpole', task_name='realworld_balance')
70 | data_in = np.load(temp_path, allow_pickle=True)
71 | ev = evaluators.Evaluators(data_in)
72 | self.assertLen(ev.get_safety_evaluator(), 3)
73 |
74 | def test_standard_evaluators(self):
75 | # TODO(dulacarnold): Make this test general to all envs.
76 | temp_path = self._gen_stats(
77 | domain_name='cartpole', task_name='realworld_balance')
78 | data_in = np.load(temp_path, allow_pickle=True)
79 | ev = evaluators.Evaluators(data_in)
80 | self.assertLen(ev.get_standard_evaluators(), 5)
81 |
82 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
83 | def test_safety_plot(self, domain_name, task_name):
84 | temp_path = self._gen_stats(domain_name, task_name)
85 | data_in = np.load(temp_path, allow_pickle=True)
86 | ev = evaluators.Evaluators(data_in)
87 | ev.get_safety_plot()
88 |
89 |
90 | if __name__ == '__main__':
91 | absltest.main()
92 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/loggers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Various logging classes to easily log to different backends."""
17 | import copy
18 | import io
19 | import os
20 | import time
21 |
22 | from absl import logging
23 | import numpy as np
24 |
25 |
26 |
27 |
28 | class PickleLogger(object):
29 | """Saves data to a pickle file.
30 |
31 | This logger will save data as a list of stored elements, written to a python3
32 | pickle file. This data can then get retrieved by a PickleReader class.
33 | """
34 |
35 | def __init__(self, path):
36 | """Generate a PickleLogger object.
37 |
38 | Args:
39 | path: string of path to write to
40 | """
41 | self._meta = []
42 | self._stack = []
43 | # To know if our current run is responsible for existing files.
44 | # Used to figure out if we can overwrite a previous log, or if we've
45 | # been checkpointed in between.
46 | ts = str(int(time.time()))
47 | split = os.path.split(path)
48 | self._pickle_path = os.path.join(split[0], '{}-{}'.format(ts, split[1]))
49 |
50 | def set_meta(self, meta):
51 | """Pickleable object of metadata about the task."""
52 | self._meta = copy.deepcopy(meta)
53 |
54 | def push(self, data):
55 | self._stack.append(copy.deepcopy(data))
56 |
57 | def save(self, data=None):
58 | """Save data to disk.
59 |
60 | Args:
61 | data: Additional data structure you want to save to disk, will use the
62 | 'data' key for storage.
63 | """
64 | logs = self.logs
65 | if data is not None:
66 | logs['data'] = data
67 | with open(self._pickle_path, 'wb') as f:
68 | np.savez_compressed(f, **logs)
69 | logging.info('Saved stats to %s.', format(self._pickle_path))
70 |
71 | @property
72 | def logs(self):
73 | return dict(meta=self._meta, stack=self._stack)
74 |
75 | @property
76 | def logs_path(self):
77 | return self._pickle_path
78 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/loggers_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for loggers."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from absl.testing import absltest
22 | import numpy as np
23 | import numpy.testing as npt
24 | from realworldrl_suite.utils import loggers
25 |
26 |
27 | class LoggersTest(absltest.TestCase):
28 |
29 | def test_write(self):
30 | temp_file = self.create_tempfile()
31 | plogger = loggers.PickleLogger(path=temp_file.full_path)
32 | write_meta = np.random.randn(10, 10)
33 | push_data = np.random.randn(10, 10)
34 | save_data = np.random.randn(10, 10)
35 | plogger.set_meta(write_meta)
36 | plogger.push(push_data)
37 | plogger.save(data=save_data)
38 | with open(plogger.logs_path, 'rb') as f:
39 | read_data = np.load(f, allow_pickle=True)
40 | npt.assert_array_equal(read_data['meta'], write_meta)
41 | npt.assert_array_equal(read_data['stack'][0], push_data)
42 | npt.assert_array_equal(read_data['data'], save_data)
43 |
44 |
45 | if __name__ == '__main__':
46 | absltest.main()
47 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/multiobj_objectives.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Some basic complementary rewards for the multi-objective task."""
17 |
18 | import abc
19 | import numpy as np
20 |
21 |
22 | class Objective(abc.ABC):
23 |
24 | @abc.abstractmethod
25 | def get_objectives(self, task_obj):
26 | raise NotImplementedError
27 |
28 | @abc.abstractmethod
29 | def merge_reward(self, task_obj, reward, alpha):
30 | raise NotImplementedError
31 |
32 |
33 | class SafetyObjective(Objective):
34 | """This class defines an extra objective related to safety."""
35 |
36 | def get_objectives(self, task_obj):
37 | """Returns the safety objective: sum of satisfied constraints."""
38 | if task_obj.safety_enabled:
39 | num_constraints = float(task_obj.constraints_obs.shape[0])
40 | num_satisfied = task_obj.constraints_obs.sum()
41 | s_reward = num_satisfied / num_constraints
42 | return np.array([s_reward])
43 | else:
44 | raise Exception('Safety not enabled. Safety-based multi-objective reward'
45 | ' requires safety spec to be enabled.')
46 |
47 | def merge_reward(self, task_obj, physics, base_reward, alpha):
48 | """Returns the sum of safety violations normalized to 1."""
49 | s_reward = self.get_objectives(task_obj)[0]
50 | return (1 - alpha) * base_reward + alpha * s_reward
51 |
52 |
53 | OBJECTIVES = {'safety': SafetyObjective}
54 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/multiobj_objectives_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for multi-objective reward."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 |
21 | import numpy as np
22 | import realworldrl_suite.environments as rwrl
23 | from realworldrl_suite.utils import multiobj_objectives
24 |
25 |
26 | class MultiObjTest(parameterized.TestCase):
27 |
28 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
29 | def testMultiObjNoSafety(self, domain_name, task_name):
30 | """Ensure multi-objective safety reward can be loaded."""
31 | env = rwrl.load(
32 | domain_name=domain_name,
33 | task_name=task_name,
34 | safety_spec={'enable': False},
35 | multiobj_spec={
36 | 'enable': True,
37 | 'objective': 'safety',
38 | 'observed': True,
39 | 'coeff': 0.5
40 | })
41 | with self.assertRaises(Exception):
42 | env.reset()
43 | env.step(0)
44 |
45 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
46 | def testMultiObjPassedObjective(self, domain_name, task_name):
47 | """Ensure objective class can be passed directly."""
48 | multiobj_class = lambda: multiobj_objectives.SafetyObjective() # pylint: disable=unnecessary-lambda
49 | env = rwrl.load(
50 | domain_name=domain_name,
51 | task_name=task_name,
52 | safety_spec={'enable': True},
53 | multiobj_spec={
54 | 'enable': True,
55 | 'objective': multiobj_class,
56 | 'observed': True,
57 | 'coeff': 0.5
58 | })
59 | env.reset()
60 | env.step(0)
61 |
62 | multiobj_class = multiobj_objectives.SafetyObjective
63 | env = rwrl.load(
64 | domain_name=domain_name,
65 | task_name=task_name,
66 | safety_spec={'enable': True},
67 | multiobj_spec={
68 | 'enable': True,
69 | 'objective': multiobj_class,
70 | 'observed': True,
71 | 'coeff': 0.5
72 | })
73 | env.reset()
74 | env.step(0)
75 |
76 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
77 | def testMultiObjSafetyNoRewardObs(self, domain_name, task_name):
78 | """Ensure multi-objective safety reward can be loaded."""
79 | env = rwrl.load(
80 | domain_name=domain_name,
81 | task_name=task_name,
82 | safety_spec={'enable': True},
83 | multiobj_spec={
84 | 'enable': True,
85 | 'objective': 'safety',
86 | 'reward': False,
87 | 'observed': True,
88 | 'coeff': 0.5
89 | })
90 | env.reset()
91 | env.step(0)
92 | ts = env.step(0)
93 | self.assertIn('multiobj', ts.observation)
94 | # Make sure we see a 1 in normalized violations
95 | env.task._constraints_obs = np.ones(
96 | env.task._constraints_obs.shape).astype(np.bool)
97 | obs = env.task.get_observation(env.physics)
98 | self.assertEqual(obs['multiobj'][1], 1)
99 | # And that there is no effect on rewards
100 | env.task._multiobj_coeff = 0
101 | r1 = env.task.get_reward(env.physics)
102 | env.task._multiobj_coeff = 1
103 | r2 = env.task.get_reward(env.physics)
104 | self.assertEqual(r1, r2)
105 |
106 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
107 | def testMultiObjSafetyRewardNoObs(self, domain_name, task_name):
108 | """Ensure multi-objective safety reward can be loaded."""
109 | env = rwrl.load(
110 | domain_name=domain_name,
111 | task_name=task_name,
112 | safety_spec={'enable': True},
113 | multiobj_spec={
114 | 'enable': True,
115 | 'objective': 'safety',
116 | 'reward': True,
117 | 'observed': False,
118 | 'coeff': 0.5
119 | })
120 | env.reset()
121 | ts = env.step(0)
122 | self.assertNotIn('multiobj', ts.observation)
123 | # Make sure the method calls without error.
124 | env.task.get_multiobj_reward(env._physics, 0)
125 | env.task._constraints_obs = np.ones(
126 | env.task._constraints_obs.shape).astype(np.bool)
127 |
128 | # Make sure the mixing is working and global reward calls without error.
129 | env.task._multiobj_coeff = 1
130 | max_reward = env.task.get_reward(env.physics)
131 | env.task._multiobj_coeff = 0.5
132 | mid_reward = env.task.get_reward(env.physics)
133 | env.task._multiobj_coeff = 0.0
134 | min_reward = env.task.get_reward(env.physics)
135 | self.assertGreaterEqual(max_reward, min_reward)
136 | self.assertGreaterEqual(mid_reward, min_reward)
137 |
138 | env.task._multiobj_coeff = 0.5
139 | max_reward = env.task.get_reward(env.physics)
140 | self.assertEqual(
141 | env.task._multiobj_objective.merge_reward(env.task, env._physics, 0, 1),
142 | 1)
143 | self.assertEqual(env.task.get_multiobj_reward(env._physics, 0), 0.5)
144 | self.assertEqual(env.task.get_multiobj_reward(env._physics, 0.5), 0.75)
145 | self.assertEqual(env.task.get_multiobj_reward(env._physics, 1), 1)
146 |
147 |
148 | if __name__ == '__main__':
149 | absltest.main()
150 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/viewer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Simple viewer for realworld environments."""
17 | from absl import app
18 | from absl import flags
19 | from absl import logging
20 |
21 | from dm_control import suite
22 | from dm_control import viewer
23 |
24 | import numpy as np
25 | import realworldrl_suite.environments as rwrl
26 |
27 | flags.DEFINE_enum('suite', 'rwrl', ['rwrl', 'dm_control'], 'Suite choice')
28 | flags.DEFINE_string('domain_name', 'cartpole', 'domain name')
29 | flags.DEFINE_string('task_name', 'realworld_balance', 'Task name')
30 |
31 | FLAGS = flags.FLAGS
32 |
33 |
34 | class RandomAgent(object):
35 |
36 | def __init__(self, action_spec):
37 | self.action_spec = action_spec
38 |
39 | def action(self, timestep):
40 | del timestep
41 | return np.random.uniform(
42 | self.action_spec.minimum,
43 | self.action_spec.maximum,
44 | size=self.action_spec.shape)
45 |
46 |
47 | def main(_):
48 | if FLAGS.suite == 'dm_control':
49 | logging.info('Loading from dm_control...')
50 | env = suite.load(domain_name=FLAGS.domain_name, task_name=FLAGS.task_name)
51 | elif FLAGS.suite == 'rwrl':
52 | logging.info('Loading from rwrl...')
53 | env = rwrl.load(domain_name=FLAGS.domain_name, task_name=FLAGS.task_name)
54 | random_policy = RandomAgent(env.action_spec()).action
55 | viewer.launch(env, policy=random_policy)
56 |
57 |
58 | if __name__ == '__main__':
59 | app.run(main)
60 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/wrappers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """RealWorld RL env logging wrappers."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import copy
23 |
24 | from dm_control.rl import control
25 | import dm_env
26 | from dm_env import specs
27 | from realworldrl_suite.utils import accumulators
28 | import six
29 |
30 |
31 | class LoggingEnv(control.Environment):
32 | """Subclass of control.Environment which adds logging."""
33 |
34 | def __init__(self,
35 | physics,
36 | task,
37 | logger=None,
38 | log_safety_vars=False,
39 | time_limit=float('inf'),
40 | control_timestep=None,
41 | n_sub_steps=None,
42 | log_every=100,
43 | flat_observation=False):
44 | """A subclass of `Environment` with logging hooks.
45 |
46 | Args:
47 | physics: Instance of `Physics`.
48 | task: Instance of `Task`.
49 | logger: Instance of 'realworldrl.utils.loggers.LoggerEnv', if specified
50 | will be used to log necessary data for realworld eval.
51 | log_safety_vars: If we should also log vars in self._task.safety_vars(),
52 | generally used for debugging or to find pertinent values for vars, will
53 | increase size of files on disk
54 | time_limit: Optional `int`, maximum time for each episode in seconds. By
55 | default this is set to infinite.
56 | control_timestep: Optional control time-step, in seconds.
57 | n_sub_steps: Optional number of physical time-steps in one control
58 | time-step, aka "action repeats". Can only be supplied if
59 | `control_timestep` is not specified.
60 | log_every: How many episodes between each log write.
61 | flat_observation: If True, observations will be flattened and concatenated
62 | into a single numpy array.
63 |
64 | Raises:
65 | ValueError: If both `n_sub_steps` and `control_timestep` are supplied.
66 | """
67 | super(LoggingEnv, self).__init__(
68 | physics,
69 | task,
70 | time_limit,
71 | control_timestep,
72 | n_sub_steps,
73 | flat_observation=False)
74 | self._flat_observation_ = flat_observation
75 | self._logger = logger
76 | self._buffer = []
77 | self._counter = 0
78 | self._log_every = log_every
79 | self._ep_counter = 0
80 | self._log_safety_vars = self._task.safety_enabled and log_safety_vars
81 | if self._logger:
82 | meta_dict = dict(task_name=type(self._task).__name__)
83 | if self._task.safety_enabled:
84 | meta_dict['safety_constraints'] = list(self._task.constraints.keys())
85 | if self._log_safety_vars:
86 | meta_dict['safety_vars'] = list(
87 | list(self._task.safety_vars(self._physics).keys()))
88 | self._logger.set_meta(meta_dict)
89 |
90 | self._stats_acc = accumulators.StatisticsAccumulator(
91 | acc_safety=self._task.safety_enabled,
92 | acc_safety_vars=self._log_safety_vars,
93 | acc_multiobj=self._task.multiobj_enabled)
94 | else:
95 | self._stats_acc = None
96 |
97 | def reset(self):
98 | """Starts a new episode and returns the first `TimeStep`."""
99 | if self._stats_acc:
100 | self._stats_acc.clear_buffer()
101 | if self._task.perturb_enabled:
102 | if self._counter % self._task.perturb_period == 0:
103 | self._physics = self._task.update_physics()
104 | self._counter += 1
105 | timestep = super(LoggingEnv, self).reset()
106 | self._track(timestep)
107 | if self._flat_observation_:
108 | timestep = dm_env.TimeStep(
109 | step_type=timestep.step_type,
110 | reward=None,
111 | discount=None,
112 | observation=control.flatten_observation(
113 | timestep.observation)['observations'])
114 | return timestep
115 |
116 | def observation_spec(self):
117 | """Returns the observation specification for this environment.
118 |
119 | Infers the spec from the observation, unless the Task implements the
120 | `observation_spec` method.
121 |
122 | Returns:
123 | An dict mapping observation name to `ArraySpec` containing observation
124 | shape and dtype.
125 | """
126 | self._flat_observation = self._flat_observation_
127 | obs_spec = super(LoggingEnv, self).observation_spec()
128 | self._flat_observation = False
129 | if self._flat_observation_:
130 | return obs_spec['observations']
131 | return obs_spec
132 |
133 | def step(self, action):
134 | """Updates the environment using the action and returns a `TimeStep`."""
135 | do_track = not self._reset_next_step
136 | timestep = super(LoggingEnv, self).step(action)
137 | if do_track:
138 | self._track(timestep)
139 | if timestep.last():
140 | self._ep_counter += 1
141 | if self._ep_counter % self._log_every == 0:
142 | self.write_logs()
143 | # Only flatten observation if we're not forwarding one from a reset(),
144 | # as it will already be flattened.
145 | if self._flat_observation_ and not timestep.first():
146 | timestep = dm_env.TimeStep(
147 | step_type=timestep.step_type,
148 | reward=timestep.reward,
149 | discount=timestep.discount,
150 | observation=control.flatten_observation(
151 | timestep.observation)['observations'])
152 | return timestep
153 |
154 | def _track(self, timestep):
155 | if self._logger is None:
156 | return
157 | ts = copy.deepcopy(timestep)
158 | # Augment the timestep with unobserved variables for logging purposes.
159 | # Add safety-related observations.
160 | if self._task.safety_enabled and 'constraints' not in ts.observation:
161 | ts.observation['constraints'] = copy.copy(self._task.constraints_obs)
162 | if self._log_safety_vars:
163 | ts.observation['safety_vars'] = copy.deepcopy(
164 | self._task.safety_vars(self._physics))
165 | if self._task.multiobj_enabled and 'multiobj' not in ts.observation:
166 | ts.observation['multiobj'] = self._task.get_multiobj_obs(self._physics)
167 | self._stats_acc.push(ts)
168 |
169 | def get_logs(self):
170 | return self._logger.logs
171 |
172 | def write_logs(self):
173 | if self._logger is None:
174 | return
175 | self._logger.save(data=self._stats_acc.to_ndarray_dict())
176 |
177 | @property
178 | def stats_acc(self):
179 | return self._stats_acc
180 |
181 | @property
182 | def logs_path(self):
183 | if self._logger is None:
184 | return None
185 | return self._logger.logs_path
186 |
187 |
188 | def _spec_from_observation(observation):
189 | result = collections.OrderedDict()
190 | for key, value in six.iteritems(observation):
191 | result[key] = specs.Array(value.shape, value.dtype, name=key)
192 | return result
193 |
--------------------------------------------------------------------------------
/realworldrl_suite/utils/wrappers_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Real-World RL Suite Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for realworldrl.utils.wrappers."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 |
23 | from absl.testing import absltest
24 | from absl.testing import parameterized
25 | import numpy as np
26 | import realworldrl_suite.environments as rwrl
27 |
28 |
29 | class RandomAgent(object):
30 |
31 | def __init__(self, action_spec):
32 | self.action_spec = action_spec
33 |
34 | def action(self):
35 | return np.random.uniform(
36 | self.action_spec.minimum,
37 | self.action_spec.maximum,
38 | size=self.action_spec.shape)
39 |
40 |
41 | class WrappersTest(parameterized.TestCase):
42 |
43 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
44 | def test_init(self, domain_name, task_name):
45 | temp_file = self.create_tempfile()
46 | env = rwrl.load(
47 | domain_name=domain_name,
48 | task_name=task_name,
49 | safety_spec={'enable': True},
50 | log_output=temp_file.full_path,
51 | environment_kwargs=dict(log_safety_vars=True))
52 | env.write_logs()
53 |
54 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
55 | def test_no_logger(self, domain_name, task_name):
56 | env = rwrl.load(
57 | domain_name=domain_name,
58 | task_name=task_name,
59 | safety_spec={'enable': True},
60 | log_output=None,
61 | environment_kwargs=dict(log_safety_vars=True))
62 | env.write_logs()
63 |
64 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
65 | def test_disabled_safety_obs(self, domain_name, task_name):
66 | temp_file = self.create_tempfile()
67 | env = rwrl.load(
68 | domain_name=domain_name,
69 | task_name=task_name,
70 | safety_spec={'enable': True, 'observations': False},
71 | log_output=temp_file.full_path,
72 | environment_kwargs=dict(log_safety_vars=True))
73 | env.reset()
74 | timestep = env.step(0)
75 | self.assertNotIn('constraints', timestep.observation.keys())
76 | self.assertIn('constraints', env._stats_acc._buffer[-1].observation)
77 | env.write_logs()
78 |
79 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
80 | def test_custom_safety_constraint(self, domain_name, task_name):
81 | const_constraint = (domain_name == rwrl.ALL_TASKS[0][1])
82 | def custom_constraint(unused_env_obj, unused_safety_vars):
83 | return const_constraint
84 | constraints = {'custom_constraint': custom_constraint}
85 | temp_file = self.create_tempfile()
86 | env = rwrl.load(
87 | domain_name=domain_name,
88 | task_name=task_name,
89 | safety_spec={'enable': True, 'observations': True,
90 | 'constraints': constraints},
91 | log_output=temp_file.full_path,
92 | environment_kwargs=dict(log_safety_vars=True))
93 | env.reset()
94 | timestep = env.step(0)
95 | self.assertIn('constraints', timestep.observation.keys())
96 | self.assertEqual(timestep.observation['constraints'],
97 | np.array([const_constraint]))
98 | env.write_logs()
99 |
100 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
101 | def test_action_roc_constraint(self, domain_name, task_name):
102 | is_within_constraint = (domain_name == rwrl.ALL_TASKS[0][1])
103 | constraints = {
104 | 'action_roc_constraint': rwrl.DOMAINS[domain_name].action_roc_constraint
105 | }
106 | temp_file = self.create_tempfile()
107 | env = rwrl.load(
108 | domain_name=domain_name,
109 | task_name=task_name,
110 | safety_spec={'enable': True, 'observations': True,
111 | 'constraints': constraints},
112 | log_output=temp_file.full_path,
113 | environment_kwargs=dict(log_safety_vars=True))
114 | env.reset()
115 | _ = env.step(-100)
116 | timestep_2 = env.step(-100 if is_within_constraint else 100)
117 | self.assertIn('constraints', timestep_2.observation.keys())
118 | self.assertEqual(timestep_2.observation['constraints'],
119 | np.array([is_within_constraint]))
120 | env.write_logs()
121 |
122 | # @parameterized.parameters(*(5,))
123 | def test_log_every_n(self, every_n=5):
124 | domain_name = 'cartpole'
125 | task_name = 'realworld_balance'
126 | temp_dir = self.create_tempdir()
127 | env = rwrl.load(
128 | domain_name=domain_name,
129 | task_name=task_name,
130 | safety_spec={'enable': True, 'observations': False},
131 | log_output=os.path.join(temp_dir.full_path, 'test.pickle'),
132 | environment_kwargs=dict(log_safety_vars=True, log_every=every_n))
133 | env.reset()
134 | n = 0
135 | while True:
136 | timestep = env.step(0)
137 | if timestep.last():
138 | n += 1
139 | if n % every_n == 0:
140 | self.assertTrue(os.path.exists(env.logs_path))
141 | os.remove(env.logs_path)
142 | self.assertFalse(os.path.exists(env.logs_path))
143 | if n > 20:
144 | break
145 |
146 | @parameterized.named_parameters(*rwrl.ALL_TASKS)
147 | def test_flat_obs(self, domain_name, task_name):
148 | temp_file = self.create_tempfile()
149 | env = rwrl.load(
150 | domain_name=domain_name,
151 | task_name=task_name,
152 | safety_spec={'enable': True},
153 | log_output=temp_file.full_path,
154 | environment_kwargs=dict(log_safety_vars=True, flat_observation=True))
155 | self.assertLen(env.observation_spec().shape, 1)
156 | if domain_name == 'cartpole' and task_name in ['swingup', 'balance']:
157 | self.assertEqual(env.observation_spec().shape[0], 8)
158 | timestep = env.step(0)
159 | self.assertLen(timestep.observation.shape, 1)
160 | if domain_name == 'cartpole' and task_name in ['swingup', 'balance']:
161 | self.assertEqual(timestep.observation.shape[0], 8)
162 | env.write_logs()
163 |
164 |
165 | if __name__ == '__main__':
166 | absltest.main()
167 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Read World RL Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Setup for pip package."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import io
22 | import sys
23 | import unittest
24 |
25 | from setuptools import find_packages
26 | from setuptools import setup
27 | from setuptools.command.test import test as TestCommandBase
28 |
29 | REQUIRED_PACKAGES = [
30 | 'six', 'absl-py', 'numpy', 'dm-env', 'dm-control', 'lxml', 'scipy',
31 | 'matplotlib'
32 | ]
33 |
34 |
35 | class StderrWrapper(io.IOBase):
36 |
37 | def write(self, *args, **kwargs):
38 | return sys.stderr.write(*args, **kwargs)
39 |
40 | def writeln(self, *args, **kwargs):
41 | if args or kwargs:
42 | sys.stderr.write(*args, **kwargs)
43 | sys.stderr.write('\n')
44 |
45 |
46 | # Needed to ensure that flags are correctly parsed.
47 | class Test(TestCommandBase):
48 |
49 | def run_tests(self):
50 | # Import absl inside run, where dependencies have been loaded already.
51 | from absl import app # pylint: disable=g-import-not-at-top
52 |
53 | def main(_):
54 | test_loader = unittest.TestLoader()
55 | test_suite = test_loader.discover(
56 | 'realworldrl_suite', pattern='*_test.py')
57 | stderr = StderrWrapper()
58 | result = unittest.TextTestResult(stderr, descriptions=True, verbosity=2)
59 | test_suite.run(result)
60 | result.printErrors()
61 |
62 | final_output = ('Tests run: {}. '.format(result.testsRun) +
63 | 'Errors: {} Failures: {}.'.format(
64 | len(result.errors), len(result.failures)))
65 |
66 | header = '=' * len(final_output)
67 | stderr.writeln(header)
68 | stderr.writeln(final_output)
69 | stderr.writeln(header)
70 |
71 | if result.wasSuccessful():
72 | return 0
73 | else:
74 | return 1
75 |
76 | # Run inside absl.app.run to ensure flags parsing is done.
77 | return app.run(main)
78 |
79 |
80 | def rwrl_test_suite():
81 | test_loader = unittest.TestLoader()
82 | test_suite = test_loader.discover('realworldrl_suite', pattern='*_test.py')
83 | return test_suite
84 |
85 |
86 | setup(
87 | name='realworldrl_suite',
88 | version='1.0',
89 | description='RL evaluation framework for the real world.',
90 | url='https://github.com/google-research/realworldrl_suite',
91 | author='Google',
92 | author_email='no-reply@google.com',
93 | # Contained modules and scripts.
94 | packages=find_packages(),
95 | install_requires=REQUIRED_PACKAGES,
96 | extras_require={},
97 | platforms=['any'],
98 | license='Apache 2.0',
99 | cmdclass={
100 | 'test': Test,
101 | },
102 | )
103 |
--------------------------------------------------------------------------------