├── .gitignore ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── ant_xy_offpolicy.txt ├── ant_xy_onpolicy.txt ├── dkitty_randomized_xy_offpolicy.txt ├── humanoid_offpolicy.txt ├── humanoid_onpolicy.txt └── template_config.txt ├── env.yml ├── envs ├── assets │ ├── ant.xml │ ├── ant_footsensor.xml │ ├── half_cheetah.xml │ ├── humanoid.xml │ └── point.xml ├── dclaw.py ├── dkitty_redesign.py ├── gym_mujoco │ ├── ant.py │ ├── half_cheetah.py │ ├── humanoid.py │ └── point_mass.py ├── hand_block.py ├── skill_wrapper.py └── video_wrapper.py ├── lib ├── py_tf_policy.py └── py_uniform_replay_buffer.py └── unsupervised_skill_learning ├── dads_agent.py ├── dads_off.py ├── skill_discriminator.py └── skill_dynamics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated files 2 | *.egg-info/ 3 | .idea* 4 | *__pycache__* 5 | .ipynb_checkpoints* 6 | *.pyc 7 | *.DS_Store 8 | *.mp4 9 | *.json 10 | output/ 11 | saved_models/ 12 | env_test.py 13 | dkitty_eval.sh 14 | experiments/ 15 | dads_token.txt 16 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of authors for copyright purposes. 2 | Google LLC 3 | Archit Sharma 4 | Shixiang Gu 5 | Sergey Levine 6 | Vikash Kumar 7 | Karol Hausman -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamics-Aware Discovery of Skills (DADS) 2 | This repository is the open-source implementation of Dynamics-Aware Unsupervised Discovery of Skills ([project page][website], [arXiv][paper]). We propose an skill-discovery method which can learn skills for different agents without any rewards, while simultaneously learning dynamics model for the skills which can be leveraged for model-based control on the downstream task. This work was published in International Conference of Learning Representations ([ICLR][iclr]), 2020. 3 | 4 | We have also included an improved off-policy version of DADS, coined off-DADS. The details have been released in [Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning][rss_arxiv]. 5 | 6 | In case of problems, contact Archit Sharma. 7 | 8 | ## Table of Contents 9 | 10 | * [Setup](#setup) 11 | * [Usage](#usage) 12 | * [Citation](#citation) 13 | * [Disclaimer](#disclaimer) 14 | 15 | ## Setup 16 | 17 | #### (1) Setup MuJoCo 18 | Download and setup [mujoco][mujoco] in `~/.mujoco`. Set the `LD_LIBRARY_PATH` in your `~/.bashrc`: 19 | ``` 20 | LD_LIBRARY_PATH='~/.mujoco/mjpro150/bin':$LD_LIBRARY_PATH 21 | ``` 22 | 23 | #### (2) Setup environment 24 | Clone the repository and setup up the [conda][conda] environment to run DADS code: 25 | ``` 26 | cd 27 | conda env create -f env.yml 28 | conda activate dads-env 29 | ``` 30 | 31 | ## Usage 32 | We give a high-level explanation of how to use the code. More details pertaining to hyperparameters can be found in the the `configs/template_config.txt`, `dads_off.py` and the Appendix A of [paper][paper]. 33 | 34 | Every training run will require an experimental logging directory and a configuration file, which can be created started from the `configs/template_config.txt`. There are two phases: (a) Training where the new skills are learnt along with their skill-dynamics models and (b) evaluation where the learnt skills are evaluated on the task associated with the environment. 35 | 36 | For training, ensure `--run_train=1` is set in the configuration file. For on-policy optimization, set `--clear_buffer_every_iter=1` and ensure the replay buffer size is bigger than the number of steps collected in every iteration. For off-policy optimization (details yet to be released), set `--clear_buffer_every_iter=0`. Set the environment name (ensure the environment is listed in `get_environment()` in `dads_off.py`). To change the observation for skill-dynamics (for example to learn in x-y space), set `--reduced_observation` and correspondingly configure `process_observation()` in `dads_off.py`. The skill space can be configured to be discrete or continuous. The optimization parameters can be tweaked, and some basic values have been set in (more details in the [paper][paper]). 37 | 38 | For evaluation, ensure `--run_eval=1` and the experimental directory points to the same directory in which the training happened. Set `--num_evals` if you want to record videos of randomly sampled skills from the prior distribution. After that, the script will use the learned models to execute MPC on the latent space to optimize for the task-reward. By default, the code will call `get_environment()` to load `FLAGS.environment + '_goal'`, and will go through the list of goal-coordinates specified in the eval section of the script. 39 | 40 | We have provided the configuration files in `configs/` to reproduce results from the experiments in the [paper][paper]. Goal evaluation is currently only setup for MuJoCo Ant environement. The goal distribution can be changed in `dads_off.py` in evaluation part of the script. 41 | 42 | ``` 43 | cd 44 | python unsupervised_skill_learning/dads_off.py --logdir= --flagfile=configs/.txt 45 | ``` 46 | 47 | The specified experimental log directory will contain the tensorboard files, the saved checkpoints and the skill-evaluation videos. 48 | 49 | ## Citation 50 | To cite [Dynamics-Aware Unsupervised Discovery of Skills](paper): 51 | ``` 52 | @article{sharma2019dynamics, 53 | title={Dynamics-aware unsupervised discovery of skills}, 54 | author={Sharma, Archit and Gu, Shixiang and Levine, Sergey and Kumar, Vikash and Hausman, Karol}, 55 | journal={arXiv preprint arXiv:1907.01657}, 56 | year={2019} 57 | } 58 | ``` 59 | To cite off-DADS and [Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning][rss_arxiv]: 60 | ``` 61 | @article{sharma2020emergent, 62 | title={Emergent Real-World Robotic Skills via Unsupervised Off-Policy Reinforcement Learning}, 63 | author={Sharma, Archit and Ahn, Michael and Levine, Sergey and Kumar, Vikash and Hausman, Karol and Gu, Shixiang}, 64 | journal={arXiv preprint arXiv:2004.12974}, 65 | year={2020} 66 | } 67 | ``` 68 | ## Disclaimer 69 | This is not an officially supported Google product. 70 | 71 | [website]: https://sites.google.com/corp/view/dads-skill 72 | [paper]: https://arxiv.org/abs/1907.01657 73 | [iclr]: https://openreview.net/forum?id=HJgLZR4KvH 74 | [mujoco]: http://www.mujoco.org/ 75 | [conda]: https://docs.conda.io/en/latest/miniconda.html 76 | [rss_arxiv]: https://arxiv.org/abs/2004.12974 77 | -------------------------------------------------------------------------------- /configs/ant_xy_offpolicy.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ### TRAINING HYPERPARAMETERS ------------------- 16 | --run_train=1 17 | 18 | # metadata flags 19 | --save_model=dads 20 | --save_freq=50 21 | --record_freq=100 22 | --vid_name=skill 23 | 24 | # optimization hyperparmaters 25 | --replay_buffer_capacity=10000 26 | 27 | # (set clear_buffer_every_iter=1 for on-policy optimization) 28 | --clear_buffer_every_iter=0 29 | --initial_collect_steps=2000 30 | --collect_steps=500 31 | --num_epochs=10000 32 | 33 | # skill dynamics optimization hyperparameters 34 | --skill_dyn_train_steps=8 35 | --skill_dynamics_lr=3e-4 36 | --skill_dyn_batch_size=256 37 | 38 | # agent hyperparameters 39 | --agent_gamma=0.99 40 | --agent_lr=3e-4 41 | --agent_entropy=0.1 42 | --agent_train_steps=64 43 | --agent_batch_size=256 44 | 45 | # (optional, do not change for on-policy) relabelling or off-policy corrections 46 | --skill_dynamics_relabel_type=importance_sampling 47 | --num_samples_for_relabelling=1 48 | --is_clip_eps=10. 49 | 50 | # (optional) skills can be resampled within the episodes, relative to max_env_steps 51 | --min_steps_before_resample=2000 52 | --resample_prob=0.02 53 | 54 | # (optional) configure skill dynamics training samples to be only from the current policy 55 | --train_skill_dynamics_on_policy=0 56 | 57 | ### SHARED HYPERPARAMETERS --------------------- 58 | --environment=Ant-v1 59 | --max_env_steps=200 60 | --reduced_observation=2 61 | 62 | # define the type of skills being learnt 63 | --num_skills=2 64 | --skill_type=cont_uniform 65 | --random_skills=100 66 | --num_evals=3 67 | 68 | # (optional) policy, critic and skill dynamics 69 | --hidden_layer_size=512 70 | 71 | # (optional) skill dynamics hyperparameters 72 | --graph_type=default 73 | --num_components=4 74 | --fix_variance=1 75 | --normalize_data=1 76 | 77 | # (optional) clip sampled actions 78 | --action_clipping=1. 79 | 80 | # (optional) debugging 81 | --debug=0 82 | 83 | ### EVALUATION HYPERPARAMETERS ----------------- 84 | --run_eval=0 85 | 86 | # MPC hyperparameters 87 | --planning_horizon=1 88 | --primitive_horizon=10 89 | --num_candidate_sequences=50 90 | --refine_steps=10 91 | --mppi_gamma=10 92 | --prior_type=normal 93 | --smoothing_beta=0.9 94 | --top_primitives=5 95 | 96 | 97 | ### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS -------- 98 | # DKitty hyperparameters 99 | --expose_last_action=1 100 | --expose_upright=1 101 | --robot_noise_ratio=0.0 102 | --root_noise_ratio=0.0 103 | --upright_threshold=0.95 104 | --scale_root_position=1 105 | --randomize_hfield=0.0 106 | 107 | # DKitty/DClaw 108 | --observation_omission_size=0 109 | 110 | # Cube Manipulation hyperparameters 111 | --randomized_initial_distribution=1 112 | --horizontal_wrist_constraint=0.3 113 | --vertical_wrist_constraint=1.0 114 | -------------------------------------------------------------------------------- /configs/ant_xy_onpolicy.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ### TRAINING HYPERPARAMETERS ------------------- 16 | --run_train=1 17 | 18 | # metadata flags 19 | --save_model=dads 20 | --save_freq=50 21 | --record_freq=100 22 | --vid_name=skill 23 | 24 | # optimization hyperparmaters 25 | --replay_buffer_capacity=100000 26 | 27 | # (set clear_buffer_iter=1 for on-policy) 28 | --clear_buffer_every_iter=1 29 | --initial_collect_steps=0 30 | --collect_steps=2000 31 | --num_epochs=10000 32 | 33 | # skill dynamics optimization hyperparameters 34 | --skill_dyn_train_steps=32 35 | --skill_dynamics_lr=3e-4 36 | --skill_dyn_batch_size=256 37 | 38 | # agent hyperparameters 39 | --agent_gamma=0.995 40 | --agent_lr=3e-4 41 | --agent_entropy=0.1 42 | --agent_train_steps=64 43 | --agent_batch_size=256 44 | 45 | # (optional, do not change for on-policy) relabelling or off-policy corrections 46 | --skill_dynamics_relabel_type=importance_sampling 47 | --num_samples_for_relabelling=1 48 | --is_clip_eps=1. 49 | 50 | # (optional) skills can be resampled within the episodes, relative to max_env_steps 51 | --min_steps_before_resample=2000 52 | --resample_prob=0.02 53 | 54 | # (optional) configure skill dynamics training samples to be only from the current policy 55 | --train_skill_dynamics_on_policy=0 56 | 57 | ### SHARED HYPERPARAMETERS --------------------- 58 | --environment=Ant-v1 59 | --max_env_steps=200 60 | --reduced_observation=2 61 | 62 | # define the type of skills being learnt 63 | --num_skills=2 64 | --skill_type=cont_uniform 65 | --random_skills=100 66 | --num_evals=3 67 | 68 | # (optional) policy, critic and skill dynamics 69 | --hidden_layer_size=512 70 | 71 | # (optional) skill dynamics hyperparameters 72 | --graph_type=default 73 | --num_components=4 74 | --fix_variance=1 75 | --normalize_data=1 76 | 77 | # (optional) clip sampled actions 78 | --action_clipping=1. 79 | 80 | # (optional) debugging 81 | --debug=0 82 | 83 | ### EVALUATION HYPERPARAMETERS ----------------- 84 | --run_eval=0 85 | 86 | # MPC hyperparameters 87 | --planning_horizon=1 88 | --primitive_horizon=10 89 | --num_candidate_sequences=50 90 | --refine_steps=10 91 | --mppi_gamma=10 92 | --prior_type=normal 93 | --smoothing_beta=0.9 94 | --top_primitives=5 95 | 96 | 97 | ### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS -------- 98 | # DKitty hyperparameters 99 | --expose_last_action=1 100 | --expose_upright=1 101 | --robot_noise_ratio=0.0 102 | --root_noise_ratio=0.0 103 | --upright_threshold=0.95 104 | --scale_root_position=1 105 | --randomize_hfield=0.0 106 | 107 | # DKitty/DClaw 108 | --observation_omission_size=0 109 | 110 | # Cube Manipulation hyperparameters 111 | --randomized_initial_distribution=1 112 | --horizontal_wrist_constraint=0.3 113 | --vertical_wrist_constraint=1.0 114 | -------------------------------------------------------------------------------- /configs/dkitty_randomized_xy_offpolicy.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ### TRAINING HYPERPARAMETERS ------------------- 16 | --run_train=1 17 | 18 | # metadata flags 19 | --save_model=dads 20 | --save_freq=50 21 | --record_freq=100 22 | --vid_name=skill 23 | 24 | # optimization hyperparmaters 25 | --replay_buffer_capacity=10000 26 | 27 | # (set clear_buffer_iter=1 for on-policy) 28 | --clear_buffer_every_iter=0 29 | --initial_collect_steps=2000 30 | --collect_steps=500 31 | --num_epochs=1000 32 | 33 | # skill dynamics optimization hyperparameters 34 | --skill_dyn_train_steps=8 35 | --skill_dynamics_lr=3e-4 36 | --skill_dyn_batch_size=256 37 | 38 | # agent hyperparameters 39 | --agent_gamma=0.99 40 | --agent_lr=3e-4 41 | --agent_entropy=0.1 42 | --agent_train_steps=64 43 | --agent_batch_size=256 44 | 45 | # (optional, do not change for on-policy) relabelling or off-policy corrections 46 | --skill_dynamics_relabel_type=importance_sampling 47 | --num_samples_for_relabelling=1 48 | --is_clip_eps=10. 49 | 50 | # (optional) skills can be resampled within the episodes, relative to max_env_steps 51 | --min_steps_before_resample=2000 52 | --resample_prob=0.02 53 | 54 | # (optional) configure skill dynamics training samples to be only from the current policy 55 | --train_skill_dynamics_on_policy=0 56 | 57 | ### SHARED HYPERPARAMETERS --------------------- 58 | --environment=DKitty_randomized 59 | --max_env_steps=200 60 | --reduced_observation=2 61 | 62 | # define the type of skills being learnt 63 | --num_skills=2 64 | --skill_type=cont_uniform 65 | --random_skills=100 66 | --num_evals=3 67 | 68 | # (optional) policy, critic and skill dynamics 69 | --hidden_layer_size=512 70 | 71 | # (optional) skill dynamics hyperparameters 72 | --graph_type=default 73 | --num_components=4 74 | --fix_variance=1 75 | --normalize_data=1 76 | 77 | # (optional) clip sampled actions 78 | --action_clipping=1. 79 | 80 | # (optional) debugging 81 | --debug=0 82 | 83 | ### EVALUATION HYPERPARAMETERS ----------------- 84 | --run_eval=0 85 | 86 | # MPC hyperparameters 87 | --planning_horizon=1 88 | --primitive_horizon=10 89 | --num_candidate_sequences=50 90 | --refine_steps=10 91 | --mppi_gamma=10 92 | --prior_type=normal 93 | --smoothing_beta=0.9 94 | --top_primitives=5 95 | 96 | 97 | ### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS -------- 98 | # DKitty hyperparameters 99 | --expose_last_action=1 100 | --expose_upright=1 101 | --robot_noise_ratio=0.0 102 | --root_noise_ratio=0.0 103 | --upright_threshold=0.95 104 | --scale_root_position=1 105 | --randomize_hfield=0.02 106 | 107 | # DKitty/DClaw 108 | --observation_omission_size=2 109 | 110 | # Cube Manipulation hyperparameters 111 | --randomized_initial_distribution=1 112 | --horizontal_wrist_constraint=0.3 113 | --vertical_wrist_constraint=1.0 114 | -------------------------------------------------------------------------------- /configs/humanoid_offpolicy.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ### TRAINING HYPERPARAMETERS ------------------- 16 | --run_train=1 17 | 18 | # metadata flags 19 | --save_model=dads 20 | --save_freq=50 21 | --record_freq=100 22 | --vid_name=skill 23 | 24 | # optimization hyperparmaters 25 | --replay_buffer_capacity=10000 26 | 27 | # (set clear_buffer_iter=1 for on-policy) 28 | --clear_buffer_every_iter=0 29 | --initial_collect_steps=5000 30 | --collect_steps=2000 31 | --num_epochs=100000 32 | 33 | # skill dynamics optimization hyperparameters 34 | --skill_dyn_train_steps=16 35 | --skill_dynamics_lr=3e-4 36 | --skill_dyn_batch_size=256 37 | 38 | # agent hyperparameters 39 | --agent_gamma=0.995 40 | --agent_lr=3e-4 41 | --agent_entropy=0.1 42 | --agent_train_steps=128 43 | --agent_batch_size=256 44 | 45 | # (optional, do not change for on-policy) relabelling or off-policy corrections 46 | --skill_dynamics_relabel_type=importance_sampling 47 | --num_samples_for_relabelling=1 48 | --is_clip_eps=1. 49 | 50 | # (optional) skills can be resampled within the episodes, relative to max_env_steps 51 | --min_steps_before_resample=2000 52 | --resample_prob=0.0 53 | 54 | # (optional) configure skill dynamics training samples to be only from the current policy 55 | --train_skill_dynamics_on_policy=0 56 | 57 | ### SHARED HYPERPARAMETERS --------------------- 58 | --environment=Humanoid-v1 59 | --max_env_steps=1000 60 | --reduced_observation=0 61 | 62 | # define the type of skills being learnt 63 | --num_skills=5 64 | --skill_type=cont_uniform 65 | --random_skills=100 66 | 67 | # number of skill-video evaluations 68 | --num_evals=3 69 | 70 | # (optional) policy, critic and skill dynamics 71 | --hidden_layer_size=1024 72 | 73 | # (optional) skill dynamics hyperparameters 74 | --graph_type=default 75 | --num_components=4 76 | --fix_variance=1 77 | --normalize_data=1 78 | 79 | # (optional) clip sampled actions 80 | --action_clipping=1. 81 | 82 | # (optional) debugging 83 | --debug=0 84 | 85 | ### EVALUATION HYPERPARAMETERS ----------------- 86 | --run_eval=0 87 | 88 | # MPC hyperparameters 89 | --planning_horizon=1 90 | --primitive_horizon=10 91 | --num_candidate_sequences=50 92 | --refine_steps=10 93 | --mppi_gamma=10 94 | --prior_type=normal 95 | --smoothing_beta=0.9 96 | --top_primitives=5 97 | 98 | 99 | ### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS -------- 100 | # DKitty hyperparameters 101 | --expose_last_action=1 102 | --expose_upright=1 103 | --robot_noise_ratio=0.0 104 | --root_noise_ratio=0.0 105 | --upright_threshold=0.95 106 | --scale_root_position=1 107 | --randomize_hfield=0.0 108 | 109 | # DKitty/DClaw 110 | --observation_omission_size=0 111 | 112 | # Cube Manipulation hyperparameters 113 | --randomized_initial_distribution=1 114 | --horizontal_wrist_constraint=0.3 115 | --vertical_wrist_constraint=1.0 116 | -------------------------------------------------------------------------------- /configs/humanoid_onpolicy.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ### TRAINING HYPERPARAMETERS ------------------- 16 | --run_train=1 17 | 18 | # metadata flags 19 | --save_model=dads 20 | --save_freq=50 21 | --record_freq=100 22 | --vid_name=skill 23 | 24 | # optimization hyperparmaters 25 | --replay_buffer_capacity=100000 26 | 27 | # (set clear_buffer_iter=1 for on-policy) 28 | --clear_buffer_every_iter=1 29 | --initial_collect_steps=0 30 | --collect_steps=4000 31 | --num_epochs=100000 32 | 33 | # skill dynamics optimization hyperparameters 34 | --skill_dyn_train_steps=32 35 | --skill_dynamics_lr=3e-4 36 | --skill_dyn_batch_size=256 37 | 38 | # agent hyperparameters 39 | --agent_gamma=0.995 40 | --agent_lr=3e-4 41 | --agent_entropy=0.1 42 | --agent_train_steps=64 43 | --agent_batch_size=256 44 | 45 | # (optional, do not change for on-policy) relabelling or off-policy corrections 46 | --skill_dynamics_relabel_type=importance_sampling 47 | --num_samples_for_relabelling=1 48 | --is_clip_eps=1. 49 | 50 | # (optional) skills can be resampled within the episodes, relative to max_env_steps 51 | --min_steps_before_resample=2000 52 | --resample_prob=0.0 53 | 54 | # (optional) configure skill dynamics training samples to be only from the current policy 55 | --train_skill_dynamics_on_policy=0 56 | 57 | ### SHARED HYPERPARAMETERS --------------------- 58 | --environment=Humanoid-v1 59 | --max_env_steps=1000 60 | --reduced_observation=0 61 | 62 | # define the type of skills being learnt 63 | --num_skills=5 64 | --skill_type=cont_uniform 65 | --random_skills=100 66 | 67 | # number of skill-video evaluations 68 | --num_evals=3 69 | 70 | # (optional) policy, critic and skill dynamics 71 | --hidden_layer_size=1024 72 | 73 | # (optional) skill dynamics hyperparameters 74 | --graph_type=default 75 | --num_components=4 76 | --fix_variance=1 77 | --normalize_data=1 78 | 79 | # (optional) clip sampled actions 80 | --action_clipping=1. 81 | 82 | # (optional) debugging 83 | --debug=0 84 | 85 | ### EVALUATION HYPERPARAMETERS ----------------- 86 | --run_eval=0 87 | 88 | # MPC hyperparameters 89 | --planning_horizon=1 90 | --primitive_horizon=10 91 | --num_candidate_sequences=50 92 | --refine_steps=10 93 | --mppi_gamma=10 94 | --prior_type=normal 95 | --smoothing_beta=0.9 96 | --top_primitives=5 97 | 98 | 99 | ### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS -------- 100 | # DKitty hyperparameters 101 | --expose_last_action=1 102 | --expose_upright=1 103 | --robot_noise_ratio=0.0 104 | --root_noise_ratio=0.0 105 | --upright_threshold=0.95 106 | --scale_root_position=1 107 | --randomize_hfield=0.0 108 | 109 | # DKitty/DClaw 110 | --observation_omission_size=0 111 | 112 | # Cube Manipulation hyperparameters 113 | --randomized_initial_distribution=1 114 | --horizontal_wrist_constraint=0.3 115 | --vertical_wrist_constraint=1.0 116 | -------------------------------------------------------------------------------- /configs/template_config.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ### TRAINING HYPERPARAMETERS ------------------- 16 | --run_train=0 17 | 18 | # metadata flags 19 | --save_model=dads 20 | --save_freq=50 21 | --record_freq=100 22 | --vid_name=skill 23 | 24 | # optimization hyperparmaters 25 | --replay_buffer_capacity=100000 26 | 27 | # (set clear_buffer_iter=1 for on-policy) 28 | --clear_buffer_every_iter=0 29 | --initial_collect_steps=2000 30 | --collect_steps=1000 31 | --num_epochs=100 32 | 33 | # skill dynamics optimization hyperparameters 34 | --skill_dyn_train_steps=16 35 | --skill_dynamics_lr=3e-4 36 | --skill_dyn_batch_size=256 37 | 38 | # agent hyperparameters 39 | --agent_gamma=0.99 40 | --agent_lr=3e-4 41 | --agent_entropy=0.1 42 | --agent_train_steps=64 43 | --agent_batch_size=256 44 | 45 | # (optional, do not change for on-policy) relabelling or off-policy corrections 46 | --skill_dynamics_relabel_type=importance_sampling 47 | --num_samples_for_relabelling=1 48 | --is_clip_eps=1. 49 | 50 | # (optional) skills can be resampled within the episodes, relative to max_env_steps 51 | --min_steps_before_resample=2000 52 | --resample_prob=0.02 53 | 54 | # (optional) configure skill dynamics training samples to be only from the current policy 55 | --train_skill_dynamics_on_policy=0 56 | 57 | ### SHARED HYPERPARAMETERS --------------------- 58 | --environment= 59 | --max_env_steps=200 60 | --reduced_observation=0 61 | 62 | # define the type of skills being learnt 63 | --num_skills=2 64 | --skill_type=cont_uniform 65 | --random_skills=100 66 | 67 | # number of skill-video evaluations 68 | --num_evals=3 69 | 70 | # (optional) policy, critic and skill dynamics 71 | --hidden_layer_size=512 72 | 73 | # (optional) skill dynamics hyperparameters 74 | --graph_type=default 75 | --num_components=4 76 | --fix_variance=1 77 | --normalize_data=1 78 | 79 | # (optional) clip sampled actions 80 | --action_clipping=1. 81 | 82 | # (optional) debugging 83 | --debug=0 84 | 85 | ### EVALUATION HYPERPARAMETERS ----------------- 86 | --run_eval=0 87 | 88 | # MPC hyperparameters 89 | --planning_horizon=1 90 | --primitive_horizon=10 91 | --num_candidate_sequences=50 92 | --refine_steps=10 93 | --mppi_gamma=10 94 | --prior_type=normal 95 | --smoothing_beta=0.9 96 | --top_primitives=5 97 | 98 | 99 | ### (optional) ENVIRONMENT SPECIFIC HYPERPARAMETERS -------- 100 | # DKitty hyperparameters 101 | --expose_last_action=1 102 | --expose_upright=1 103 | --robot_noise_ratio=0.0 104 | --root_noise_ratio=0.0 105 | --upright_threshold=0.95 106 | --scale_root_position=1 107 | --randomize_hfield=0.0 108 | 109 | # DKitty/DClaw 110 | --observation_omission_size=0 111 | 112 | # Cube Manipulation hyperparameters 113 | --randomized_initial_distribution=1 114 | --horizontal_wrist_constraint=0.3 115 | --vertical_wrist_constraint=1.0 116 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: dads-env 16 | channels: 17 | - defaults 18 | - conda-forge 19 | dependencies: 20 | - python=3.6.8 21 | - pip>=18.1 22 | - conda>=4.6.7 23 | - pip: 24 | - numpy<2.0,>=1.16.0 25 | - tensorflow-probability==0.10.0 26 | - tensorflow==2.2.0 27 | - tf-agents==0.4.0 28 | - tensorflow-estimator==2.2.0 29 | - gym==0.11.0 30 | - matplotlib==3.0.2 31 | - robel==0.1.2 32 | - mujoco-py==2.0.2.5 33 | - click 34 | - transforms3d 35 | -------------------------------------------------------------------------------- /envs/assets/ant.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 98 | -------------------------------------------------------------------------------- /envs/assets/ant_footsensor.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 109 | -------------------------------------------------------------------------------- /envs/assets/half_cheetah.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 113 | -------------------------------------------------------------------------------- /envs/assets/humanoid.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /envs/assets/point.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 48 | -------------------------------------------------------------------------------- /envs/dclaw.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Turn tasks with DClaw robots. 16 | 17 | This is a single rotation of an object from an initial angle to a target angle. 18 | """ 19 | 20 | import abc 21 | import collections 22 | from typing import Dict, Optional, Sequence 23 | 24 | import numpy as np 25 | 26 | from robel.components.robot.dynamixel_robot import DynamixelRobotState 27 | from robel.dclaw.base_env import BaseDClawObjectEnv 28 | from robel.simulation.randomize import SimRandomizer 29 | from robel.utils.configurable import configurable 30 | from robel.utils.resources import get_asset_path 31 | 32 | # The observation keys that are concatenated as the environment observation. 33 | DEFAULT_OBSERVATION_KEYS = ( 34 | 'object_x', 35 | 'object_y', 36 | 'claw_qpos', 37 | 'last_action', 38 | ) 39 | 40 | # Reset pose for the claw joints. 41 | RESET_POSE = [0, -np.pi / 3, np.pi / 3] * 3 42 | 43 | DCLAW3_ASSET_PATH = 'robel/dclaw/assets/dclaw3xh_valve3_v0.xml' 44 | 45 | 46 | class BaseDClawTurn(BaseDClawObjectEnv, metaclass=abc.ABCMeta): 47 | """Shared logic for DClaw turn tasks.""" 48 | 49 | def __init__(self, 50 | asset_path: str = DCLAW3_ASSET_PATH, 51 | observation_keys: Sequence[str] = DEFAULT_OBSERVATION_KEYS, 52 | frame_skip: int = 40, 53 | **kwargs): 54 | """Initializes the environment. 55 | 56 | Args: 57 | asset_path: The XML model file to load. 58 | observation_keys: The keys in `get_obs_dict` to concatenate as the 59 | observations returned by `step` and `reset`. 60 | frame_skip: The number of simulation steps per environment step. 61 | interactive: If True, allows the hardware guide motor to freely 62 | rotate and its current angle is used as the goal. 63 | success_threshold: The difference threshold (in radians) of the 64 | object position and the goal position within which we consider 65 | as a sucesss. 66 | """ 67 | super().__init__( 68 | sim_model=get_asset_path(asset_path), 69 | observation_keys=observation_keys, 70 | frame_skip=frame_skip, 71 | **kwargs) 72 | 73 | self._desired_claw_pos = RESET_POSE 74 | 75 | # The following are modified (possibly every reset) by subclasses. 76 | self._initial_object_pos = 0 77 | self._initial_object_vel = 0 78 | 79 | def _reset(self): 80 | """Resets the environment.""" 81 | self._reset_dclaw_and_object( 82 | claw_pos=RESET_POSE, 83 | object_pos=self._initial_object_pos, 84 | object_vel=self._initial_object_vel) 85 | 86 | def _step(self, action: np.ndarray): 87 | """Applies an action to the robot.""" 88 | self.robot.step({ 89 | 'dclaw': action, 90 | }) 91 | 92 | def get_obs_dict(self) -> Dict[str, np.ndarray]: 93 | """Returns the current observation of the environment. 94 | 95 | Returns: 96 | A dictionary of observation values. This should be an ordered 97 | dictionary if `observation_keys` isn't set. 98 | """ 99 | claw_state, object_state = self.robot.get_state( 100 | ['dclaw', 'object']) 101 | 102 | obs_dict = collections.OrderedDict(( 103 | ('claw_qpos', claw_state.qpos), 104 | ('claw_qvel', claw_state.qvel), 105 | ('object_x', np.cos(object_state.qpos)), 106 | ('object_y', np.sin(object_state.qpos)), 107 | ('object_qvel', object_state.qvel), 108 | ('last_action', self._get_last_action()), 109 | )) 110 | # Add hardware-specific state if present. 111 | if isinstance(claw_state, DynamixelRobotState): 112 | obs_dict['claw_current'] = claw_state.current 113 | 114 | return obs_dict 115 | 116 | def get_reward_dict( 117 | self, 118 | action: np.ndarray, 119 | obs_dict: Dict[str, np.ndarray], 120 | ) -> Dict[str, np.ndarray]: 121 | """Returns the reward for the given action and observation.""" 122 | reward_dict = collections.OrderedDict(()) 123 | return reward_dict 124 | 125 | def get_score_dict( 126 | self, 127 | obs_dict: Dict[str, np.ndarray], 128 | reward_dict: Dict[str, np.ndarray], 129 | ) -> Dict[str, np.ndarray]: 130 | """Returns a standardized measure of success for the environment.""" 131 | return collections.OrderedDict(()) 132 | 133 | def get_done( 134 | self, 135 | obs_dict: Dict[str, np.ndarray], 136 | reward_dict: Dict[str, np.ndarray], 137 | ) -> np.ndarray: 138 | """Returns whether the episode should terminate.""" 139 | return np.zeros_like([0], dtype=bool) 140 | 141 | 142 | @configurable(pickleable=True) 143 | class DClawTurnRandom(BaseDClawTurn): 144 | """Turns the object with a random initial and random target position.""" 145 | 146 | def _reset(self): 147 | # Initial position is +/- 60 degrees. 148 | self._initial_object_pos = self.np_random.uniform( 149 | low=-np.pi / 3, high=np.pi / 3) 150 | super()._reset() 151 | 152 | 153 | @configurable(pickleable=True) 154 | class DClawTurnRandomDynamics(DClawTurnRandom): 155 | """Turns the object with a random initial and random target position. 156 | 157 | The dynamics of the simulation are randomized each episode. 158 | """ 159 | 160 | def __init__(self, 161 | *args, 162 | sim_observation_noise: Optional[float] = 0.05, 163 | **kwargs): 164 | super().__init__( 165 | *args, sim_observation_noise=sim_observation_noise, **kwargs) 166 | self._randomizer = SimRandomizer(self) 167 | self._dof_indices = ( 168 | self.robot.get_config('dclaw').qvel_indices.tolist() + 169 | self.robot.get_config('object').qvel_indices.tolist()) 170 | 171 | def _reset(self): 172 | # Randomize joint dynamics. 173 | self._randomizer.randomize_dofs( 174 | self._dof_indices, 175 | damping_range=(0.005, 0.1), 176 | friction_loss_range=(0.001, 0.005), 177 | ) 178 | self._randomizer.randomize_actuators( 179 | all_same=True, 180 | kp_range=(1, 3), 181 | ) 182 | # Randomize friction on all geoms in the scene. 183 | self._randomizer.randomize_geoms( 184 | all_same=True, 185 | friction_slide_range=(0.8, 1.2), 186 | friction_spin_range=(0.003, 0.007), 187 | friction_roll_range=(0.00005, 0.00015), 188 | ) 189 | self._randomizer.randomize_bodies( 190 | ['mount'], 191 | position_perturb_range=(-0.01, 0.01), 192 | ) 193 | self._randomizer.randomize_geoms( 194 | ['mount'], 195 | color_range=(0.2, 0.9), 196 | ) 197 | self._randomizer.randomize_geoms( 198 | parent_body_names=['valve'], 199 | color_range=(0.2, 0.9), 200 | ) 201 | super()._reset() 202 | -------------------------------------------------------------------------------- /envs/dkitty_redesign.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """DKitty redesign 16 | """ 17 | 18 | import abc 19 | import collections 20 | from typing import Dict, Optional, Sequence, Tuple, Union 21 | 22 | import numpy as np 23 | 24 | from robel.components.tracking import TrackerState 25 | from robel.dkitty.base_env import BaseDKittyUprightEnv 26 | from robel.simulation.randomize import SimRandomizer 27 | from robel.utils.configurable import configurable 28 | from robel.utils.math_utils import calculate_cosine 29 | from robel.utils.resources import get_asset_path 30 | 31 | DKITTY_ASSET_PATH = 'robel/dkitty/assets/dkitty_walk-v0.xml' 32 | 33 | DEFAULT_OBSERVATION_KEYS = ( 34 | 'root_pos', 35 | 'root_euler', 36 | 'kitty_qpos', 37 | # 'root_vel', 38 | # 'root_angular_vel', 39 | 'kitty_qvel', 40 | 'last_action', 41 | 'upright', 42 | ) 43 | 44 | 45 | class BaseDKittyWalk(BaseDKittyUprightEnv, metaclass=abc.ABCMeta): 46 | """Shared logic for DKitty walk tasks.""" 47 | 48 | def __init__( 49 | self, 50 | asset_path: str = DKITTY_ASSET_PATH, 51 | observation_keys: Sequence[str] = DEFAULT_OBSERVATION_KEYS, 52 | device_path: Optional[str] = None, 53 | torso_tracker_id: Optional[Union[str, int]] = None, 54 | frame_skip: int = 40, 55 | sticky_action_probability: float = 0., 56 | upright_threshold: float = 0.9, 57 | upright_reward: float = 1, 58 | falling_reward: float = -500, 59 | expose_last_action: bool = True, 60 | expose_upright: bool = True, 61 | robot_noise_ratio: float = 0.05, 62 | **kwargs): 63 | """Initializes the environment. 64 | 65 | Args: 66 | asset_path: The XML model file to load. 67 | observation_keys: The keys in `get_obs_dict` to concatenate as the 68 | observations returned by `step` and `reset`. 69 | device_path: The device path to Dynamixel hardware. 70 | torso_tracker_id: The device index or serial of the tracking device 71 | for the D'Kitty torso. 72 | frame_skip: The number of simulation steps per environment step. 73 | sticky_action_probability: Repeat previous action with this 74 | probability. Default 0 (no sticky actions). 75 | upright_threshold: The threshold (in [0, 1]) above which the D'Kitty 76 | is considered to be upright. If the cosine similarity of the 77 | D'Kitty's z-axis with the global z-axis is below this threshold, 78 | the D'Kitty is considered to have fallen. 79 | upright_reward: The reward multiplier for uprightedness. 80 | falling_reward: The reward multipler for falling. 81 | """ 82 | self._expose_last_action = expose_last_action 83 | self._expose_upright = expose_upright 84 | observation_keys = observation_keys[:-2] 85 | if self._expose_last_action: 86 | observation_keys += ('last_action',) 87 | if self._expose_upright: 88 | observation_keys += ('upright',) 89 | 90 | # robot_config = self.get_robot_config(device_path) 91 | # if 'sim_observation_noise' in robot_config.keys(): 92 | # robot_config['sim_observation_noise'] = robot_noise_ratio 93 | 94 | super().__init__( 95 | sim_model=get_asset_path(asset_path), 96 | # robot_config=robot_config, 97 | # tracker_config=self.get_tracker_config( 98 | # torso=torso_tracker_id, 99 | # ), 100 | observation_keys=observation_keys, 101 | frame_skip=frame_skip, 102 | upright_threshold=upright_threshold, 103 | upright_reward=upright_reward, 104 | falling_reward=falling_reward, 105 | **kwargs) 106 | 107 | self._last_action = np.zeros(12) 108 | self._sticky_action_probability = sticky_action_probability 109 | self._time_step = 0 110 | 111 | def _reset(self): 112 | """Resets the environment.""" 113 | self._reset_dkitty_standing() 114 | 115 | # Set the tracker locations. 116 | self.tracker.set_state({ 117 | 'torso': TrackerState(pos=np.zeros(3), rot=np.identity(3)), 118 | }) 119 | 120 | self._time_step = 0 121 | 122 | def _step(self, action: np.ndarray): 123 | """Applies an action to the robot.""" 124 | self._time_step += 1 125 | 126 | # Sticky actions 127 | rand = self.np_random.uniform() < self._sticky_action_probability 128 | action_to_apply = np.where(rand, self._last_action, action) 129 | 130 | # Apply action. 131 | self.robot.step({ 132 | 'dkitty': action_to_apply, 133 | }) 134 | # Save the action to add to the observation. 135 | self._last_action = action 136 | 137 | def get_obs_dict(self) -> Dict[str, np.ndarray]: 138 | """Returns the current observation of the environment. 139 | 140 | Returns: 141 | A dictionary of observation values. This should be an ordered 142 | dictionary if `observation_keys` isn't set. 143 | """ 144 | robot_state = self.robot.get_state('dkitty') 145 | torso_track_state = self.tracker.get_state( 146 | ['torso'])[0] 147 | obs_dict = (('root_pos', torso_track_state.pos), 148 | ('root_euler', torso_track_state.rot_euler), 149 | ('root_vel', torso_track_state.vel), 150 | ('root_angular_vel', torso_track_state.angular_vel), 151 | ('kitty_qpos', robot_state.qpos), 152 | ('kitty_qvel', robot_state.qvel)) 153 | 154 | if self._expose_last_action: 155 | obs_dict += (('last_action', self._last_action),) 156 | 157 | # Add observation terms relating to being upright. 158 | if self._expose_upright: 159 | obs_dict += (*self._get_upright_obs(torso_track_state).items(),) 160 | 161 | return collections.OrderedDict(obs_dict) 162 | 163 | def get_reward_dict( 164 | self, 165 | action: np.ndarray, 166 | obs_dict: Dict[str, np.ndarray], 167 | ) -> Dict[str, np.ndarray]: 168 | """Returns the reward for the given action and observation.""" 169 | reward_dict = collections.OrderedDict(()) 170 | return reward_dict 171 | 172 | def get_score_dict( 173 | self, 174 | obs_dict: Dict[str, np.ndarray], 175 | reward_dict: Dict[str, np.ndarray], 176 | ) -> Dict[str, np.ndarray]: 177 | """Returns a standardized measure of success for the environment.""" 178 | return collections.OrderedDict(()) 179 | 180 | @configurable(pickleable=True) 181 | class DKittyRandomDynamics(BaseDKittyWalk): 182 | """Walk straight towards a random location.""" 183 | 184 | def __init__(self, *args, randomize_hfield=0.0, **kwargs): 185 | super().__init__(*args, **kwargs) 186 | self._randomizer = SimRandomizer(self) 187 | self._randomize_hfield = randomize_hfield 188 | self._dof_indices = ( 189 | self.robot.get_config('dkitty').qvel_indices.tolist()) 190 | 191 | def _reset(self): 192 | """Resets the environment.""" 193 | # Randomize joint dynamics. 194 | self._randomizer.randomize_dofs( 195 | self._dof_indices, 196 | all_same=True, 197 | damping_range=(0.1, 0.2), 198 | friction_loss_range=(0.001, 0.005), 199 | ) 200 | self._randomizer.randomize_actuators( 201 | all_same=True, 202 | kp_range=(2.8, 3.2), 203 | ) 204 | # Randomize friction on all geoms in the scene. 205 | self._randomizer.randomize_geoms( 206 | all_same=True, 207 | friction_slide_range=(0.8, 1.2), 208 | friction_spin_range=(0.003, 0.007), 209 | friction_roll_range=(0.00005, 0.00015), 210 | ) 211 | # Generate a random height field. 212 | self._randomizer.randomize_global( 213 | total_mass_range=(1.6, 2.0), 214 | height_field_range=(0, self._randomize_hfield), 215 | ) 216 | # if self._randomize_hfield > 0.0: 217 | # self.sim_scene.upload_height_field(0) 218 | super()._reset() 219 | -------------------------------------------------------------------------------- /envs/gym_mujoco/ant.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | 21 | from gym import utils 22 | import numpy as np 23 | from gym.envs.mujoco import mujoco_env 24 | 25 | def q_inv(a): 26 | return [a[0], -a[1], -a[2], -a[3]] 27 | 28 | 29 | def q_mult(a, b): # multiply two quaternion 30 | w = a[0] * b[0] - a[1] * b[1] - a[2] * b[2] - a[3] * b[3] 31 | i = a[0] * b[1] + a[1] * b[0] + a[2] * b[3] - a[3] * b[2] 32 | j = a[0] * b[2] - a[1] * b[3] + a[2] * b[0] + a[3] * b[1] 33 | k = a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + a[3] * b[0] 34 | return [w, i, j, k] 35 | 36 | # pylint: disable=missing-docstring 37 | class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): 38 | 39 | def __init__(self, 40 | task="forward", 41 | goal=None, 42 | expose_all_qpos=False, 43 | expose_body_coms=None, 44 | expose_body_comvels=None, 45 | expose_foot_sensors=False, 46 | use_alt_path=False, 47 | model_path="ant.xml"): 48 | self._task = task 49 | self._goal = goal 50 | self._expose_all_qpos = expose_all_qpos 51 | self._expose_body_coms = expose_body_coms 52 | self._expose_body_comvels = expose_body_comvels 53 | self._expose_foot_sensors = expose_foot_sensors 54 | self._body_com_indices = {} 55 | self._body_comvel_indices = {} 56 | 57 | # Settings from 58 | # https://github.com/openai/gym/blob/master/gym/envs/__init__.py 59 | 60 | xml_path = "envs/assets/" 61 | model_path = os.path.abspath(os.path.join(xml_path, model_path)) 62 | mujoco_env.MujocoEnv.__init__(self, model_path, 5) 63 | utils.EzPickle.__init__(self) 64 | 65 | def compute_reward(self, ob, next_ob, action=None): 66 | xposbefore = ob[:, 0] 67 | yposbefore = ob[:, 1] 68 | xposafter = next_ob[:, 0] 69 | yposafter = next_ob[:, 1] 70 | 71 | forward_reward = (xposafter - xposbefore) / self.dt 72 | sideward_reward = (yposafter - yposbefore) / self.dt 73 | 74 | if action is not None: 75 | ctrl_cost = .5 * np.square(action).sum(axis=1) 76 | survive_reward = 1.0 77 | if self._task == "forward": 78 | reward = forward_reward - ctrl_cost + survive_reward 79 | elif self._task == "backward": 80 | reward = -forward_reward - ctrl_cost + survive_reward 81 | elif self._task == "left": 82 | reward = sideward_reward - ctrl_cost + survive_reward 83 | elif self._task == "right": 84 | reward = -sideward_reward - ctrl_cost + survive_reward 85 | elif self._task == "goal": 86 | reward = -np.linalg.norm( 87 | np.array([xposafter, yposafter]).T - self._goal, axis=1) 88 | 89 | return reward 90 | 91 | def step(self, a): 92 | xposbefore = self.get_body_com("torso")[0] 93 | yposbefore = self.sim.data.qpos.flat[1] 94 | self.do_simulation(a, self.frame_skip) 95 | xposafter = self.get_body_com("torso")[0] 96 | yposafter = self.sim.data.qpos.flat[1] 97 | 98 | forward_reward = (xposafter - xposbefore) / self.dt 99 | sideward_reward = (yposafter - yposbefore) / self.dt 100 | 101 | ctrl_cost = .5 * np.square(a).sum() 102 | survive_reward = 1.0 103 | if self._task == "forward": 104 | reward = forward_reward - ctrl_cost + survive_reward 105 | elif self._task == "backward": 106 | reward = -forward_reward - ctrl_cost + survive_reward 107 | elif self._task == "left": 108 | reward = sideward_reward - ctrl_cost + survive_reward 109 | elif self._task == "right": 110 | reward = -sideward_reward - ctrl_cost + survive_reward 111 | elif self._task == "goal": 112 | reward = -np.linalg.norm(np.array([xposafter, yposafter]) - self._goal) 113 | elif self._task == "motion": 114 | reward = np.max(np.abs(np.array([forward_reward, sideward_reward 115 | ]))) - ctrl_cost + survive_reward 116 | 117 | state = self.state_vector() 118 | notdone = np.isfinite(state).all() 119 | done = not notdone 120 | ob = self._get_obs() 121 | return ob, reward, done, dict( 122 | reward_forward=forward_reward, 123 | reward_sideward=sideward_reward, 124 | reward_ctrl=-ctrl_cost, 125 | reward_survive=survive_reward) 126 | 127 | def _get_obs(self): 128 | # No crfc observation 129 | if self._expose_all_qpos: 130 | obs = np.concatenate([ 131 | self.sim.data.qpos.flat[:15], 132 | self.sim.data.qvel.flat[:14], 133 | ]) 134 | else: 135 | obs = np.concatenate([ 136 | self.sim.data.qpos.flat[2:15], 137 | self.sim.data.qvel.flat[:14], 138 | ]) 139 | 140 | if self._expose_body_coms is not None: 141 | for name in self._expose_body_coms: 142 | com = self.get_body_com(name) 143 | if name not in self._body_com_indices: 144 | indices = range(len(obs), len(obs) + len(com)) 145 | self._body_com_indices[name] = indices 146 | obs = np.concatenate([obs, com]) 147 | 148 | if self._expose_body_comvels is not None: 149 | for name in self._expose_body_comvels: 150 | comvel = self.get_body_comvel(name) 151 | if name not in self._body_comvel_indices: 152 | indices = range(len(obs), len(obs) + len(comvel)) 153 | self._body_comvel_indices[name] = indices 154 | obs = np.concatenate([obs, comvel]) 155 | 156 | if self._expose_foot_sensors: 157 | obs = np.concatenate([obs, self.sim.data.sensordata]) 158 | return obs 159 | 160 | def reset_model(self): 161 | qpos = self.init_qpos + self.np_random.uniform( 162 | size=self.sim.model.nq, low=-.1, high=.1) 163 | qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .1 164 | 165 | qpos[15:] = self.init_qpos[15:] 166 | qvel[14:] = 0. 167 | 168 | self.set_state(qpos, qvel) 169 | return self._get_obs() 170 | 171 | def viewer_setup(self): 172 | self.viewer.cam.distance = self.model.stat.extent * 2.5 173 | 174 | def get_ori(self): 175 | ori = [0, 1, 0, 0] 176 | rot = self.sim.data.qpos[3:7] # take the quaternion 177 | ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane 178 | ori = math.atan2(ori[1], ori[0]) 179 | return ori 180 | 181 | @property 182 | def body_com_indices(self): 183 | return self._body_com_indices 184 | 185 | @property 186 | def body_comvel_indices(self): 187 | return self._body_comvel_indices 188 | -------------------------------------------------------------------------------- /envs/gym_mujoco/half_cheetah.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | 21 | from gym import utils 22 | import numpy as np 23 | from gym.envs.mujoco import mujoco_env 24 | 25 | 26 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): 27 | 28 | def __init__(self, 29 | expose_all_qpos=False, 30 | task='default', 31 | target_velocity=None, 32 | model_path='half_cheetah.xml'): 33 | # Settings from 34 | # https://github.com/openai/gym/blob/master/gym/envs/__init__.py 35 | self._expose_all_qpos = expose_all_qpos 36 | self._task = task 37 | self._target_velocity = target_velocity 38 | 39 | xml_path = "envs/assets/" 40 | model_path = os.path.abspath(os.path.join(xml_path, model_path)) 41 | 42 | mujoco_env.MujocoEnv.__init__( 43 | self, 44 | model_path, 45 | 5) 46 | utils.EzPickle.__init__(self) 47 | 48 | def step(self, action): 49 | xposbefore = self.sim.data.qpos[0] 50 | self.do_simulation(action, self.frame_skip) 51 | xposafter = self.sim.data.qpos[0] 52 | xvelafter = self.sim.data.qvel[0] 53 | ob = self._get_obs() 54 | reward_ctrl = -0.1 * np.square(action).sum() 55 | 56 | if self._task == 'default': 57 | reward_vel = 0. 58 | reward_run = (xposafter - xposbefore) / self.dt 59 | reward = reward_ctrl + reward_run 60 | elif self._task == 'target_velocity': 61 | reward_vel = -(self._target_velocity - xvelafter)**2 62 | reward = reward_ctrl + reward_vel 63 | elif self._task == 'run_back': 64 | reward_vel = 0. 65 | reward_run = (xposbefore - xposafter) / self.dt 66 | reward = reward_ctrl + reward_run 67 | 68 | done = False 69 | return ob, reward, done, dict( 70 | reward_run=reward_run, reward_ctrl=reward_ctrl, reward_vel=reward_vel) 71 | 72 | def _get_obs(self): 73 | if self._expose_all_qpos: 74 | return np.concatenate( 75 | [self.sim.data.qpos.flat, self.sim.data.qvel.flat]) 76 | return np.concatenate([ 77 | self.sim.data.qpos.flat[1:], 78 | self.sim.data.qvel.flat, 79 | ]) 80 | 81 | def reset_model(self): 82 | qpos = self.init_qpos + self.np_random.uniform( 83 | low=-.1, high=.1, size=self.sim.model.nq) 84 | qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .1 85 | self.set_state(qpos, qvel) 86 | return self._get_obs() 87 | 88 | def viewer_setup(self): 89 | self.viewer.cam.distance = self.model.stat.extent * 0.5 90 | -------------------------------------------------------------------------------- /envs/gym_mujoco/humanoid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | 21 | from gym import utils 22 | import numpy as np 23 | from gym.envs.mujoco import mujoco_env 24 | 25 | 26 | def mass_center(sim): 27 | mass = np.expand_dims(sim.model.body_mass, 1) 28 | xpos = sim.data.xipos 29 | return (np.sum(mass * xpos, 0) / np.sum(mass))[0] 30 | 31 | 32 | # pylint: disable=missing-docstring 33 | class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle): 34 | 35 | def __init__(self, 36 | expose_all_qpos=False, 37 | model_path='humanoid.xml', 38 | task=None, 39 | goal=None): 40 | 41 | self._task = task 42 | self._goal = goal 43 | if self._task == "follow_goals": 44 | self._goal_list = [ 45 | np.array([3.0, -0.5]), 46 | np.array([6.0, 8.0]), 47 | np.array([12.0, 12.0]), 48 | ] 49 | self._goal = self._goal_list[0] 50 | print("Following a trajectory of goals:", self._goal_list) 51 | 52 | self._expose_all_qpos = expose_all_qpos 53 | xml_path = "envs/assets/" 54 | model_path = os.path.abspath(os.path.join(xml_path, model_path)) 55 | mujoco_env.MujocoEnv.__init__(self, model_path, 5) 56 | utils.EzPickle.__init__(self) 57 | 58 | def _get_obs(self): 59 | data = self.sim.data 60 | if self._expose_all_qpos: 61 | return np.concatenate([ 62 | data.qpos.flat, data.qvel.flat, 63 | # data.cinert.flat, data.cvel.flat, 64 | # data.qfrc_actuator.flat, data.cfrc_ext.flat 65 | ]) 66 | return np.concatenate([ 67 | data.qpos.flat[2:], data.qvel.flat, data.cinert.flat, data.cvel.flat, 68 | data.qfrc_actuator.flat, data.cfrc_ext.flat 69 | ]) 70 | 71 | def compute_reward(self, ob, next_ob, action=None): 72 | xposbefore = ob[:, 0] 73 | yposbefore = ob[:, 1] 74 | xposafter = next_ob[:, 0] 75 | yposafter = next_ob[:, 1] 76 | 77 | forward_reward = (xposafter - xposbefore) / self.dt 78 | sideward_reward = (yposafter - yposbefore) / self.dt 79 | 80 | if action is not None: 81 | ctrl_cost = .5 * np.square(action).sum(axis=1) 82 | survive_reward = 1.0 83 | if self._task == "forward": 84 | reward = forward_reward - ctrl_cost + survive_reward 85 | elif self._task == "backward": 86 | reward = -forward_reward - ctrl_cost + survive_reward 87 | elif self._task == "left": 88 | reward = sideward_reward - ctrl_cost + survive_reward 89 | elif self._task == "right": 90 | reward = -sideward_reward - ctrl_cost + survive_reward 91 | elif self._task in ["goal", "follow_goals"]: 92 | reward = -np.linalg.norm( 93 | np.array([xposafter, yposafter]).T - self._goal, axis=1) 94 | elif self._task in ["sparse_goal"]: 95 | reward = (-np.linalg.norm( 96 | np.array([xposafter, yposafter]).T - self._goal, axis=1) > 97 | -0.3).astype(np.float32) 98 | return reward 99 | 100 | def step(self, a): 101 | pos_before = mass_center(self.sim) 102 | self.do_simulation(a, self.frame_skip) 103 | pos_after = mass_center(self.sim) 104 | alive_bonus = 5.0 105 | data = self.sim.data 106 | lin_vel_cost = 0.25 * ( 107 | pos_after - pos_before) / self.sim.model.opt.timestep 108 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum() 109 | quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum() 110 | quad_impact_cost = min(quad_impact_cost, 10) 111 | reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus 112 | 113 | if self._task == "follow_goals": 114 | xposafter = self.sim.data.qpos.flat[0] 115 | yposafter = self.sim.data.qpos.flat[1] 116 | reward = -np.linalg.norm(np.array([xposafter, yposafter]).T - self._goal) 117 | # update goal 118 | if np.abs(reward) < 0.5: 119 | self._goal = self._goal_list[0] 120 | self._goal_list = self._goal_list[1:] 121 | print("Goal Updated:", self._goal) 122 | 123 | elif self._task == "goal": 124 | xposafter = self.sim.data.qpos.flat[0] 125 | yposafter = self.sim.data.qpos.flat[1] 126 | reward = -np.linalg.norm(np.array([xposafter, yposafter]).T - self._goal) 127 | 128 | qpos = self.sim.data.qpos 129 | done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) 130 | return self._get_obs(), reward, done, dict( 131 | reward_linvel=lin_vel_cost, 132 | reward_quadctrl=-quad_ctrl_cost, 133 | reward_alive=alive_bonus, 134 | reward_impact=-quad_impact_cost) 135 | 136 | def reset_model(self): 137 | c = 0.01 138 | self.set_state( 139 | self.init_qpos + self.np_random.uniform( 140 | low=-c, high=c, size=self.sim.model.nq), 141 | self.init_qvel + self.np_random.uniform( 142 | low=-c, 143 | high=c, 144 | size=self.sim.model.nv, 145 | )) 146 | 147 | if self._task == "follow_goals": 148 | self._goal = self._goal_list[0] 149 | self._goal_list = self._goal_list[1:] 150 | print("Current goal:", self._goal) 151 | 152 | return self._get_obs() 153 | 154 | def viewer_setup(self): 155 | self.viewer.cam.distance = self.model.stat.extent * 2.0 156 | -------------------------------------------------------------------------------- /envs/gym_mujoco/point_mass.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import math 20 | import os 21 | 22 | from gym import utils 23 | import numpy as np 24 | from gym.envs.mujoco import mujoco_env 25 | 26 | 27 | # pylint: disable=missing-docstring 28 | class PointMassEnv(mujoco_env.MujocoEnv, utils.EzPickle): 29 | 30 | def __init__(self, 31 | target=None, 32 | wiggly_weight=0., 33 | alt_xml=False, 34 | expose_velocity=True, 35 | expose_goal=True, 36 | use_simulator=False, 37 | model_path='point.xml'): 38 | self._sample_target = target 39 | if self._sample_target is not None: 40 | self.goal = np.array([1.0, 1.0]) 41 | 42 | self._expose_velocity = expose_velocity 43 | self._expose_goal = expose_goal 44 | self._use_simulator = use_simulator 45 | self._wiggly_weight = abs(wiggly_weight) 46 | self._wiggle_direction = +1 if wiggly_weight > 0. else -1 47 | 48 | xml_path = "envs/assets/" 49 | model_path = os.path.abspath(os.path.join(xml_path, model_path)) 50 | 51 | if self._use_simulator: 52 | mujoco_env.MujocoEnv.__init__(self, model_path, 5) 53 | else: 54 | mujoco_env.MujocoEnv.__init__(self, model_path, 1) 55 | utils.EzPickle.__init__(self) 56 | 57 | def step(self, action): 58 | if self._use_simulator: 59 | self.do_simulation(action, self.frame_skip) 60 | else: 61 | force = 0.2 * action[0] 62 | rot = 1.0 * action[1] 63 | qpos = self.sim.data.qpos.flat.copy() 64 | qpos[2] += rot 65 | ori = qpos[2] 66 | dx = math.cos(ori) * force 67 | dy = math.sin(ori) * force 68 | qpos[0] = np.clip(qpos[0] + dx, -2, 2) 69 | qpos[1] = np.clip(qpos[1] + dy, -2, 2) 70 | qvel = self.sim.data.qvel.flat.copy() 71 | self.set_state(qpos, qvel) 72 | 73 | ob = self._get_obs() 74 | if self._sample_target is not None and self.goal is not None: 75 | reward = -np.linalg.norm(self.sim.data.qpos.flat[:2] - self.goal)**2 76 | else: 77 | reward = 0. 78 | 79 | if self._wiggly_weight > 0.: 80 | reward = (np.exp(-((-reward)**0.5))**(1. - self._wiggly_weight)) * ( 81 | max(self._wiggle_direction * action[1], 0)**self._wiggly_weight) 82 | done = False 83 | return ob, reward, done, None 84 | 85 | def _get_obs(self): 86 | new_obs = [self.sim.data.qpos.flat] 87 | if self._expose_velocity: 88 | new_obs += [self.sim.data.qvel.flat] 89 | if self._expose_goal and self.goal is not None: 90 | new_obs += [self.goal] 91 | return np.concatenate(new_obs) 92 | 93 | def reset_model(self): 94 | qpos = self.init_qpos + np.append( 95 | self.np_random.uniform(low=-.2, high=.2, size=2), 96 | self.np_random.uniform(-np.pi, np.pi, size=1)) 97 | qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * .01 98 | if self._sample_target is not None: 99 | self.goal = self._sample_target(qpos[:2]) 100 | self.set_state(qpos, qvel) 101 | return self._get_obs() 102 | 103 | # only works when goal is not exposed 104 | def set_qpos(self, state): 105 | qvel = np.copy(self.sim.data.qvel.flat) 106 | self.set_state(state, qvel) 107 | 108 | def viewer_setup(self): 109 | self.viewer.cam.distance = self.model.stat.extent * 0.5 110 | -------------------------------------------------------------------------------- /envs/hand_block.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import gym 17 | import os 18 | from gym import spaces 19 | from gym.envs.robotics.hand.manipulate import ManipulateEnv 20 | import mujoco_py 21 | 22 | MANIPULATE_BLOCK_XML = os.path.join('hand', 'manipulate_block.xml') 23 | 24 | class HandBlockCustomEnv(ManipulateEnv): 25 | def __init__(self, 26 | model_path=MANIPULATE_BLOCK_XML, 27 | target_position='random', 28 | target_rotation='xyz', 29 | reward_type='sparse', 30 | horizontal_wrist_constraint=1.0, 31 | vertical_wrist_constraint=1.0, 32 | **kwargs): 33 | ManipulateEnv.__init__(self, 34 | model_path=MANIPULATE_BLOCK_XML, 35 | target_position=target_position, 36 | target_rotation=target_rotation, 37 | target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), 38 | reward_type=reward_type, 39 | **kwargs) 40 | 41 | self._viewers = {} 42 | 43 | # constraining the movement of wrist (vertical movement more important than horizontal) 44 | self.action_space.low[0] = -horizontal_wrist_constraint 45 | self.action_space.high[0] = horizontal_wrist_constraint 46 | self.action_space.low[1] = -vertical_wrist_constraint 47 | self.action_space.high[1] = vertical_wrist_constraint 48 | 49 | def _get_viewer(self, mode): 50 | self.viewer = self._viewers.get(mode) 51 | if self.viewer is None: 52 | if mode == 'human': 53 | self.viewer = mujoco_py.MjViewer(self.sim) 54 | elif mode == 'rgb_array': 55 | self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, device_id=-1) 56 | self._viewer_setup() 57 | self._viewers[mode] = self.viewer 58 | return self.viewer 59 | 60 | def _viewer_setup(self): 61 | body_id = self.sim.model.body_name2id('robot0:palm') 62 | lookat = self.sim.data.body_xpos[body_id] 63 | for idx, value in enumerate(lookat): 64 | self.viewer.cam.lookat[idx] = value 65 | self.viewer.cam.distance = 0.5 66 | self.viewer.cam.azimuth = 55. 67 | self.viewer.cam.elevation = -25. 68 | 69 | def step(self, action): 70 | 71 | def is_on_palm(): 72 | self.sim.forward() 73 | cube_middle_idx = self.sim.model.site_name2id('object:center') 74 | cube_middle_pos = self.sim.data.site_xpos[cube_middle_idx] 75 | is_on_palm = (cube_middle_pos[2] > 0.04) 76 | return is_on_palm 77 | 78 | obs, reward, done, info = super().step(action) 79 | done = not is_on_palm() 80 | return obs, reward, done, info 81 | 82 | def render(self, mode='human', width=500, height=500): 83 | self._render_callback() 84 | if mode == 'rgb_array': 85 | self._get_viewer(mode).render(width, height) 86 | # window size used for old mujoco-py: 87 | data = self._get_viewer(mode).read_pixels(width, height, depth=False) 88 | # original image is upside-down, so flip it 89 | return data[::-1, :, :] 90 | elif mode == 'human': 91 | self._get_viewer(mode).render() 92 | -------------------------------------------------------------------------------- /envs/skill_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | 21 | import gym 22 | from gym import Wrapper 23 | 24 | class SkillWrapper(Wrapper): 25 | 26 | def __init__( 27 | self, 28 | env, 29 | # skill type and dimension 30 | num_latent_skills=None, 31 | skill_type='discrete_uniform', 32 | # execute an episode with the same predefined skill, does not resample 33 | preset_skill=None, 34 | # resample skills within episode 35 | min_steps_before_resample=10, 36 | resample_prob=0.): 37 | 38 | super(SkillWrapper, self).__init__(env) 39 | self._skill_type = skill_type 40 | if num_latent_skills is None: 41 | self._num_skills = 0 42 | else: 43 | self._num_skills = num_latent_skills 44 | self._preset_skill = preset_skill 45 | 46 | # attributes for controlling skill resampling 47 | self._min_steps_before_resample = min_steps_before_resample 48 | self._resample_prob = resample_prob 49 | 50 | if isinstance(self.env.observation_space, gym.spaces.Dict): 51 | size = self.env.observation_space.spaces['observation'].shape[0] + self._num_skills 52 | else: 53 | size = self.env.observation_space.shape[0] + self._num_skills 54 | self.observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(size,), dtype='float32') 55 | 56 | def _remake_time_step(self, cur_obs): 57 | if isinstance(self.env.observation_space, gym.spaces.Dict): 58 | cur_obs = cur_obs['observation'] 59 | 60 | if self._num_skills == 0: 61 | return cur_obs 62 | else: 63 | return np.concatenate([cur_obs, self.skill]) 64 | 65 | def _set_skill(self): 66 | if self._num_skills: 67 | if self._preset_skill is not None: 68 | self.skill = self._preset_skill 69 | print('Skill:', self.skill) 70 | elif self._skill_type == 'discrete_uniform': 71 | self.skill = np.random.multinomial( 72 | 1, [1. / self._num_skills] * self._num_skills) 73 | elif self._skill_type == 'gaussian': 74 | self.skill = np.random.multivariate_normal( 75 | np.zeros(self._num_skills), np.eye(self._num_skills)) 76 | elif self._skill_type == 'cont_uniform': 77 | self.skill = np.random.uniform( 78 | low=-1.0, high=1.0, size=self._num_skills) 79 | 80 | def reset(self): 81 | cur_obs = self.env.reset() 82 | self._set_skill() 83 | self._step_count = 0 84 | return self._remake_time_step(cur_obs) 85 | 86 | def step(self, action): 87 | cur_obs, reward, done, info = self.env.step(action) 88 | self._step_count += 1 89 | if self._preset_skill is None and self._step_count >= self._min_steps_before_resample and np.random.random( 90 | ) < self._resample_prob: 91 | self._set_skill() 92 | self._step_count = 0 93 | return self._remake_time_step(cur_obs), reward, done, info 94 | 95 | def close(self): 96 | return self.env.close() 97 | -------------------------------------------------------------------------------- /envs/video_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | 21 | import gym 22 | from gym import Wrapper 23 | from gym.wrappers.monitoring import video_recorder 24 | 25 | class VideoWrapper(Wrapper): 26 | 27 | def __init__(self, env, base_path, base_name=None, new_video_every_reset=False): 28 | super(VideoWrapper, self).__init__(env) 29 | 30 | self._base_path = base_path 31 | self._base_name = base_name 32 | 33 | self._new_video_every_reset = new_video_every_reset 34 | if self._new_video_every_reset: 35 | self._counter = 0 36 | self._recorder = None 37 | else: 38 | if self._base_name is not None: 39 | self._vid_name = os.path.join(self._base_path, self._base_name) 40 | else: 41 | self._vid_name = self._base_path 42 | self._recorder = video_recorder.VideoRecorder(self.env, path=self._vid_name + '.mp4') 43 | 44 | def reset(self): 45 | if self._new_video_every_reset: 46 | if self._recorder is not None: 47 | self._recorder.close() 48 | 49 | self._counter += 1 50 | if self._base_name is not None: 51 | self._vid_name = os.path.join(self._base_path, self._base_name + '_' + str(self._counter)) 52 | else: 53 | self._vid_name = self._base_path + '_' + str(self._counter) 54 | 55 | self._recorder = video_recorder.VideoRecorder(self.env, path=self._vid_name + '.mp4') 56 | 57 | return self.env.reset() 58 | 59 | def step(self, action): 60 | self._recorder.capture_frame() 61 | return self.env.step(action) 62 | 63 | def close(self): 64 | self._recorder.encoder.proc.stdin.flush() 65 | self._recorder.close() 66 | return self.env.close() -------------------------------------------------------------------------------- /lib/py_tf_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Converts TensorFlow Policies into Python Policies.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from absl import logging 21 | 22 | import tensorflow as tf 23 | from tf_agents.policies import py_policy 24 | from tf_agents.policies import tf_policy 25 | from tf_agents.specs import tensor_spec 26 | from tf_agents.trajectories import policy_step 27 | from tf_agents.utils import common 28 | from tf_agents.utils import nest_utils 29 | from tf_agents.utils import session_utils 30 | 31 | 32 | class PyTFPolicy(py_policy.Base, session_utils.SessionUser): 33 | """Exposes a Python policy as wrapper over a TF Policy.""" 34 | 35 | # TODO(damienv): currently, the initial policy state must be batched 36 | # if batch_size is given. Without losing too much generality, the initial 37 | # policy state could be the same for every element in the batch. 38 | # In that case, the initial policy state could be given with no batch 39 | # dimension. 40 | # TODO(sfishman): Remove batch_size param entirely. 41 | def __init__(self, policy, batch_size=None, seed=None): 42 | """Initializes a new `PyTFPolicy`. 43 | 44 | Args: 45 | policy: A TF Policy implementing `tf_policy.Base`. 46 | batch_size: (deprecated) 47 | seed: Seed to use if policy performs random actions (optional). 48 | """ 49 | if not isinstance(policy, tf_policy.Base): 50 | logging.warning('Policy should implement tf_policy.Base') 51 | 52 | if batch_size is not None: 53 | logging.warning('In PyTFPolicy constructor, `batch_size` is deprecated, ' 54 | 'this parameter has no effect. This argument will be ' 55 | 'removed on 2019-05-01') 56 | 57 | time_step_spec = tensor_spec.to_nest_array_spec(policy.time_step_spec) 58 | action_spec = tensor_spec.to_nest_array_spec(policy.action_spec) 59 | super(PyTFPolicy, self).__init__( 60 | time_step_spec, action_spec, policy_state_spec=(), info_spec=()) 61 | 62 | self._tf_policy = policy 63 | self.session = None 64 | 65 | self._policy_state_spec = tensor_spec.to_nest_array_spec( 66 | self._tf_policy.policy_state_spec) 67 | 68 | self._batch_size = None 69 | self._batched = None 70 | self._seed = seed 71 | self._built = False 72 | 73 | def _construct(self, batch_size, graph): 74 | """Construct the agent graph through placeholders.""" 75 | 76 | self._batch_size = batch_size 77 | self._batched = batch_size is not None 78 | 79 | outer_dims = [self._batch_size] if self._batched else [1] 80 | with graph.as_default(): 81 | self._time_step = tensor_spec.to_nest_placeholder( 82 | self._tf_policy.time_step_spec, outer_dims=outer_dims) 83 | self._tf_initial_state = self._tf_policy.get_initial_state( 84 | batch_size=self._batch_size or 1) 85 | 86 | self._policy_state = tf.nest.map_structure( 87 | lambda ps: tf.compat.v1.placeholder( # pylint: disable=g-long-lambda 88 | ps.dtype, 89 | ps.shape, 90 | name='policy_state'), 91 | self._tf_initial_state) 92 | self._action_step = self._tf_policy.action( 93 | self._time_step, self._policy_state, seed=self._seed) 94 | 95 | self._actions = tensor_spec.to_nest_placeholder( 96 | self._tf_policy.action_spec, outer_dims=outer_dims) 97 | self._action_distribution = self._tf_policy.distribution( 98 | self._time_step, policy_state=self._policy_state).action 99 | self._log_prob = common.log_probability(self._action_distribution, 100 | self._actions, 101 | self._tf_policy.action_spec) 102 | 103 | def initialize(self, batch_size, graph=None): 104 | if self._built: 105 | raise RuntimeError('PyTFPolicy can only be initialized once.') 106 | 107 | if not graph: 108 | graph = tf.compat.v1.get_default_graph() 109 | 110 | self._construct(batch_size, graph) 111 | var_list = tf.nest.flatten(self._tf_policy.variables()) 112 | common.initialize_uninitialized_variables(self.session, var_list) 113 | self._built = True 114 | 115 | def save(self, policy_dir=None, graph=None): 116 | if not self._built: 117 | raise RuntimeError('PyTFPolicy has not been initialized yet.') 118 | 119 | if not graph: 120 | graph = tf.compat.v1.get_default_graph() 121 | 122 | with graph.as_default(): 123 | global_step = tf.compat.v1.train.get_or_create_global_step() 124 | policy_checkpointer = common.Checkpointer( 125 | ckpt_dir=policy_dir, policy=self._tf_policy, global_step=global_step) 126 | policy_checkpointer.initialize_or_restore(self.session) 127 | with self.session.as_default(): 128 | policy_checkpointer.save(global_step) 129 | 130 | def restore(self, policy_dir, graph=None, assert_consumed=True): 131 | """Restores the policy from the checkpoint. 132 | 133 | Args: 134 | policy_dir: Directory with the checkpoint. 135 | graph: A graph, inside which policy the is restored (optional). 136 | assert_consumed: If true, contents of the checkpoint will be checked 137 | for a match against graph variables. 138 | 139 | Returns: 140 | step: Global step associated with the restored policy checkpoint. 141 | 142 | Raises: 143 | RuntimeError: if the policy is not initialized. 144 | AssertionError: if the checkpoint contains variables which do not have 145 | matching names in the graph, and assert_consumed is set to True. 146 | 147 | """ 148 | 149 | if not self._built: 150 | raise RuntimeError( 151 | 'PyTFPolicy must be initialized before being restored.') 152 | if not graph: 153 | graph = tf.compat.v1.get_default_graph() 154 | 155 | with graph.as_default(): 156 | global_step = tf.compat.v1.train.get_or_create_global_step() 157 | policy_checkpointer = common.Checkpointer( 158 | ckpt_dir=policy_dir, policy=self._tf_policy, global_step=global_step) 159 | status = policy_checkpointer.initialize_or_restore(self.session) 160 | with self.session.as_default(): 161 | if assert_consumed: 162 | status.assert_consumed() 163 | status.run_restore_ops() 164 | return self.session.run(global_step) 165 | 166 | def _build_from_time_step(self, time_step): 167 | outer_shape = nest_utils.get_outer_array_shape(time_step, 168 | self._time_step_spec) 169 | if len(outer_shape) == 1: 170 | self.initialize(outer_shape[0]) 171 | elif not outer_shape: 172 | self.initialize(None) 173 | else: 174 | raise ValueError( 175 | 'Cannot handle more than one outer dimension. Saw {} outer ' 176 | 'dimensions: {}'.format(len(outer_shape), outer_shape)) 177 | 178 | def _get_initial_state(self, batch_size): 179 | if not self._built: 180 | self.initialize(batch_size) 181 | if batch_size != self._batch_size: 182 | raise ValueError( 183 | '`batch_size` argument is different from the batch size provided ' 184 | 'previously. Expected {}, but saw {}.'.format(self._batch_size, 185 | batch_size)) 186 | return self.session.run(self._tf_initial_state) 187 | 188 | def _action(self, time_step, policy_state): 189 | if not self._built: 190 | self._build_from_time_step(time_step) 191 | 192 | batch_size = None 193 | if time_step.step_type.shape: 194 | batch_size = time_step.step_type.shape[0] 195 | if self._batch_size != batch_size: 196 | raise ValueError( 197 | 'The batch size of time_step is different from the batch size ' 198 | 'provided previously. Expected {}, but saw {}.'.format( 199 | self._batch_size, batch_size)) 200 | 201 | if not self._batched: 202 | # Since policy_state is given in a batched form from the policy and we 203 | # simply have to send it back we do not need to worry about it. Only 204 | # update time_step. 205 | time_step = nest_utils.batch_nested_array(time_step) 206 | 207 | tf.nest.assert_same_structure(self._time_step, time_step) 208 | feed_dict = {self._time_step: time_step} 209 | if policy_state is not None: 210 | # Flatten policy_state to handle specs that are not hashable due to lists. 211 | for state_ph, state in zip( 212 | tf.nest.flatten(self._policy_state), tf.nest.flatten(policy_state)): 213 | feed_dict[state_ph] = state 214 | 215 | action_step = self.session.run(self._action_step, feed_dict) 216 | action, state, info = action_step 217 | 218 | if not self._batched: 219 | action, info = nest_utils.unbatch_nested_array([action, info]) 220 | 221 | return policy_step.PolicyStep(action, state, info) 222 | 223 | def log_prob(self, time_step, action_step, policy_state=None): 224 | if not self._built: 225 | self._build_from_time_step(time_step) 226 | tf.nest.assert_same_structure(self._time_step, time_step) 227 | tf.nest.assert_same_structure(self._actions, action_step) 228 | feed_dict = {self._time_step: time_step, self._actions: action_step} 229 | if policy_state is not None: 230 | feed_dict[self._policy_state] = policy_state 231 | return self.session.run(self._log_prob, feed_dict) 232 | -------------------------------------------------------------------------------- /lib/py_uniform_replay_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Uniform replay buffer in Python. 16 | 17 | The base class provides all the functionalities of a uniform replay buffer: 18 | - add samples in a First In First Out way. 19 | - read samples uniformly. 20 | 21 | PyHashedReplayBuffer is a flavor of the base class which 22 | compresses the observations when the observations have some partial overlap 23 | (e.g. when using frame stacking). 24 | """ 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import threading 30 | 31 | import numpy as np 32 | import tensorflow as tf 33 | from tf_agents.replay_buffers import replay_buffer 34 | from tf_agents.specs import array_spec 35 | from tf_agents.utils import nest_utils 36 | from tf_agents.utils import numpy_storage 37 | 38 | 39 | class PyUniformReplayBuffer(replay_buffer.ReplayBuffer): 40 | """A Python-based replay buffer that supports uniform sampling. 41 | 42 | Writing and reading to this replay buffer is thread safe. 43 | 44 | This replay buffer can be subclassed to change the encoding used for the 45 | underlying storage by overriding _encoded_data_spec, _encode, _decode, and 46 | _on_delete. 47 | """ 48 | 49 | def __init__(self, data_spec, capacity): 50 | """Creates a PyUniformReplayBuffer. 51 | 52 | Args: 53 | data_spec: An ArraySpec or a list/tuple/nest of ArraySpecs describing a 54 | single item that can be stored in this buffer. 55 | capacity: The maximum number of items that can be stored in the buffer. 56 | """ 57 | super(PyUniformReplayBuffer, self).__init__(data_spec, capacity) 58 | 59 | self._storage = numpy_storage.NumpyStorage(self._encoded_data_spec(), 60 | capacity) 61 | self._lock = threading.Lock() 62 | self._np_state = numpy_storage.NumpyState() 63 | 64 | # Adding elements to the replay buffer is done in a circular way. 65 | # Keeps track of the actual size of the replay buffer and the location 66 | # where to add new elements. 67 | self._np_state.size = np.int64(0) 68 | self._np_state.cur_id = np.int64(0) 69 | 70 | # Total number of items that went through the replay buffer. 71 | self._np_state.item_count = np.int64(0) 72 | 73 | def _encoded_data_spec(self): 74 | """Spec of data items after encoding using _encode.""" 75 | return self._data_spec 76 | 77 | def _encode(self, item): 78 | """Encodes an item (before adding it to the buffer).""" 79 | return item 80 | 81 | def _decode(self, item): 82 | """Decodes an item.""" 83 | return item 84 | 85 | def _on_delete(self, encoded_item): 86 | """Do any necessary cleanup.""" 87 | pass 88 | 89 | @property 90 | def size(self): 91 | return self._np_state.size 92 | 93 | def _add_batch(self, items): 94 | outer_shape = nest_utils.get_outer_array_shape(items, self._data_spec) 95 | if outer_shape[0] != 1: 96 | raise NotImplementedError('PyUniformReplayBuffer only supports a batch ' 97 | 'size of 1, but received `items` with batch ' 98 | 'size {}.'.format(outer_shape[0])) 99 | 100 | item = nest_utils.unbatch_nested_array(items) 101 | with self._lock: 102 | if self._np_state.size == self._capacity: 103 | # If we are at capacity, we are deleting element cur_id. 104 | self._on_delete(self._storage.get(self._np_state.cur_id)) 105 | self._storage.set(self._np_state.cur_id, self._encode(item)) 106 | self._np_state.size = np.minimum(self._np_state.size + 1, 107 | self._capacity) 108 | self._np_state.cur_id = (self._np_state.cur_id + 1) % self._capacity 109 | self._np_state.item_count += 1 110 | 111 | def _get_next(self, 112 | sample_batch_size=None, 113 | num_steps=None, 114 | time_stacked=True): 115 | num_steps_value = num_steps if num_steps is not None else 1 116 | def get_single(): 117 | """Gets a single item from the replay buffer.""" 118 | with self._lock: 119 | if self._np_state.size <= 0: 120 | def empty_item(spec): 121 | return np.empty(spec.shape, dtype=spec.dtype) 122 | if num_steps is not None: 123 | item = [tf.nest.map_structure(empty_item, self.data_spec) 124 | for n in range(num_steps)] 125 | if time_stacked: 126 | item = nest_utils.stack_nested_arrays(item) 127 | else: 128 | item = tf.nest.map_structure(empty_item, self.data_spec) 129 | return item 130 | idx = np.random.randint(self._np_state.size - num_steps_value + 1) 131 | if self._np_state.size == self._capacity: 132 | # If the buffer is full, add cur_id (head of circular buffer) so that 133 | # we sample from the range [cur_id, cur_id + size - num_steps_value]. 134 | # We will modulo the size below. 135 | idx += self._np_state.cur_id 136 | 137 | if num_steps is not None: 138 | # TODO(b/120242830): Try getting data from numpy in one shot rather 139 | # than num_steps_value. 140 | item = [self._decode(self._storage.get((idx + n) % self._capacity)) 141 | for n in range(num_steps)] 142 | else: 143 | item = self._decode(self._storage.get(idx % self._capacity)) 144 | 145 | if num_steps is not None and time_stacked: 146 | item = nest_utils.stack_nested_arrays(item) 147 | return item 148 | 149 | if sample_batch_size is None: 150 | return get_single() 151 | else: 152 | samples = [get_single() for _ in range(sample_batch_size)] 153 | return nest_utils.stack_nested_arrays(samples) 154 | 155 | def _as_dataset(self, sample_batch_size=None, num_steps=None, 156 | num_parallel_calls=None): 157 | if num_parallel_calls is not None: 158 | raise NotImplementedError('PyUniformReplayBuffer does not support ' 159 | 'num_parallel_calls (must be None).') 160 | 161 | data_spec = self._data_spec 162 | if sample_batch_size is not None: 163 | data_spec = array_spec.add_outer_dims_nest( 164 | data_spec, (sample_batch_size,)) 165 | if num_steps is not None: 166 | data_spec = (data_spec,) * num_steps 167 | shapes = tuple(s.shape for s in tf.nest.flatten(data_spec)) 168 | dtypes = tuple(s.dtype for s in tf.nest.flatten(data_spec)) 169 | 170 | def generator_fn(): 171 | while True: 172 | if sample_batch_size is not None: 173 | batch = [self._get_next(num_steps=num_steps, time_stacked=False) 174 | for _ in range(sample_batch_size)] 175 | item = nest_utils.stack_nested_arrays(batch) 176 | else: 177 | item = self._get_next(num_steps=num_steps, time_stacked=False) 178 | yield tuple(tf.nest.flatten(item)) 179 | 180 | def time_stack(*structures): 181 | time_axis = 0 if sample_batch_size is None else 1 182 | return tf.nest.map_structure( 183 | lambda *elements: tf.stack(elements, axis=time_axis), *structures) 184 | 185 | ds = tf.data.Dataset.from_generator( 186 | generator_fn, dtypes, 187 | shapes).map(lambda *items: tf.nest.pack_sequence_as(data_spec, items)) 188 | if num_steps is not None: 189 | return ds.map(time_stack) 190 | else: 191 | return ds 192 | 193 | def _gather_all(self): 194 | data = [self._decode(self._storage.get(idx)) 195 | for idx in range(self._capacity)] 196 | stacked = nest_utils.stack_nested_arrays(data) 197 | batched = tf.nest.map_structure(lambda t: np.expand_dims(t, 0), stacked) 198 | return batched 199 | 200 | def _clear(self): 201 | self._np_state.size = np.int64(0) 202 | self._np_state.cur_id = np.int64(0) 203 | 204 | def gather_all_transitions(self): 205 | num_steps_value = 2 206 | 207 | def get_single(idx): 208 | """Gets the idx item from the replay buffer.""" 209 | with self._lock: 210 | if self._np_state.size <= idx: 211 | 212 | def empty_item(spec): 213 | return np.empty(spec.shape, dtype=spec.dtype) 214 | 215 | item = [ 216 | tf.nest.map_structure(empty_item, self.data_spec) 217 | for n in range(num_steps_value) 218 | ] 219 | item = nest_utils.stack_nested_arrays(item) 220 | return item 221 | 222 | if self._np_state.size == self._capacity: 223 | # If the buffer is full, add cur_id (head of circular buffer) so that 224 | # we sample from the range [cur_id, cur_id + size - num_steps_value]. 225 | # We will modulo the size below. 226 | idx += self._np_state.cur_id 227 | 228 | item = [ 229 | self._decode(self._storage.get((idx + n) % self._capacity)) 230 | for n in range(num_steps_value) 231 | ] 232 | 233 | item = nest_utils.stack_nested_arrays(item) 234 | return item 235 | 236 | samples = [ 237 | get_single(idx) 238 | for idx in range(self._np_state.size - num_steps_value + 1) 239 | ] 240 | return nest_utils.stack_nested_arrays(samples) 241 | -------------------------------------------------------------------------------- /unsupervised_skill_learning/dads_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """TF-Agents Class for DADS. Builds on top of the SAC agent.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | import sys 24 | sys.path.append(os.path.abspath('./')) 25 | 26 | import numpy as np 27 | import tensorflow as tf 28 | 29 | from tf_agents.agents.sac import sac_agent 30 | 31 | import skill_dynamics 32 | 33 | nest = tf.nest 34 | 35 | 36 | class DADSAgent(sac_agent.SacAgent): 37 | 38 | def __init__(self, 39 | save_directory, 40 | skill_dynamics_observation_size, 41 | observation_modify_fn=None, 42 | restrict_input_size=0, 43 | latent_size=2, 44 | latent_prior='cont_uniform', 45 | prior_samples=100, 46 | fc_layer_params=(256, 256), 47 | normalize_observations=True, 48 | network_type='default', 49 | num_mixture_components=4, 50 | fix_variance=True, 51 | skill_dynamics_learning_rate=3e-4, 52 | reweigh_batches=False, 53 | agent_graph=None, 54 | skill_dynamics_graph=None, 55 | *sac_args, 56 | **sac_kwargs): 57 | self._skill_dynamics_learning_rate = skill_dynamics_learning_rate 58 | self._latent_size = latent_size 59 | self._latent_prior = latent_prior 60 | self._prior_samples = prior_samples 61 | self._save_directory = save_directory 62 | self._restrict_input_size = restrict_input_size 63 | self._process_observation = observation_modify_fn 64 | 65 | if agent_graph is None: 66 | self._graph = tf.compat.v1.get_default_graph() 67 | else: 68 | self._graph = agent_graph 69 | 70 | if skill_dynamics_graph is None: 71 | skill_dynamics_graph = self._graph 72 | 73 | # instantiate the skill dynamics 74 | self._skill_dynamics = skill_dynamics.SkillDynamics( 75 | observation_size=skill_dynamics_observation_size, 76 | action_size=self._latent_size, 77 | restrict_observation=self._restrict_input_size, 78 | normalize_observations=normalize_observations, 79 | fc_layer_params=fc_layer_params, 80 | network_type=network_type, 81 | num_components=num_mixture_components, 82 | fix_variance=fix_variance, 83 | reweigh_batches=reweigh_batches, 84 | graph=skill_dynamics_graph) 85 | 86 | super(DADSAgent, self).__init__(*sac_args, **sac_kwargs) 87 | self._placeholders_in_place = False 88 | 89 | def compute_dads_reward(self, input_obs, cur_skill, target_obs): 90 | if self._process_observation is not None: 91 | input_obs, target_obs = self._process_observation( 92 | input_obs), self._process_observation(target_obs) 93 | 94 | num_reps = self._prior_samples if self._prior_samples > 0 else self._latent_size - 1 95 | input_obs_altz = np.concatenate([input_obs] * num_reps, axis=0) 96 | target_obs_altz = np.concatenate([target_obs] * num_reps, axis=0) 97 | 98 | # for marginalization of the denominator 99 | if self._latent_prior == 'discrete_uniform' and not self._prior_samples: 100 | alt_skill = np.concatenate( 101 | [np.roll(cur_skill, i, axis=1) for i in range(1, num_reps + 1)], 102 | axis=0) 103 | elif self._latent_prior == 'discrete_uniform': 104 | alt_skill = np.random.multinomial( 105 | 1, [1. / self._latent_size] * self._latent_size, 106 | size=input_obs_altz.shape[0]) 107 | elif self._latent_prior == 'gaussian': 108 | alt_skill = np.random.multivariate_normal( 109 | np.zeros(self._latent_size), 110 | np.eye(self._latent_size), 111 | size=input_obs_altz.shape[0]) 112 | elif self._latent_prior == 'cont_uniform': 113 | alt_skill = np.random.uniform( 114 | low=-1.0, high=1.0, size=(input_obs_altz.shape[0], self._latent_size)) 115 | 116 | logp = self._skill_dynamics.get_log_prob(input_obs, cur_skill, target_obs) 117 | 118 | # denominator may require more memory than that of a GPU, break computation 119 | split_group = 20 * 4000 120 | if input_obs_altz.shape[0] <= split_group: 121 | logp_altz = self._skill_dynamics.get_log_prob(input_obs_altz, alt_skill, 122 | target_obs_altz) 123 | else: 124 | logp_altz = [] 125 | for split_idx in range(input_obs_altz.shape[0] // split_group): 126 | start_split = split_idx * split_group 127 | end_split = (split_idx + 1) * split_group 128 | logp_altz.append( 129 | self._skill_dynamics.get_log_prob( 130 | input_obs_altz[start_split:end_split], 131 | alt_skill[start_split:end_split], 132 | target_obs_altz[start_split:end_split])) 133 | if input_obs_altz.shape[0] % split_group: 134 | start_split = input_obs_altz.shape[0] % split_group 135 | logp_altz.append( 136 | self._skill_dynamics.get_log_prob(input_obs_altz[-start_split:], 137 | alt_skill[-start_split:], 138 | target_obs_altz[-start_split:])) 139 | logp_altz = np.concatenate(logp_altz) 140 | logp_altz = np.array(np.array_split(logp_altz, num_reps)) 141 | 142 | # final DADS reward 143 | intrinsic_reward = np.log(num_reps + 1) - np.log(1 + np.exp( 144 | np.clip(logp_altz - logp.reshape(1, -1), -50, 50)).sum(axis=0)) 145 | 146 | return intrinsic_reward, {'logp': logp, 'logp_altz': logp_altz.flatten()} 147 | 148 | def get_experience_placeholder(self): 149 | self._placeholders_in_place = True 150 | self._placeholders = [] 151 | for item in nest.flatten(self.collect_data_spec): 152 | self._placeholders += [ 153 | tf.compat.v1.placeholder( 154 | item.dtype, 155 | shape=(None, 2) if len(item.shape) == 0 else 156 | (None, 2, item.shape[-1]), 157 | name=item.name) 158 | ] 159 | self._policy_experience_ph = nest.pack_sequence_as(self.collect_data_spec, 160 | self._placeholders) 161 | return self._policy_experience_ph 162 | 163 | def build_agent_graph(self): 164 | with self._graph.as_default(): 165 | self.get_experience_placeholder() 166 | self.agent_train_op = self.train(self._policy_experience_ph) 167 | self.summary_ops = tf.compat.v1.summary.all_v2_summary_ops() 168 | return self.agent_train_op 169 | 170 | def build_skill_dynamics_graph(self): 171 | self._skill_dynamics.make_placeholders() 172 | self._skill_dynamics.build_graph() 173 | self._skill_dynamics.increase_prob_op( 174 | learning_rate=self._skill_dynamics_learning_rate) 175 | 176 | def create_savers(self): 177 | self._skill_dynamics.create_saver( 178 | save_prefix=os.path.join(self._save_directory, 'dynamics')) 179 | 180 | def set_sessions(self, initialize_or_restore_skill_dynamics, session=None): 181 | if session is not None: 182 | self._session = session 183 | else: 184 | self._session = tf.compat.v1.Session(graph=self._graph) 185 | self._skill_dynamics.set_session( 186 | initialize_or_restore_variables=initialize_or_restore_skill_dynamics, 187 | session=session) 188 | 189 | def save_variables(self, global_step): 190 | self._skill_dynamics.save_variables(global_step=global_step) 191 | 192 | def _get_dict(self, trajectories, batch_size=-1): 193 | tf.nest.assert_same_structure(self.collect_data_spec, trajectories) 194 | if batch_size > 0: 195 | shuffled_batch = np.random.permutation( 196 | trajectories.observation.shape[0])[:batch_size] 197 | else: 198 | shuffled_batch = np.arange(trajectories.observation.shape[0]) 199 | 200 | return_dict = {} 201 | 202 | for placeholder, val in zip(self._placeholders, nest.flatten(trajectories)): 203 | return_dict[placeholder] = val[shuffled_batch] 204 | 205 | return return_dict 206 | 207 | def train_loop(self, 208 | trajectories, 209 | recompute_reward=False, 210 | batch_size=-1, 211 | num_steps=1): 212 | if not self._placeholders_in_place: 213 | return 214 | 215 | if recompute_reward: 216 | input_obs = trajectories.observation[:, 0, :-self._latent_size] 217 | cur_skill = trajectories.observation[:, 0, -self._latent_size:] 218 | target_obs = trajectories.observation[:, 1, :-self._latent_size] 219 | new_reward, info = self.compute_dads_reward(input_obs, cur_skill, 220 | target_obs) 221 | trajectories = trajectories._replace( 222 | reward=np.concatenate( 223 | [np.expand_dims(new_reward, axis=1), trajectories.reward[:, 1:]], 224 | axis=1)) 225 | 226 | # TODO(architsh):all agent specs should be the same as env specs, shift preprocessing to actor/critic networks 227 | if self._restrict_input_size > 0: 228 | trajectories = trajectories._replace( 229 | observation=trajectories.observation[:, :, 230 | self._restrict_input_size:]) 231 | 232 | for _ in range(num_steps): 233 | self._session.run([self.agent_train_op, self.summary_ops], 234 | feed_dict=self._get_dict( 235 | trajectories, batch_size=batch_size)) 236 | 237 | if recompute_reward: 238 | return new_reward, info 239 | else: 240 | return None, None 241 | 242 | @property 243 | def skill_dynamics(self): 244 | return self._skill_dynamics 245 | -------------------------------------------------------------------------------- /unsupervised_skill_learning/skill_discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Skill Discriminator Prediction and Training.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import numpy as np 23 | import tensorflow as tf 24 | import tensorflow_probability as tfp 25 | 26 | from tf_agents.distributions import tanh_bijector_stable 27 | 28 | class SkillDiscriminator: 29 | 30 | def __init__( 31 | self, 32 | observation_size, 33 | skill_size, 34 | skill_type, 35 | normalize_observations=False, 36 | # network properties 37 | fc_layer_params=(256, 256), 38 | fix_variance=False, 39 | input_type='diayn', 40 | # probably do not need to change these 41 | graph=None, 42 | scope_name='skill_discriminator'): 43 | 44 | self._observation_size = observation_size 45 | self._skill_size = skill_size 46 | self._skill_type = skill_type 47 | self._normalize_observations = normalize_observations 48 | 49 | # tensorflow requirements 50 | if graph is not None: 51 | self._graph = graph 52 | else: 53 | self._graph = tf.get_default_graph() 54 | self._scope_name = scope_name 55 | 56 | # discriminator network properties 57 | self._fc_layer_params = fc_layer_params 58 | self._fix_variance = fix_variance 59 | if not self._fix_variance: 60 | self._std_lower_clip = 0.3 61 | self._std_upper_clip = 10.0 62 | self._input_type = input_type 63 | 64 | self._use_placeholders = False 65 | self.log_probability = None 66 | self.disc_max_op = None 67 | self.disc_min_op = None 68 | self._session = None 69 | 70 | # saving/restoring variables 71 | self._saver = None 72 | 73 | def _get_distributions(self, out): 74 | if self._skill_type in ['gaussian', 'cont_uniform']: 75 | mean = tf.layers.dense( 76 | out, self._skill_size, name='mean', reuse=tf.AUTO_REUSE) 77 | if not self._fix_variance: 78 | stddev = tf.clip_by_value( 79 | tf.layers.dense( 80 | out, 81 | self._skill_size, 82 | activation=tf.nn.softplus, 83 | name='stddev', 84 | reuse=tf.AUTO_REUSE), self._std_lower_clip, 85 | self._std_upper_clip) 86 | else: 87 | stddev = tf.fill([tf.shape(out)[0], self._skill_size], 1.0) 88 | 89 | inference_distribution = tfp.distributions.MultivariateNormalDiag( 90 | loc=mean, scale_diag=stddev) 91 | 92 | if self._skill_type == 'gaussian': 93 | prior_distribution = tfp.distributions.MultivariateNormalDiag( 94 | loc=[0.] * self._skill_size, scale_diag=[1.] * self._skill_size) 95 | elif self._skill_type == 'cont_uniform': 96 | prior_distribution = tfp.distributions.Independent( 97 | tfp.distributions.Uniform( 98 | low=[-1.] * self._skill_size, high=[1.] * self._skill_size), 99 | reinterpreted_batch_ndims=1) 100 | 101 | # squash posterior to the right range of [-1, 1] 102 | bijectors = [] 103 | bijectors.append(tanh_bijector_stable.Tanh()) 104 | bijector_chain = tfp.bijectors.Chain(bijectors) 105 | inference_distribution = tfp.distributions.TransformedDistribution( 106 | distribution=inference_distribution, bijector=bijector_chain) 107 | 108 | elif self._skill_type == 'discrete_uniform': 109 | logits = tf.layers.dense( 110 | out, self._skill_size, name='logits', reuse=tf.AUTO_REUSE) 111 | inference_distribution = tfp.distributions.OneHotCategorical( 112 | logits=logits) 113 | prior_distribution = tfp.distributions.OneHotCategorical( 114 | probs=[1. / self._skill_size] * self._skill_size) 115 | elif self._skill_type == 'multivariate_bernoulli': 116 | print('Not supported yet') 117 | 118 | return inference_distribution, prior_distribution 119 | 120 | # simple dynamics graph 121 | def _default_graph(self, timesteps): 122 | out = timesteps 123 | for idx, layer_size in enumerate(self._fc_layer_params): 124 | out = tf.layers.dense( 125 | out, 126 | layer_size, 127 | activation=tf.nn.relu, 128 | name='hid_' + str(idx), 129 | reuse=tf.AUTO_REUSE) 130 | 131 | return self._get_distributions(out) 132 | 133 | def _get_dict(self, 134 | input_steps, 135 | target_skills, 136 | input_next_steps=None, 137 | batch_size=-1, 138 | batch_norm=False): 139 | if batch_size > 0: 140 | shuffled_batch = np.random.permutation(len(input_steps))[:batch_size] 141 | else: 142 | shuffled_batch = np.arange(len(input_steps)) 143 | 144 | batched_input = input_steps[shuffled_batch, :] 145 | batched_skills = target_skills[shuffled_batch, :] 146 | if self._input_type in ['diff', 'both']: 147 | batched_targets = input_next_steps[shuffled_batch, :] 148 | 149 | return_dict = { 150 | self.timesteps_pl: batched_input, 151 | self.skills_pl: batched_skills, 152 | } 153 | 154 | if self._input_type in ['diff', 'both']: 155 | return_dict[self.next_timesteps_pl] = batched_targets 156 | if self._normalize_observations: 157 | return_dict[self.is_training_pl] = batch_norm 158 | 159 | return return_dict 160 | 161 | def make_placeholders(self): 162 | self._use_placeholders = True 163 | with self._graph.as_default(), tf.variable_scope(self._scope_name): 164 | self.timesteps_pl = tf.placeholder( 165 | tf.float32, shape=(None, self._observation_size), name='timesteps_pl') 166 | self.skills_pl = tf.placeholder( 167 | tf.float32, shape=(None, self._skill_size), name='skills_pl') 168 | if self._input_type in ['diff', 'both']: 169 | self.next_timesteps_pl = tf.placeholder( 170 | tf.float32, 171 | shape=(None, self._observation_size), 172 | name='next_timesteps_pl') 173 | if self._normalize_observations: 174 | self.is_training_pl = tf.placeholder(tf.bool, name='batch_norm_pl') 175 | 176 | def set_session(self, session=None, initialize_or_restore_variables=False): 177 | if session is None: 178 | self._session = tf.Session(graph=self._graph) 179 | else: 180 | self._session = session 181 | 182 | # only initialize uninitialized variables 183 | if initialize_or_restore_variables: 184 | if tf.gfile.Exists(self._save_prefix): 185 | self.restore_variables() 186 | with self._graph.as_default(): 187 | is_initialized = self._session.run([ 188 | tf.compat.v1.is_variable_initialized(v) 189 | for key, v in self._variable_list.items() 190 | ]) 191 | uninitialized_vars = [] 192 | for flag, v in zip(is_initialized, self._variable_list.items()): 193 | if not flag: 194 | uninitialized_vars.append(v[1]) 195 | 196 | if uninitialized_vars: 197 | self._session.run( 198 | tf.compat.v1.variables_initializer(uninitialized_vars)) 199 | 200 | def build_graph(self, 201 | timesteps=None, 202 | skills=None, 203 | next_timesteps=None, 204 | is_training=None): 205 | with self._graph.as_default(), tf.variable_scope(self._scope_name): 206 | if self._use_placeholders: 207 | timesteps = self.timesteps_pl 208 | skills = self.skills_pl 209 | if self._input_type in ['diff', 'both']: 210 | next_timesteps = self.next_timesteps_pl 211 | if self._normalize_observations: 212 | is_training = self.is_training_pl 213 | 214 | # use deltas 215 | if self._input_type == 'both': 216 | next_timesteps -= timesteps 217 | timesteps = tf.concat([timesteps, next_timesteps], axis=1) 218 | if self._input_type == 'diff': 219 | timesteps = next_timesteps - timesteps 220 | 221 | if self._normalize_observations: 222 | timesteps = tf.layers.batch_normalization( 223 | timesteps, 224 | training=is_training, 225 | name='input_normalization', 226 | reuse=tf.AUTO_REUSE) 227 | 228 | inference_distribution, prior_distribution = self._default_graph( 229 | timesteps) 230 | 231 | self.log_probability = inference_distribution.log_prob(skills) 232 | self.prior_probability = prior_distribution.log_prob(skills) 233 | return self.log_probability, self.prior_probability 234 | 235 | def increase_prob_op(self, learning_rate=3e-4): 236 | with self._graph.as_default(): 237 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 238 | with tf.control_dependencies(update_ops): 239 | self.disc_max_op = tf.train.AdamOptimizer( 240 | learning_rate=learning_rate).minimize( 241 | -tf.reduce_mean(self.log_probability)) 242 | return self.disc_max_op 243 | 244 | def decrease_prob_op(self, learning_rate=3e-4): 245 | with self._graph.as_default(): 246 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 247 | with tf.control_dependencies(update_ops): 248 | self.disc_min_op = tf.train.AdamOptimizer( 249 | learning_rate=learning_rate).minimize( 250 | tf.reduce_mean(self.log_probability)) 251 | return self.disc_min_op 252 | 253 | # only useful when training use placeholders, otherwise use ops directly 254 | def train(self, 255 | timesteps, 256 | skills, 257 | next_timesteps=None, 258 | batch_size=512, 259 | num_steps=1, 260 | increase_probs=True): 261 | if not self._use_placeholders: 262 | return 263 | 264 | if increase_probs: 265 | run_op = self.disc_max_op 266 | else: 267 | run_op = self.disc_min_op 268 | 269 | for _ in range(num_steps): 270 | self._session.run( 271 | run_op, 272 | feed_dict=self._get_dict( 273 | timesteps, 274 | skills, 275 | input_next_steps=next_timesteps, 276 | batch_size=batch_size, 277 | batch_norm=True)) 278 | 279 | def get_log_probs(self, timesteps, skills, next_timesteps=None): 280 | if not self._use_placeholders: 281 | return 282 | 283 | return self._session.run([self.log_probability, self.prior_probability], 284 | feed_dict=self._get_dict( 285 | timesteps, 286 | skills, 287 | input_next_steps=next_timesteps, 288 | batch_norm=False)) 289 | 290 | def create_saver(self, save_prefix): 291 | if self._saver is not None: 292 | return self._saver 293 | else: 294 | with self._graph.as_default(): 295 | self._variable_list = {} 296 | for var in tf.get_collection( 297 | tf.GraphKeys.GLOBAL_VARIABLES, scope=self._scope_name): 298 | self._variable_list[var.name] = var 299 | self._saver = tf.train.Saver(self._variable_list, save_relative_paths=True) 300 | self._save_prefix = save_prefix 301 | 302 | def save_variables(self, global_step): 303 | if not tf.gfile.Exists(self._save_prefix): 304 | tf.gfile.MakeDirs(self._save_prefix) 305 | 306 | self._saver.save( 307 | self._session, 308 | os.path.join(self._save_prefix, 'ckpt'), 309 | global_step=global_step) 310 | 311 | def restore_variables(self): 312 | self._saver.restore(self._session, 313 | tf.train.latest_checkpoint(self._save_prefix)) 314 | -------------------------------------------------------------------------------- /unsupervised_skill_learning/skill_dynamics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dynamics Prediction and Training.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import numpy as np 23 | import tensorflow as tf 24 | import tensorflow_probability as tfp 25 | 26 | 27 | # TODO(architsh): Implement the dynamics with last K step input 28 | class SkillDynamics: 29 | 30 | def __init__( 31 | self, 32 | observation_size, 33 | action_size, 34 | restrict_observation=0, 35 | normalize_observations=False, 36 | # network properties 37 | fc_layer_params=(256, 256), 38 | network_type='default', 39 | num_components=1, 40 | fix_variance=False, 41 | reweigh_batches=False, 42 | graph=None, 43 | scope_name='skill_dynamics'): 44 | 45 | self._observation_size = observation_size 46 | self._action_size = action_size 47 | self._normalize_observations = normalize_observations 48 | self._restrict_observation = restrict_observation 49 | self._reweigh_batches = reweigh_batches 50 | 51 | # tensorflow requirements 52 | if graph is not None: 53 | self._graph = graph 54 | else: 55 | self._graph = tf.compat.v1.get_default_graph() 56 | self._scope_name = scope_name 57 | 58 | # dynamics network properties 59 | self._fc_layer_params = fc_layer_params 60 | self._network_type = network_type 61 | self._num_components = num_components 62 | self._fix_variance = fix_variance 63 | if not self._fix_variance: 64 | self._std_lower_clip = 0.3 65 | self._std_upper_clip = 10.0 66 | 67 | self._use_placeholders = False 68 | self.log_probability = None 69 | self.dyn_max_op = None 70 | self.dyn_min_op = None 71 | self._session = None 72 | self._use_modal_mean = False 73 | 74 | # saving/restoring variables 75 | self._saver = None 76 | 77 | def _get_distribution(self, out): 78 | if self._num_components > 1: 79 | self.logits = tf.compat.v1.layers.dense( 80 | out, self._num_components, name='logits', reuse=tf.compat.v1.AUTO_REUSE) 81 | means, scale_diags = [], [] 82 | for component_id in range(self._num_components): 83 | means.append( 84 | tf.compat.v1.layers.dense( 85 | out, 86 | self._observation_size, 87 | name='mean_' + str(component_id), 88 | reuse=tf.compat.v1.AUTO_REUSE)) 89 | if not self._fix_variance: 90 | scale_diags.append( 91 | tf.clip_by_value( 92 | tf.compat.v1.layers.dense( 93 | out, 94 | self._observation_size, 95 | activation=tf.nn.softplus, 96 | name='stddev_' + str(component_id), 97 | reuse=tf.compat.v1.AUTO_REUSE), self._std_lower_clip, 98 | self._std_upper_clip)) 99 | else: 100 | scale_diags.append( 101 | tf.fill([tf.shape(out)[0], self._observation_size], 1.0)) 102 | 103 | self.means = tf.stack(means, axis=1) 104 | self.scale_diags = tf.stack(scale_diags, axis=1) 105 | return tfp.distributions.MixtureSameFamily( 106 | mixture_distribution=tfp.distributions.Categorical( 107 | logits=self.logits), 108 | components_distribution=tfp.distributions.MultivariateNormalDiag( 109 | loc=self.means, scale_diag=self.scale_diags)) 110 | 111 | else: 112 | mean = tf.compat.v1.layers.dense( 113 | out, self._observation_size, name='mean', reuse=tf.compat.v1.AUTO_REUSE) 114 | if not self._fix_variance: 115 | stddev = tf.clip_by_value( 116 | tf.compat.v1.layers.dense( 117 | out, 118 | self._observation_size, 119 | activation=tf.nn.softplus, 120 | name='stddev', 121 | reuse=tf.compat.v1.AUTO_REUSE), self._std_lower_clip, 122 | self._std_upper_clip) 123 | else: 124 | stddev = tf.fill([tf.shape(out)[0], self._observation_size], 1.0) 125 | return tfp.distributions.MultivariateNormalDiag( 126 | loc=mean, scale_diag=stddev) 127 | 128 | # dynamics graph with separate pipeline for skills and timesteps 129 | def _graph_with_separate_skill_pipe(self, timesteps, actions): 130 | skill_out = actions 131 | with tf.compat.v1.variable_scope('action_pipe'): 132 | for idx, layer_size in enumerate((self._fc_layer_params[0] // 2,)): 133 | skill_out = tf.compat.v1.layers.dense( 134 | skill_out, 135 | layer_size, 136 | activation=tf.nn.relu, 137 | name='hid_' + str(idx), 138 | reuse=tf.compat.v1.AUTO_REUSE) 139 | 140 | ts_out = timesteps 141 | with tf.compat.v1.variable_scope('ts_pipe'): 142 | for idx, layer_size in enumerate((self._fc_layer_params[0] // 2,)): 143 | ts_out = tf.compat.v1.layers.dense( 144 | ts_out, 145 | layer_size, 146 | activation=tf.nn.relu, 147 | name='hid_' + str(idx), 148 | reuse=tf.compat.v1.AUTO_REUSE) 149 | 150 | # out = tf.compat.v1.layers.flatten(tf.einsum('ai,aj->aij', ts_out, skill_out)) 151 | out = tf.concat([ts_out, skill_out], axis=1) 152 | with tf.compat.v1.variable_scope('joint'): 153 | for idx, layer_size in enumerate(self._fc_layer_param[1:]): 154 | out = tf.compat.v1.layers.dense( 155 | out, 156 | layer_size, 157 | activation=tf.nn.relu, 158 | name='hid_' + str(idx), 159 | reuse=tf.compat.v1.AUTO_REUSE) 160 | 161 | return self._get_distribution(out) 162 | 163 | # simple dynamics graph 164 | def _default_graph(self, timesteps, actions): 165 | out = tf.concat([timesteps, actions], axis=1) 166 | for idx, layer_size in enumerate(self._fc_layer_params): 167 | out = tf.compat.v1.layers.dense( 168 | out, 169 | layer_size, 170 | activation=tf.nn.relu, 171 | name='hid_' + str(idx), 172 | reuse=tf.compat.v1.AUTO_REUSE) 173 | 174 | return self._get_distribution(out) 175 | 176 | def _get_dict(self, 177 | input_data, 178 | input_actions, 179 | target_data, 180 | batch_size=-1, 181 | batch_weights=None, 182 | batch_norm=False, 183 | noise_targets=False, 184 | noise_std=0.5): 185 | if batch_size > 0: 186 | shuffled_batch = np.random.permutation(len(input_data))[:batch_size] 187 | else: 188 | shuffled_batch = np.arange(len(input_data)) 189 | 190 | # if we are noising the input, it is better to create a new copy of the numpy arrays 191 | batched_input = input_data[shuffled_batch, :] 192 | batched_skills = input_actions[shuffled_batch, :] 193 | batched_targets = target_data[shuffled_batch, :] 194 | 195 | if self._reweigh_batches and batch_weights is not None: 196 | example_weights = batch_weights[shuffled_batch] 197 | 198 | if noise_targets: 199 | batched_targets += np.random.randn(*batched_targets.shape) * noise_std 200 | 201 | return_dict = { 202 | self.timesteps_pl: batched_input, 203 | self.actions_pl: batched_skills, 204 | self.next_timesteps_pl: batched_targets 205 | } 206 | if self._normalize_observations: 207 | return_dict[self.is_training_pl] = batch_norm 208 | if self._reweigh_batches and batch_weights is not None: 209 | return_dict[self.batch_weights] = example_weights 210 | 211 | return return_dict 212 | 213 | def _get_run_dict(self, input_data, input_actions): 214 | return_dict = { 215 | self.timesteps_pl: input_data, 216 | self.actions_pl: input_actions 217 | } 218 | if self._normalize_observations: 219 | return_dict[self.is_training_pl] = False 220 | 221 | return return_dict 222 | 223 | def make_placeholders(self): 224 | self._use_placeholders = True 225 | with self._graph.as_default(), tf.compat.v1.variable_scope(self._scope_name): 226 | self.timesteps_pl = tf.compat.v1.placeholder( 227 | tf.float32, shape=(None, self._observation_size), name='timesteps_pl') 228 | self.actions_pl = tf.compat.v1.placeholder( 229 | tf.float32, shape=(None, self._action_size), name='actions_pl') 230 | self.next_timesteps_pl = tf.compat.v1.placeholder( 231 | tf.float32, 232 | shape=(None, self._observation_size), 233 | name='next_timesteps_pl') 234 | if self._normalize_observations: 235 | self.is_training_pl = tf.compat.v1.placeholder(tf.bool, name='batch_norm_pl') 236 | if self._reweigh_batches: 237 | self.batch_weights = tf.compat.v1.placeholder( 238 | tf.float32, shape=(None,), name='importance_sampled_weights') 239 | 240 | def set_session(self, session=None, initialize_or_restore_variables=False): 241 | if session is None: 242 | self._session = tf.Session(graph=self._graph) 243 | else: 244 | self._session = session 245 | 246 | # only initialize uninitialized variables 247 | if initialize_or_restore_variables: 248 | if tf.io.gfile.exists(self._save_prefix): 249 | self.restore_variables() 250 | with self._graph.as_default(): 251 | var_list = tf.compat.v1.global_variables( 252 | ) + tf.compat.v1.local_variables() 253 | is_initialized = self._session.run( 254 | [tf.compat.v1.is_variable_initialized(v) for v in var_list]) 255 | uninitialized_vars = [] 256 | for flag, v in zip(is_initialized, var_list): 257 | if not flag: 258 | uninitialized_vars.append(v) 259 | 260 | if uninitialized_vars: 261 | self._session.run( 262 | tf.compat.v1.variables_initializer(uninitialized_vars)) 263 | 264 | def build_graph(self, 265 | timesteps=None, 266 | actions=None, 267 | next_timesteps=None, 268 | is_training=None): 269 | with self._graph.as_default(), tf.compat.v1.variable_scope( 270 | self._scope_name, reuse=tf.compat.v1.AUTO_REUSE): 271 | if self._use_placeholders: 272 | timesteps = self.timesteps_pl 273 | actions = self.actions_pl 274 | next_timesteps = self.next_timesteps_pl 275 | if self._normalize_observations: 276 | is_training = self.is_training_pl 277 | 278 | # predict deltas instead of observations 279 | next_timesteps -= timesteps 280 | 281 | if self._restrict_observation > 0: 282 | timesteps = timesteps[:, self._restrict_observation:] 283 | 284 | if self._normalize_observations: 285 | timesteps = tf.compat.v1.layers.batch_normalization( 286 | timesteps, 287 | training=is_training, 288 | name='input_normalization', 289 | reuse=tf.compat.v1.AUTO_REUSE) 290 | self.output_norm_layer = tf.compat.v1.layers.BatchNormalization( 291 | scale=False, center=False, name='output_normalization') 292 | next_timesteps = self.output_norm_layer( 293 | next_timesteps, training=is_training) 294 | 295 | if self._network_type == 'default': 296 | self.base_distribution = self._default_graph(timesteps, actions) 297 | elif self._network_type == 'separate': 298 | self.base_distribution = self._graph_with_separate_skill_pipe( 299 | timesteps, actions) 300 | 301 | # if building multiple times, be careful about which log_prob you are optimizing 302 | self.log_probability = self.base_distribution.log_prob(next_timesteps) 303 | self.mean = self.base_distribution.mean() 304 | 305 | return self.log_probability 306 | 307 | def increase_prob_op(self, learning_rate=3e-4, weights=None): 308 | with self._graph.as_default(): 309 | update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) 310 | with tf.control_dependencies(update_ops): 311 | if self._reweigh_batches: 312 | self.dyn_max_op = tf.compat.v1.train.AdamOptimizer( 313 | learning_rate=learning_rate, 314 | name='adam_max').minimize(-tf.reduce_mean(self.log_probability * 315 | self.batch_weights)) 316 | elif weights is not None: 317 | self.dyn_max_op = tf.compat.v1.train.AdamOptimizer( 318 | learning_rate=learning_rate, 319 | name='adam_max').minimize(-tf.reduce_mean(self.log_probability * 320 | weights)) 321 | else: 322 | self.dyn_max_op = tf.compat.v1.train.AdamOptimizer( 323 | learning_rate=learning_rate, 324 | name='adam_max').minimize(-tf.reduce_mean(self.log_probability)) 325 | 326 | return self.dyn_max_op 327 | 328 | def decrease_prob_op(self, learning_rate=3e-4, weights=None): 329 | with self._graph.as_default(): 330 | update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) 331 | with tf.control_dependencies(update_ops): 332 | if self._reweigh_batches: 333 | self.dyn_min_op = tf.compat.v1.train.AdamOptimizer( 334 | learning_rate=learning_rate, name='adam_min').minimize( 335 | tf.reduce_mean(self.log_probability * self.batch_weights)) 336 | elif weights is not None: 337 | self.dyn_min_op = tf.compat.v1.train.AdamOptimizer( 338 | learning_rate=learning_rate, name='adam_min').minimize( 339 | tf.reduce_mean(self.log_probability * weights)) 340 | else: 341 | self.dyn_min_op = tf.compat.v1.train.AdamOptimizer( 342 | learning_rate=learning_rate, 343 | name='adam_min').minimize(tf.reduce_mean(self.log_probability)) 344 | return self.dyn_min_op 345 | 346 | def create_saver(self, save_prefix): 347 | if self._saver is not None: 348 | return self._saver 349 | else: 350 | with self._graph.as_default(): 351 | self._variable_list = {} 352 | for var in tf.compat.v1.get_collection( 353 | tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=self._scope_name): 354 | self._variable_list[var.name] = var 355 | self._saver = tf.compat.v1.train.Saver( 356 | self._variable_list, save_relative_paths=True) 357 | self._save_prefix = save_prefix 358 | 359 | def save_variables(self, global_step): 360 | if not tf.io.gfile.exists(self._save_prefix): 361 | tf.io.gfile.makedirs(self._save_prefix) 362 | 363 | self._saver.save( 364 | self._session, 365 | os.path.join(self._save_prefix, 'ckpt'), 366 | global_step=global_step) 367 | 368 | def restore_variables(self): 369 | self._saver.restore(self._session, 370 | tf.compat.v1.train.latest_checkpoint(self._save_prefix)) 371 | 372 | # all functions here-on require placeholders---------------------------------- 373 | def train(self, 374 | timesteps, 375 | actions, 376 | next_timesteps, 377 | batch_weights=None, 378 | batch_size=512, 379 | num_steps=1, 380 | increase_probs=True): 381 | if not self._use_placeholders: 382 | return 383 | 384 | if increase_probs: 385 | run_op = self.dyn_max_op 386 | else: 387 | run_op = self.dyn_min_op 388 | 389 | for _ in range(num_steps): 390 | self._session.run( 391 | run_op, 392 | feed_dict=self._get_dict( 393 | timesteps, 394 | actions, 395 | next_timesteps, 396 | batch_weights=batch_weights, 397 | batch_size=batch_size, 398 | batch_norm=True)) 399 | 400 | def get_log_prob(self, timesteps, actions, next_timesteps): 401 | if not self._use_placeholders: 402 | return 403 | 404 | return self._session.run( 405 | self.log_probability, 406 | feed_dict=self._get_dict( 407 | timesteps, actions, next_timesteps, batch_norm=False)) 408 | 409 | def predict_state(self, timesteps, actions): 410 | if not self._use_placeholders: 411 | return 412 | 413 | if self._use_modal_mean: 414 | all_means, modal_mean_indices = self._session.run( 415 | [self.means, tf.argmax(self.logits, axis=1)], 416 | feed_dict=self._get_run_dict(timesteps, actions)) 417 | pred_state = all_means[[ 418 | np.arange(all_means.shape[0]), modal_mean_indices 419 | ]] 420 | else: 421 | pred_state = self._session.run( 422 | self.mean, feed_dict=self._get_run_dict(timesteps, actions)) 423 | 424 | if self._normalize_observations: 425 | with self._session.as_default(), self._graph.as_default(): 426 | mean_correction, variance_correction = self.output_norm_layer.get_weights( 427 | ) 428 | 429 | pred_state = pred_state * np.sqrt(variance_correction + 430 | 1e-3) + mean_correction 431 | 432 | pred_state += timesteps 433 | return pred_state 434 | --------------------------------------------------------------------------------