├── .gitignore ├── LICENSE ├── README.md ├── conf ├── agent │ └── ddpg_multimodal_skill_torch.yaml ├── benchmark │ └── dmc.yaml ├── config.yaml ├── hydra │ └── job_logging │ │ └── custom.yaml └── intrinsic │ └── multimodal_cic.yaml ├── core ├── __init__.py ├── agents │ ├── __init__.py │ ├── agent_base.py │ ├── ddpg.py │ ├── ddpg_multimodal_skill_torch.py │ └── ddpg_skill.py ├── calculations │ ├── __init__.py │ ├── augmentations.py │ ├── distributions.py │ ├── layers.py │ ├── losses.py │ ├── misc.py │ ├── params_utils.py │ └── skill_utils.py ├── custom_dmc_tasks │ ├── __init__.py │ ├── cheetah.py │ ├── cheetah.xml │ ├── hopper.py │ ├── hopper.xml │ ├── jaco.py │ ├── quadruped.py │ ├── quadruped.xml │ ├── walker.py │ └── walker.xml ├── data │ ├── __init__.py │ ├── replay_buffer.py │ └── replay_buffer_torch.py ├── envs │ ├── __init__.py │ ├── dmc.py │ ├── dmc_benchmark.py │ └── wrappers.py ├── exp_utils │ ├── __init__.py │ ├── checkpointing.py │ ├── loggers.py │ └── video.py └── intrinsic │ ├── __init__.py │ ├── cic.py │ ├── intrinsic_reward_base.py │ └── multimodal_cic.py ├── figures ├── MOSS_robot.png ├── fraction_rliable.png └── rliable.png ├── finetune_multimodal.py ├── helpers.py └── pretrain_multimodal.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # General 132 | .DS_Store 133 | .AppleDouble 134 | .LSOverride 135 | 136 | # Icon must end with two \r 137 | Icon 138 | 139 | # Thumbnails 140 | ._* 141 | 142 | # Files that might appear in the root of a volume 143 | .DocumentRevisions-V100 144 | .fseventsd 145 | .Spotlight-V100 146 | .TemporaryItems 147 | .Trashes 148 | .VolumeIcon.icns 149 | .com.apple.timemachine.donotpresent 150 | 151 | # Directories potentially created on remote AFP share 152 | .AppleDB 153 | .AppleDesktop 154 | Network Trash Folder 155 | Temporary Items 156 | .apdisk 157 | 158 | # Created by .ignore support plugin (hsz.mobi) 159 | ### Python template 160 | # Byte-compiled / optimized / DLL files 161 | __pycache__/ 162 | *.py[cod] 163 | *$py.class 164 | 165 | # C extensions 166 | *.so 167 | 168 | # Distribution / packaging 169 | .Python 170 | build/ 171 | develop-eggs/ 172 | dist/ 173 | downloads/ 174 | eggs/ 175 | .eggs/ 176 | lib/ 177 | lib64/ 178 | parts/ 179 | sdist/ 180 | var/ 181 | *.egg-info/ 182 | .installed.cfg 183 | *.egg 184 | 185 | # PyInstaller 186 | # Usually these files are written by a python script from a template 187 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 188 | *.manifest 189 | *.spec 190 | 191 | # Installer logs 192 | pip-log.txt 193 | pip-delete-this-directory.txt 194 | 195 | # Unit test / coverage reports 196 | htmlcov/ 197 | .tox/ 198 | .coverage 199 | .coverage.* 200 | .cache 201 | nosetests.xml 202 | coverage.xml 203 | *,cover 204 | .hypothesis/ 205 | 206 | # Translations 207 | *.mo 208 | *.pot 209 | 210 | # Django stuff: 211 | *.log 212 | local_settings.py 213 | 214 | # Flask stuff: 215 | instance/ 216 | .webassets-cache 217 | 218 | # Scrapy stuff: 219 | .scrapy 220 | 221 | # Sphinx documentation 222 | docs/_build/ 223 | 224 | # PyBuilder 225 | target/ 226 | 227 | # IPython Notebook 228 | .ipynb_checkpoints 229 | 230 | # pyenv 231 | .python-version 232 | 233 | # celery beat schedule file 234 | celerybeat-schedule 235 | 236 | # dotenv 237 | #.env 238 | 239 | # virtualenv 240 | #venv/ 241 | #ENV/ 242 | 243 | # Spyder project settings 244 | .spyderproject 245 | 246 | # Rope project settings 247 | .ropeproject 248 | ### VirtualEnv template 249 | # Virtualenv 250 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 251 | .Python 252 | [Bb]in 253 | [Ii]nclude 254 | [Ll]ib 255 | [Ll]ib64 256 | [Ll]ocal 257 | [Ss]cripts 258 | pyvenv.cfg 259 | .venv 260 | pip-selfcheck.json 261 | ### JetBrains template 262 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 263 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 264 | 265 | # User-specific stuff: 266 | .idea/workspace.xml 267 | .idea/tasks.xml 268 | .idea/dictionaries 269 | .idea/vcs.xml 270 | .idea/jsLibraryMappings.xml 271 | 272 | # Sensitive or high-churn files: 273 | .idea/dataSources.ids 274 | .idea/dataSources.xml 275 | .idea/dataSources.local.xml 276 | .idea/sqlDataSources.xml 277 | .idea/dynamic.xml 278 | .idea/uiDesigner.xml 279 | 280 | # Gradle: 281 | .idea/gradle.xml 282 | .idea/libraries 283 | 284 | # Mongo Explorer plugin: 285 | .idea/mongoSettings.xml 286 | 287 | .idea/ 288 | 289 | ## File-based project format: 290 | *.iws 291 | 292 | ## Plugin-specific files: 293 | 294 | # IntelliJ 295 | /out/ 296 | 297 | # mpeltonen/sbt-idea plugin 298 | .idea_modules/ 299 | 300 | # JIRA plugin 301 | atlassian-ide-plugin.xml 302 | 303 | # Crashlytics plugin (for Android Studio and IntelliJ) 304 | com_crashlytics_export_strings.xml 305 | crashlytics.properties 306 | crashlytics-build.properties 307 | fabric.properties 308 | 309 | outputs/ 310 | commands.md 311 | testing.ipynb 312 | testing_wrapper.py 313 | is_pretrain*/ 314 | .vscode* 315 | script/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixture Of SurpriseS (MOSS) 2 | This repo contains the official [**Jax/Haiku**](https://github.com/google/jax) code for A Mixture Of Surprises for Unsupervised Reinforcement Learning. [[arxiv]](https://arxiv.org/abs/2210.06702) 3 | 4 | ## Introduction 5 | 6 | ![moss_robot](figures/MOSS_robot.png) 7 | 8 | We investigated a method that uses mixture of skills to alleviate the assumptions needed on the environment for unsupervised reinforcement learning. 9 | 10 | ## Results 11 | 12 | [RLiable](https://github.com/google-research/rliable) Aggregated Results on the [Unsupervised Reinforcement Learning Benchmark](https://github.com/rll-research/url_benchmark) 13 | 14 | ![rliable](figures/rliable.png) 15 | 16 | ## Numerical Results 17 | 18 | | Domain | | Walker | | | | Quadruped | | | | Jaco | | 19 | |----------------------------------------------------|-----------------|-----------------|----------------|----------------|-----------------|----------------|-----------------|-----------------|----------------|-----------------|----------------| 20 | | Method\Task | Flip | Run | Stand | Walk | Jump | Run | Stand | Walk | Bottom Left | Bottom Right | Top Left | 21 | | ICM | 381±10 | 180±15 | 868±30 | 568±38 | 337±18 | 221±14 | 452±15 | 234±18 | 112±7 | 94±5 | 90±6 | 22 | | Disagreement | 313±8 | 166±9 | 658±33 | 453±37 | 512±14 | 395±12 | 686±30 | 358±25 | 120±7 | 132±5 | 111±10 | 23 | | RND | 412±18 | 267±18 | 842±19 | 694±26 | **681±11** | 455±7 | 875±25 | 581±42 | 106±6 | 111±6 | 83±7 | 24 | | ICM APT | 596±24 | 491±18 | 949±3 | 850±22 | 508±44 | 390±24 | 676±44 | 464±52 | 114±5 | 120±3 | 116±4 | 25 | | IND APT | 576±20 | 467±21 | 947±4 | 888±19 | 542±34 | 328±18 | 605±32 | 367±24 | 126±5 | 131±4 | 109±6 | 26 | | Proto | 378±4 | 225±16 | 828±24 | 610±40 | 426±32 | 310±22 | 702±59 | 348±55 | 130±12 | 131±11 | 134±12 | 27 | | AS-Bob | 475±16 | 247±23 | 917±36 | 675±21 | 449±24 | 285±23 | 594±37 | 353±39 | 116±21 | **166±12** | 143±12 | 28 | | AS-Alice | 491±20 | 211±9 | 868±47 | 655±36 | 415±20 | 296±18 | 590±41 | 337±17 | 109±20 | 141±19 | 140±17 | 29 | | SMM | 428±8 | 345±31 | 924±9 | 731±43 | 271±35 | 222±23 | 388±51 | 167±20 | 52±5 | 55±2 | 53±2 | 30 | | DIAYN | 306±12 | 146±7 | 631±46 | 394±22 | 491±38 | 325±21 | 662±38 | 273±19 | 35±5 | 35±6 | 23±3 | 31 | | APS | 355±18 | 166±15 | 667±56 | 500±40 | 283±22 | 206±16 | 379±31 | 192±17 | 61±6 | 79±12 | 51±5 | 32 | | CIC | 715±40 | **535±25** | **968±2** | 914±12 | 541±31 | 376±19 | 717±46 | 460±36 | 147±8 | 150±6 | 145±9 | 33 | | MOSS (Ours) | **729±40** | 531±20 | 962±3 | **942±5** | 674±11 | **485±6** | **911±11** | **635±36** | **151±5** | 150±5 | **150±5** | 34 | 35 | 36 | ## Get Started 37 | ### Pretraining 38 | ``` 39 | # example for pretraining on the jaco domain 40 | python pretrain_multimodal.py \ 41 | reward_free=true \ 42 | agent=ddpg_multimodal_skill_torch \ 43 | agent.skill_mode=sign \ 44 | agent.partitions=1.5 \ 45 | agent.skills_cfg.update_skill_every=50 \ 46 | intrinsic=multimodal_cic \ 47 | intrinsic.temperature=0.5 \ 48 | intrinsic.network_cfg.skill_dim=64\ 49 | intrinsic.knn_entropy_config.minus_mean=true \ 50 | benchmark=dmc \ 51 | benchmark.task=jaco_reach_top_left \ 52 | seed=0 \ 53 | wandb_note=moss_pretrain_base_sign 54 | ``` 55 | ### Finetuning 56 | ``` 57 | # example for finetuning on the jaco domain 58 | python finetune_multimodal.py \ 59 | reward_free=false \ 60 | agent=ddpg_multimodal_skill_torch \ 61 | intrinsic.network_cfg.skill_dim=64 \ 62 | agent.search_mode=constant \ 63 | benchmark=dmc \ 64 | benchmark.task=jaco_reach_top_left \ 65 | seed=0 \ 66 | checkpoint=../../../../is_pretrain_True/jaco_reach_top_left/0/moss_pretrain_base_sign/checkpoints/2000000.pth \ 67 | num_finetune_frames=100000 \ 68 | wandb_note=moss_finetune_base_sign 69 | ``` 70 | 71 | ## Contact 72 | 73 | If you have any question, please feel free to contact the authors. Andrew Zhao: [zqc21@mails.tsinghua.edu.cn](mailto:zqc21@mails.tsinghua.edu.cn). 74 | 75 | ## Acknowledgment 76 | 77 | Our code is based on [Contrastive Intrinsic Control](https://github.com/rll-research/cic) and [URL Benchmark](https://github.com/rll-research/url_benchmark). 78 | 79 | ## Citation 80 | 81 | If you find our work is useful in your research, please consider citing: 82 | 83 | ```bibtex 84 | @article{zhao2022mixture, 85 | title={A Mixture of Surprises for Unsupervised Reinforcement Learning}, 86 | author={Zhao, Andrew and Lin, Matthieu Gaetan and Li, Yangguang and Liu, Yong-Jin and Huang, Gao}, 87 | journal={arXiv preprint arXiv:2210.06702}, 88 | year={2022} 89 | } 90 | ``` -------------------------------------------------------------------------------- /conf/agent/ddpg_multimodal_skill_torch.yaml: -------------------------------------------------------------------------------- 1 | _target_: core.agents.ddpg_multimodal_skill_torch.DDPGAgentMultiModalSkill 2 | 3 | action_type: continuous # [continuous, discrete] 4 | to_jit: true 5 | stddev_schedule: 0.2 6 | stddev_clip: 0.3 7 | critic_target_tau: 0.01 8 | l2_weight: 0.0 9 | #lr_encoder: 1e-4 10 | lr_actor: 1e-4 11 | lr_critic: 1e-4 12 | network_cfg: 13 | obs_type: ${benchmark.obs_type} 14 | action_shape: ??? 15 | feature_dim: 50 16 | hidden_dim: 1024 17 | ln_config: 18 | axis: -1 19 | create_scale: True 20 | create_offset: True 21 | 22 | # replay buffer 23 | replay_buffer_cfg: 24 | nstep: 3 25 | replay_buffer_size: 1000000 26 | batch_size: 1024 #2048 # 27 | discount: 0.99 28 | num_workers: 4 29 | skill_dim: ${intrinsic.network_cfg.skill_dim} 30 | 31 | search_mode: grid_search 32 | skill_mode: half 33 | partitions: 2 # 4 34 | reward_free: ${reward_free} 35 | # additional for skill based DDPG 36 | skills_cfg: 37 | update_skill_every: 50 38 | skill_dim: ${intrinsic.network_cfg.skill_dim} 39 | -------------------------------------------------------------------------------- /conf/benchmark/dmc.yaml: -------------------------------------------------------------------------------- 1 | task: quadruped_walk 2 | obs_type: states # [states, pixels] 3 | frame_stack: 3 4 | action_repeat: 1 5 | seed: ${seed} 6 | reward_scale: 1.0 7 | -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - agent: ddpg_multimodal_skill_torch 4 | - intrinsic: multimodal_cic 5 | - benchmark: dmc 6 | - override hydra/job_logging: custom 7 | 8 | 9 | # mode 10 | reward_free: false 11 | 12 | # train settings 13 | num_pretrain_frames: 2000000 14 | num_finetune_frames: 4001 #100000 15 | num_seed_frames: 4000 16 | update_every_steps: 2 17 | 18 | # eval 19 | eval_every_frames: 10000 20 | num_eval_episodes: 10 21 | 22 | # wandb 23 | log_params_to_wandb_every: 100000 24 | run_id: '1' # used for resuming 25 | resume: allow 26 | use_wandb: false 27 | wandb_note: 'entropy_calc' 28 | wandb_project_name: unsupervisedRL 29 | 30 | # misc 31 | seed: 0 32 | save_video: true 33 | save_train_video: false 34 | checkpoint: null 35 | save_dir: checkpoints 36 | snapshots: [100000, 500000, 1000000, 2000000] 37 | 38 | 39 | hydra: 40 | run: 41 | dir: is_pretrain_${reward_free}/${benchmark.task}/${seed}/${wandb_note} 42 | -------------------------------------------------------------------------------- /conf/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | simple: 4 | format: "[%(asctime)s](%(filename)s %(lineno)d): %(message)s" 5 | colored: 6 | (): colorlog.ColoredFormatter 7 | # format: "[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s" 8 | format: '%(green)s[%(asctime)s](%(filename)s %(lineno)d): %(white)s%(message)s' 9 | handlers: 10 | console: # console handler 11 | class: logging.StreamHandler 12 | level: INFO 13 | formatter: colored 14 | stream: ext://sys.stdout 15 | file: # file handler 16 | class: logging.FileHandler 17 | formatter: colored 18 | level: INFO 19 | filename: output.log 20 | loggers: # parents 21 | finetune: 22 | level: INFO 23 | handlers: [console, file] 24 | propagate: no 25 | root: # default one 26 | level: INFO #DEBUG 27 | handlers: [console, file] 28 | 29 | disable_existing_loggers: false -------------------------------------------------------------------------------- /conf/intrinsic/multimodal_cic.yaml: -------------------------------------------------------------------------------- 1 | name: multimodal_cic 2 | _target_: core.intrinsic.MultimodalCICReward 3 | lr: 1e-4 4 | to_jit: true 5 | temperature: 0.5 6 | network_cfg: 7 | hidden_dim: 1024 8 | skill_dim: 32 9 | project_skill: true 10 | knn_entropy_config: 11 | knn_clip: 0.0005 12 | knn_k: 16 13 | knn_avg: true 14 | knn_rm: true 15 | minus_mean: true -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/MOSS/534820b5cb5713389f416f60c867de505e791166/core/__init__.py -------------------------------------------------------------------------------- /core/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, NamedTuple 2 | 3 | import hydra 4 | 5 | from .agent_base import Agent 6 | from .ddpg import DDPGAgent, DDPGTrainState 7 | from .ddpg_skill import DDPGAgentSkill 8 | from .ddpg_multimodal_skill_torch import DDPGAgentMultiModalSkill 9 | 10 | 11 | def make_agent(obs_type, action_shape, agent_cfg): 12 | if agent_cfg.action_type == 'continuous': 13 | return make_continuous_agent(action_shape, agent_cfg) 14 | elif agent_cfg.action_type == 'discrete': 15 | return make_discrete_agent(obs_type, action_shape, agent_cfg) 16 | else: 17 | raise NotImplementedError 18 | 19 | def make_continuous_agent(action_shape, agent_cfg): 20 | agent_cfg.network_cfg.action_shape = action_shape 21 | return hydra.utils.instantiate(agent_cfg) 22 | 23 | def make_discrete_agent(obs_type: str, action_shape: Tuple[int], cfg): 24 | cfg.network_cfg.obs_type = obs_type 25 | cfg.network_cfg.action_shape = action_shape 26 | return hydra.utils.instantiate(cfg) 27 | -------------------------------------------------------------------------------- /core/agents/agent_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class Agent(ABC): 4 | 5 | @abstractmethod 6 | def init_params(self, 7 | init_key, 8 | dummy_obs, 9 | summarize = True 10 | ): 11 | raise NotImplementedError 12 | 13 | @abstractmethod 14 | def select_action(self, *args, **kwargs): 15 | """act function""" 16 | raise NotImplementedError 17 | 18 | @abstractmethod 19 | def update(self, *args, **kwargs): 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def get_meta_specs(self, *args, **kwargs): 24 | raise NotImplementedError 25 | 26 | @abstractmethod 27 | def init_meta(self, *args, **kwargs): 28 | raise NotImplementedError 29 | 30 | @abstractmethod 31 | def update_meta(self, *args, **kwargs): 32 | raise NotImplementedError 33 | 34 | @abstractmethod 35 | def init_replay_buffer(self, *args, **kwargs): 36 | raise NotImplementedError 37 | 38 | @abstractmethod 39 | def store_timestep(self, *args, **kwargs): 40 | raise NotImplementedError 41 | 42 | @abstractmethod 43 | def sample_timesteps(self, *args, **kwargs): 44 | raise NotImplementedError 45 | -------------------------------------------------------------------------------- /core/agents/ddpg_multimodal_skill_torch.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | import dm_env 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | 10 | from core import agents 11 | from core.envs import wrappers 12 | from core import calculations 13 | 14 | 15 | def episode_partition_mode_selector( 16 | time_step: wrappers.InformativeTimeStep, 17 | partitions: int = 2, 18 | ) -> bool: 19 | """ 20 | True for control 21 | False for explore 22 | """ 23 | max_time_step = time_step.max_timestep 24 | interval = max_time_step // partitions 25 | current_timestep = time_step.current_timestep 26 | return bool(current_timestep // interval % 2) 27 | 28 | 29 | def get_meta_specs(skill_dim: int, 30 | reward_free: bool 31 | ) -> Tuple: 32 | # noinspection PyRedundantParentheses 33 | if reward_free: 34 | return ( 35 | dm_env.specs.Array((skill_dim,), np.float32, 'skill'), 36 | dm_env.specs.Array((), np.bool, 'mode') # for pytorch replay buffer 37 | ) 38 | else: 39 | return ( 40 | dm_env.specs.Array((skill_dim,), np.float32, 'skill'), 41 | ) 42 | 43 | def init_meta(key, 44 | time_step: wrappers.InformativeTimeStep, 45 | reward_free: bool, 46 | skill_dim: int, 47 | partitions: int, 48 | search_mode = 'random_grid_search', 49 | skill_mode = 'half', 50 | skill_tracker: calculations.skill_utils.SkillRewardTracker=None, 51 | step: int = None, 52 | ) -> Tuple[OrderedDict, Any]: 53 | """ 54 | :param key: only parameter needed in forward pass 55 | :param reward_free: defined as a constant during init of ddpg skill 56 | :param step: global step, at a certain step it only outputs the best skill 57 | :param skill_dim: defined as a constant during init of ddpg skill 58 | :param time_step: used to get current step in the episode for mode only 59 | :param skill_tracker: keep track in a NamedTuple of the best skill 60 | :return: during pretrain runing meta with skill and mode. Finetune return best skill in skill_tracker 61 | """ 62 | meta = OrderedDict() 63 | if reward_free: 64 | # mode_key, skill_key = jax.random.split(key) 65 | skill_key = key 66 | # mode = bool(jax.random.bernoulli(key=mode_key, p=0.4)) 67 | mode = episode_partition_mode_selector(time_step=time_step, partitions=partitions) 68 | if skill_mode == 'half': 69 | first_half_dim = int(skill_dim / 2) 70 | second_half_dim = skill_dim - first_half_dim 71 | zero = jnp.zeros(shape=(first_half_dim,), dtype=jnp.float32) 72 | uniform = jax.random.uniform(skill_key, shape=(second_half_dim,), minval=0., maxval=1.) 73 | if mode: 74 | skill = jnp.concatenate([zero, uniform]) 75 | else: 76 | skill = jnp.concatenate([uniform, zero]) 77 | elif skill_mode == 'sign': 78 | sign = -1. if mode else 1. 79 | skill = jax.random.uniform(skill_key, shape=(skill_dim,), minval=0., maxval=1.) * sign 80 | elif skill_mode == 'same': 81 | skill = jax.random.uniform(skill_key, shape=(skill_dim,), minval=0., maxval=1.) 82 | elif skill_mode == 'discrete': 83 | sign = -1. if mode else 1. 84 | skill = jnp.ones((skill_dim,)) * sign 85 | # sign = 0. if mode else 1. 86 | # skill = jnp.ones(shape=(skill_dim,), dtype=jnp.float32) * sign 87 | meta['mode'] = mode 88 | else: 89 | # outputs best skill after exploration loop 90 | # use constant skill function for baseline 91 | if search_mode == 'random_grid_search': 92 | skill = calculations.skill_utils.random_grid_search_skill( 93 | skill_dim=skill_dim, 94 | global_timestep=step, 95 | skill_tracker=skill_tracker, 96 | key=key 97 | ) 98 | elif search_mode == 'grid_search': 99 | skill = calculations.skill_utils.grid_search_skill( 100 | skill_dim=skill_dim, 101 | global_timestep=step, 102 | skill_tracker=skill_tracker, 103 | ) 104 | elif search_mode == 'random_search': 105 | skill = calculations.skill_utils.random_search_skill( 106 | skill_dim=skill_dim, 107 | global_timestep=step, 108 | skill_tracker=skill_tracker, 109 | key=key 110 | ) 111 | elif search_mode == 'constant': 112 | skill = calculations.skill_utils.constant_fixed_skill( 113 | skill_dim=skill_dim, 114 | ) 115 | elif search_mode == 'explore': 116 | skill = jnp.ones((skill_dim,)) 117 | 118 | elif search_mode == 'control': 119 | skill = -jnp.ones((skill_dim,)) 120 | 121 | if skill_tracker.update: 122 | # first step 123 | if skill_tracker.score_step == 0: 124 | pass 125 | elif skill_tracker.score_sum / skill_tracker.score_step > skill_tracker.best_score: 126 | skill_tracker = skill_tracker._replace( 127 | best_skill=skill_tracker.current_skill, 128 | best_score=skill_tracker.score_sum / skill_tracker.score_step 129 | ) 130 | skill_tracker = skill_tracker._replace( 131 | score_sum=0., 132 | score_step=0 133 | ) 134 | # skill = jnp.ones(skill_dim, dtype=jnp.float32) * 0.5 135 | skill_tracker = skill_tracker._replace(current_skill=skill) 136 | 137 | meta['skill'] = skill 138 | 139 | return meta, skill_tracker 140 | 141 | class DDPGAgentMultiModalSkill(agents.DDPGAgentSkill): 142 | 143 | """Implement DDPG with skills""" 144 | def __init__(self, 145 | skills_cfg, 146 | reward_free: bool, 147 | search_mode, 148 | skill_mode, 149 | partitions, 150 | **kwargs 151 | ): 152 | super().__init__( 153 | skills_cfg, 154 | reward_free, 155 | **kwargs 156 | ) 157 | # init in exploration mode 158 | self._mode = bool(0) 159 | 160 | to_jit = jax.jit if kwargs['to_jit'] else lambda x: x 161 | 162 | self.get_meta_specs = partial( 163 | get_meta_specs, skill_dim=skills_cfg.skill_dim, reward_free=reward_free 164 | ) 165 | self.init_meta = partial( 166 | init_meta, 167 | partitions=partitions, 168 | reward_free=reward_free, 169 | skill_dim=skills_cfg.skill_dim, 170 | search_mode=search_mode, 171 | skill_mode=skill_mode 172 | ) 173 | 174 | def update_meta(self, 175 | key: jax.random.PRNGKey, 176 | meta: OrderedDict, 177 | step: int, 178 | update_skill_every: int, 179 | time_step, 180 | skill_tracker=None, 181 | ) -> Tuple[OrderedDict, Any]: 182 | if step % update_skill_every == 0: 183 | return self.init_meta(key, step=step, skill_tracker=skill_tracker, time_step=time_step) 184 | return meta, skill_tracker -------------------------------------------------------------------------------- /core/agents/ddpg_skill.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | import dm_env 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | 10 | from core import agents 11 | from core import calculations 12 | 13 | def init_meta(key, 14 | reward_free: bool, 15 | skill_dim: int, 16 | skill_tracker: calculations.skill_utils.SkillRewardTracker=None, 17 | step: int = None) -> Tuple[OrderedDict, Any]: 18 | 19 | meta = OrderedDict() 20 | if reward_free: 21 | skill = jax.random.uniform(key, shape=(skill_dim, ), minval=0., maxval=1.) 22 | 23 | else: 24 | # outputs best skill after exploration loop 25 | # use constant skill function for baseline 26 | skill = calculations.skill_utils.grid_search_skill( 27 | skill_dim=skill_dim, 28 | global_timestep=step, 29 | skill_tracker=skill_tracker, 30 | ) 31 | if skill_tracker.update: 32 | # first step 33 | if skill_tracker.score_step == 0: 34 | pass 35 | elif skill_tracker.score_sum / skill_tracker.score_step > skill_tracker.best_score: 36 | skill_tracker = skill_tracker._replace( 37 | best_skill=skill_tracker.current_skill, 38 | best_score=skill_tracker.score_sum / skill_tracker.score_step 39 | ) 40 | skill_tracker = skill_tracker._replace( 41 | score_sum=0., 42 | score_step=0 43 | ) 44 | # skill = jnp.ones(skill_dim, dtype=jnp.float32) * 0.5 45 | skill_tracker = skill_tracker._replace(current_skill=skill) 46 | 47 | meta['skill'] = skill 48 | return meta, skill_tracker 49 | 50 | 51 | def get_meta_specs(skill_dim: int) -> Tuple: 52 | """ 53 | Each element of the tuple represent one spec for a particular element 54 | """ 55 | # noinspection PyRedundantParentheses 56 | return (dm_env.specs.Array((skill_dim,), np.float32, 'skill'),) 57 | 58 | 59 | class DDPGAgentSkill(agents.DDPGAgent): 60 | 61 | """Implement DDPG with skills""" 62 | def __init__(self, 63 | skills_cfg, 64 | reward_free: bool, 65 | **kwargs 66 | ): 67 | super(DDPGAgentSkill, self).__init__(**kwargs) 68 | self.get_meta_specs = partial(get_meta_specs, skill_dim=skills_cfg.skill_dim) 69 | self.init_meta = partial( 70 | init_meta, 71 | reward_free=reward_free, 72 | skill_dim=skills_cfg.skill_dim, 73 | ) 74 | self.update_meta = partial(self.update_meta, update_skill_every=skills_cfg.update_skill_every) 75 | self.init_params = partial( 76 | self.init_params, 77 | obs_type=kwargs['network_cfg'].obs_type 78 | ) 79 | 80 | def init_params(self, 81 | init_key: jax.random.PRNGKey, 82 | dummy_obs: jnp.ndarray, 83 | summarize: bool = True, 84 | checkpoint_state = None, 85 | **kwargs 86 | ): 87 | """ 88 | :param init_key: 89 | :param dummy_obs: 90 | :param summarize: 91 | :param checkpoint_state: 92 | :return: 93 | """ 94 | skill = jnp.empty(self.get_meta_specs()[0].shape) 95 | dummy_obs = jnp.concatenate([dummy_obs, skill], axis=-1) 96 | state = super().init_params(init_key=init_key, 97 | dummy_obs=dummy_obs, 98 | summarize=summarize, 99 | checkpoint_state=checkpoint_state) 100 | return state 101 | 102 | def update_meta(self, 103 | key: jax.random.PRNGKey, 104 | meta: OrderedDict, 105 | step: int, 106 | update_skill_every: int, 107 | time_step=None, 108 | skill_tracker=None, 109 | ) -> Tuple[OrderedDict, Any]: 110 | 111 | if step % update_skill_every == 0: 112 | return self.init_meta(key, step=step, skill_tracker=skill_tracker) 113 | return meta, skill_tracker 114 | 115 | -------------------------------------------------------------------------------- /core/calculations/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import l2_loss, particle_based_entropy, noise_contrastive_loss, cpc_loss, l2_loss_without_bias, softmax_probabilities 2 | from .layers import Identity, trunk, linear_relu, default_linear_init, feature_extractor, mlp, mlp_bottlneck 3 | from .distributions import TruncNormal 4 | from .params_utils import polyak_averaging 5 | from .misc import schedule 6 | from .skill_utils import random_search_skill, constant_fixed_skill, grid_search_skill, random_grid_search_skill -------------------------------------------------------------------------------- /core/calculations/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | 8 | def _random_flip_single_image(image, rng): 9 | _, flip_rng = jax.random.split(rng) 10 | should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5 11 | image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x) 12 | return image 13 | 14 | 15 | def random_flip(images, rng): 16 | rngs = jax.random.split(rng, images.shape[0]) 17 | return jax.vmap(_random_flip_single_image)(images, rngs) 18 | 19 | 20 | def random_shift_aug(x: jnp.ndarray): 21 | """x: [N, H, W, C]""" 22 | x = x.astype(dtype=jnp.float32) 23 | n, h, w, c = x.shape 24 | assert h == w 25 | 26 | return jax.lax.stop_gradient(x) 27 | 28 | class RandomShiftsAug(nn.Module): 29 | def __init__(self, pad): 30 | super().__init__() 31 | self.pad = pad 32 | 33 | def forward(self, x): 34 | x = x.float() 35 | n, c, h, w = x.size() 36 | assert h == w 37 | padding = tuple([self.pad] * 4) 38 | x = F.pad(x, padding, 'replicate') 39 | eps = 1.0 / (h + 2 * self.pad) 40 | arange = torch.linspace(-1.0 + eps, 41 | 1.0 - eps, 42 | h + 2 * self.pad, 43 | device=x.device, 44 | dtype=x.dtype)[:h] 45 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 46 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 47 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 48 | 49 | shift = torch.randint(0, 50 | 2 * self.pad + 1, 51 | size=(n, 1, 1, 2), 52 | device=x.device, 53 | dtype=x.dtype) 54 | shift *= 2.0 / (h + 2 * self.pad) 55 | 56 | grid = base_grid + shift 57 | return F.grid_sample(x, 58 | grid, 59 | padding_mode='zeros', 60 | align_corners=False) -------------------------------------------------------------------------------- /core/calculations/distributions.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | class Distribution: 6 | """ 7 | Abstract base class for probability distribution 8 | """ 9 | def __init__(self, batch_shape, event_shape): 10 | self._batch_shape = batch_shape 11 | self._event_shape = event_shape 12 | 13 | def sample(self, sample_shape): 14 | pass 15 | 16 | class TruncNormal: 17 | def __init__(self, loc, scale, low=-1.0, high=1.0): 18 | """ Trunc from -1 to 1 for DMC action space 19 | :param loc: mean (N, action_dim) 20 | :param scale: stddev () 21 | :param low: clamp to low 22 | :param high: clamp to high 23 | """ 24 | self.low = low 25 | self.high = high 26 | self.loc = loc 27 | self.scale = scale 28 | self.eps = 1e-6 29 | 30 | def mean(self): 31 | return self.loc 32 | 33 | def sample(self, 34 | clip=None, 35 | *, 36 | seed: jax.random.PRNGKey, 37 | # sample_shape: Sequence[int] = (), 38 | ): 39 | """Samples an event. 40 | 41 | Args: 42 | clip: implements clipped noise in DrQ-v2 43 | seed: PRNG key or integer seed. 44 | 45 | Returns: 46 | A sample of shape `sample_shape` + `batch_shape` + `event_shape`. 47 | """ 48 | sample_shape = self.loc.shape 49 | noise = jax.random.normal(seed, sample_shape) # has to be same shape as loc which specifies the mean for each individual Gaussians 50 | noise *= self.scale 51 | 52 | if clip is not None: 53 | # clip N(0, var) of exploration schedule in DrQ-v2 54 | noise = jnp.clip(noise, a_min=-clip, a_max=clip) 55 | x = self.loc + noise 56 | # return jnp.clip(x, a_min=self.low, a_max=self.high) 57 | clamped_x = jnp.clip(x, a_min=self.low + self.eps, a_max=self.high - self.eps) 58 | x = x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(clamped_x) # trick to backprop on x without clamping affecting it 59 | return x 60 | # 61 | # class TruncNormal(distrax.Normal): 62 | # def __init__(self, loc, scale, low=-1.0, high=1.0): 63 | # """ Trunc from -1 to 1 for DMC action space 64 | # :param loc: mean 65 | # :param scale: stddev 66 | # :param low: 67 | # :param high: 68 | # :param eps: 69 | # """ 70 | # super(TruncNormal, self).__init__(loc=loc, scale=scale) 71 | # 72 | # self.low = low 73 | # self.high = high 74 | # # self.eps = eps 75 | # 76 | # def _clamp(self, x): 77 | # """ Clamping method for TruncNormal""" 78 | # clamped_x = jnp.clip(x, self.low, self.high) 79 | # x = x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(clamped_x) 80 | # return x 81 | # 82 | # def sample(self, 83 | # clip=None, 84 | # *, 85 | # seed, #: Union[IntLike, PRNGKey], 86 | # sample_shape = (),#: Union[IntLike, Sequence[IntLike]] = () 87 | # ): 88 | # """Samples an event. 89 | # 90 | # Args: 91 | # clip: implements clipped noise in DrQ-v2 92 | # seed: PRNG key or integer seed. 93 | # sample_shape: Additional leading dimensions for sample. 94 | # 95 | # Returns: 96 | # A sample of shape `sample_shape` + `batch_shape` + `event_shape`. 97 | # """ 98 | # # this line check if rng is a PRNG key and sample_shape a tuple if not it converts them. 99 | # # rng, sample_shape = convert_seed_and_sample_shape(seed, sample_shape) 100 | # num_samples = functools.reduce(operator.mul, sample_shape, 1) # product 101 | # 102 | # eps = self._sample_from_std_normal(seed, num_samples) 103 | # scale = jnp.expand_dims(self._scale, range(eps.ndim - self._scale.ndim)) 104 | # loc = jnp.expand_dims(self._loc, range(eps.ndim - self._loc.ndim)) 105 | # 106 | # eps *= scale 107 | # if clip is not None: 108 | # # clip N(0, var) of exploration schedule in DrQ-v2 109 | # eps = jnp.clip(eps, a_min=-clip, a_max=clip) 110 | # samples = loc + eps 111 | # samples = self._clamp(samples) 112 | # return samples.reshape(sample_shape + samples.shape[1:]) 113 | 114 | # 115 | # import torch 116 | # from torch import distributions as pyd 117 | # from torch.distributions.utils import _standard_normal 118 | # 119 | # 120 | # class TruncatedNormal(pyd.Normal): 121 | # def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 122 | # super().__init__(loc, scale, validate_args=False) 123 | # self.low = low 124 | # self.high = high 125 | # self.eps = eps 126 | # 127 | # def _clamp(self, x): 128 | # clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 129 | # x = x - x.detach() + clamped_x.detach() 130 | # return x 131 | # 132 | # def sample(self, clip=None, sample_shape=torch.Size()): 133 | # shape = self._extended_shape(sample_shape) 134 | # eps = _standard_normal(shape, 135 | # dtype=self.loc.dtype, 136 | # device=self.loc.device) 137 | # eps *= self.scale 138 | # if clip is not None: 139 | # eps = torch.clamp(eps, -clip, clip) 140 | # x = self.loc + eps 141 | # return self._clamp(x) 142 | # 143 | # if __name__ == "__main__": 144 | # truncNormal = TruncNormal(jnp.ones((3,)), 1.) 145 | # samples_jax = truncNormal.sample(clip=2, seed=jax.random.PRNGKey(666)) 146 | # 147 | # torchtruncNormal = TruncatedNormal(torch.ones(3), 1.) 148 | # samples_torch = torchtruncNormal.sample(clip=2) 149 | # 150 | # print(samples_jax, samples_torch) 151 | # [[0.96648777 1.] 152 | # [0.4025777 1.] 153 | # [-0.59399736 154 | # 1.]] -------------------------------------------------------------------------------- /core/calculations/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Mapping, Union 2 | 3 | import haiku as hk 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | FloatStrOrBool = Union[str, float, bool] 8 | default_linear_init = hk.initializers.Orthogonal() 9 | 10 | class Identity(hk.Module): 11 | def __init__(self, name = 'identity'): 12 | super(Identity, self).__init__(name=name) 13 | 14 | def __call__(self, inputs): 15 | return inputs 16 | 17 | def trunk(ln_config: Mapping[str, FloatStrOrBool], feature_dim: int, name='trunk') -> Callable: 18 | """Layer""" 19 | return hk.Sequential([ 20 | hk.Linear(output_size=feature_dim, w_init=default_linear_init, name='trunk_linear'), 21 | hk.LayerNorm(**ln_config, name='trunk_ln'), 22 | jax.nn.tanh 23 | ], name=name) 24 | 25 | def linear_relu(dim: int, name='linear_relu') -> Callable: 26 | """Layer""" 27 | return hk.Sequential([ 28 | hk.Linear(output_size=dim, w_init=default_linear_init), #TODO pass it as argument 29 | jax.nn.relu 30 | ], name=name) 31 | 32 | def mlp(dim: int, out_dim: int, name='mlp') -> Callable: 33 | return hk.Sequential( 34 | [linear_relu(dim=dim), 35 | linear_relu(dim=dim), 36 | hk.Linear(out_dim, w_init=default_linear_init) 37 | ], 38 | name=name 39 | ) 40 | 41 | def mlp_bottlneck(dim: int, out_dim: int, name='mlp') -> Callable: 42 | return hk.Sequential( 43 | [linear_relu(dim=dim // 2), 44 | linear_relu(dim=dim), 45 | hk.Linear(out_dim, w_init=default_linear_init) 46 | ], 47 | name=name 48 | ) 49 | 50 | def feature_extractor(obs: jnp.ndarray, obs_type: str, name='encoder') -> jnp.ndarray: 51 | """encoder""" 52 | if obs_type == 'pixels': 53 | encoder = hk.Sequential([ 54 | lambda x: x / 255.0 - 0.5, #FIXME put on GPU instead of CPU 55 | hk.Conv2D(output_channels=32, kernel_shape=3, stride=2, padding='VALID'), 56 | jax.nn.relu, 57 | hk.Conv2D(output_channels=32, kernel_shape=3, stride=1, padding='VALID'), 58 | jax.nn.relu, 59 | hk.Conv2D(output_channels=32, kernel_shape=3, stride=1, padding='VALID'), 60 | jax.nn.relu, 61 | hk.Conv2D(output_channels=32, kernel_shape=3, stride=1, padding='VALID'), 62 | jax.nn.relu, 63 | hk.Flatten(preserve_dims=-3) # [N, H, W, C] -> [N, -1] 64 | ], name=name) 65 | else: 66 | encoder = Identity() 67 | 68 | return encoder(inputs=obs) 69 | 70 | 71 | if __name__ == "__main__": 72 | def network(obs): 73 | def make_q(name): 74 | return hk.Sequential([ 75 | linear_relu(10), 76 | hk.Linear(1, w_init=default_linear_init) 77 | ], name) 78 | 79 | q1 = make_q(name='q1') 80 | q2 = make_q(name='q2') # q1 neq q2 81 | return q1(obs), q2(obs) 82 | 83 | forward = hk.without_apply_rng(hk.transform(network)) 84 | key = jax.random.PRNGKey(2) 85 | obs = jnp.ones((1, 10)) 86 | state = forward.init(rng=key, obs=obs) # state = state2 87 | state_2 = forward.init(rng=key, obs=obs) 88 | print(state) -------------------------------------------------------------------------------- /core/calculations/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import jax.numpy as jnp 5 | import jax 6 | import torch 7 | import chex 8 | import haiku as hk 9 | import tree 10 | 11 | 12 | def l2_loss(preds: jnp.ndarray, 13 | targets: jnp.ndarray = None 14 | ) -> jnp.ndarray: 15 | """Compute l2 loss if target not provided computes l2 loss with target 0""" 16 | if targets is None: 17 | targets = jnp.zeros_like(preds) 18 | chex.assert_type([preds, targets], float) 19 | return 0.5 * (preds - targets)**2 20 | 21 | def l2_loss_without_bias(params: hk.Params): 22 | l2_params = [p for ((module_name, x), p) in tree.flatten_with_path(params) if x == 'w'] 23 | return 0.5 * sum(jnp.sum(jnp.square(p)) for p in l2_params) 24 | 25 | 26 | def running_stats( 27 | mean: jnp.ndarray, 28 | std: jnp.ndarray, 29 | x: jnp.ndarray, 30 | num: float, 31 | ): 32 | bs = x.shape[0] 33 | delta = jnp.mean(x, axis=0) - mean 34 | new_mean = mean + delta * bs / (num + bs) 35 | new_std = (std * num + jnp.var(x, axis=0) * bs + 36 | (delta**2) * num * bs / (num + bs)) / (num + bs) 37 | return new_mean, new_std, num + bs 38 | 39 | 40 | def particle_based_entropy(source: jnp.ndarray, 41 | target: jnp.ndarray, 42 | knn_clip: float = 0.0005, # todo remove for minimization 43 | knn_k: int = 16, 44 | knn_avg: bool = True, 45 | knn_rm: bool = True, 46 | minus_mean: bool = True, 47 | mean: jnp.ndarray = None, 48 | std: jnp.ndarray = None, 49 | num: float = None, 50 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]: 51 | """ Implement Particle Based Entropy Estimator as in APT 52 | :param knn_rm: 53 | :param mean: mean for running mean 54 | :param knn_clip: 55 | :param knn_k: hyperparameter k 56 | :param knn_avg: whether to take the average over k nearest neighbors 57 | :param source: value to compute entropy over [b1, c] 58 | :param target: value to compute entropy over [b1, c] 59 | :return: entropy of rep # (b1, 1) 60 | """ 61 | # source = target = rep #[b1, c] [b2, c] 62 | 63 | b1, b2 = source.shape[0], target.shape[0] 64 | # (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2) 65 | sim_matrix = jnp.linalg.norm( 66 | source[:, None, :].reshape(b1, 1, -1) - target[None, :, :].reshape(1, b2, -1), 67 | axis=-1, 68 | ord=2 69 | ) 70 | # take the min of the sim_matrix to get largest=False 71 | reward, _ = jax.lax.top_k( 72 | operand=-sim_matrix, #(b1, b2) 73 | k=knn_k 74 | ) 75 | reward = -reward 76 | 77 | if not knn_avg: # only keep k-th nearest neighbor 78 | reward = reward[:, -1] 79 | reward = reward.reshape(-1, 1) # (b1 * k, 1) 80 | if knn_rm: 81 | mean, std, num = running_stats(mean, std, reward, num) 82 | if minus_mean: 83 | reward = (reward - mean) / std 84 | else: 85 | reward = reward/ std 86 | reward = jnp.maximum( 87 | reward - knn_clip, 88 | jnp.zeros_like(reward) 89 | ) 90 | else: # average over all k nearest neigbors 91 | reward = reward.reshape(-1, 1) #(b1 * k, 1) 92 | if knn_rm: 93 | mean, std, num = running_stats(mean, std, reward, num) 94 | if minus_mean: 95 | reward = (reward - mean) / std 96 | else: 97 | reward = reward / std 98 | if knn_clip >= 0.0: 99 | reward = jnp.maximum( 100 | reward - knn_clip, 101 | jnp.zeros_like(reward) 102 | ) 103 | reward = reward.reshape((b1, knn_k)) 104 | reward = jnp.mean(reward, axis=1, keepdims=True) # (b1, 1) 105 | 106 | reward = jnp.log(reward + 1.0) 107 | return reward, mean, std, num 108 | 109 | 110 | 111 | def log_sum_exp(logits: jnp.ndarray): 112 | return jnp.log( 113 | jnp.sum( 114 | jnp.exp(logits),# [N, C] 115 | axis=-1 116 | ) # [N] 117 | ) 118 | 119 | def normalize(x): 120 | return x / (jnp.linalg.norm(x=x, ord=2, axis=-1, keepdims=True) + 1e-12) 121 | # jnp.sqrt(jnp.sum(jnp.square(normalize(a)))) 122 | 123 | def noise_contrastive_loss( 124 | query, 125 | key, 126 | temperature = 0.5 127 | ): 128 | """ 129 | s_i - \sum \exp s_i 130 | """ 131 | query = normalize(query) 132 | key = normalize(key) 133 | logits = query @ key.T #(N, N) positive pairs on the diagonal 134 | logits = logits / temperature 135 | shifted_cov = logits - jax.lax.stop_gradient(logits.max(axis=-1, keepdims=True)) # [N, N] 136 | diag_indexes = jnp.arange(shifted_cov.shape[0])[:, None]# [N, 1] 137 | pos = jnp.take_along_axis(arr=shifted_cov, indices=diag_indexes, axis=-1) # [N, 1] 138 | neg = log_sum_exp(shifted_cov) 139 | return -jnp.mean(pos.reshape(-1) - neg.reshape(-1)) 140 | 141 | 142 | def softmax_probabilities(query, key, temperature=0.5): 143 | query = normalize(query) 144 | key = normalize(key) 145 | logits = query @ key.T 146 | logits = logits / temperature 147 | shifted_cov = logits - jax.lax.stop_gradient(logits.max(axis=-1, keepdims=True)) # [N, N] 148 | diag_indexes = jnp.arange(shifted_cov.shape[0])[:, None] # [N, 1] 149 | pos = jnp.take_along_axis(arr=shifted_cov, indices=diag_indexes, axis=-1) # [N, 1] 150 | pos = jnp.exp(pos) 151 | neg = jnp.sum(jnp.exp(logits), axis=-1, keepdims=True) # [N, 1] 152 | return pos / neg 153 | 154 | 155 | def cpc_loss( 156 | query, 157 | key, 158 | temperature = 0.5 159 | ): 160 | 161 | query = normalize(query) 162 | key = normalize(key) 163 | cov = query @ key.T # (N, N) positive pairs on the diagonal 164 | sim = jnp.exp(cov / temperature) 165 | neg = sim.sum(axis=-1) # b 166 | row_sub = jnp.ones_like(neg) * math.exp(1/temperature) 167 | neg = jnp.clip(neg - row_sub, a_min=1e-6) 168 | 169 | pos = jnp.exp(jnp.sum(query * key, axis=-1) / temperature) # b 170 | loss = -jnp.log(pos / (neg + 1e-6)) 171 | return loss.mean() 172 | 173 | if __name__ == "__main__": 174 | # x = jax.random.normal(key=jax.random.PRNGKey(5), shape=(15, 5)) 175 | # 10, 5 176 | jax_input = jnp.array([[ 0.61735314, 0.65116936, 0.37252188, 0.01196358, 177 | -1.0840642 ], 178 | [ 0.40633643, -0.3350711 , 0.433196 , 1.8324155 , 179 | 1.2233032 ], 180 | [ 0.6076932 , 0.62271905, -0.5155139 , -0.8686952 , 181 | 1.3694043 ], 182 | [ 1.5686233 , -1.0647503 , 1.0048455 , 1.4000669 , 183 | 0.30719075], 184 | [ 1.6678249 , -0.5851507 , -1.420454 , -0.05948697, 185 | -1.5111905 ], 186 | [ 1.8621138 , -0.6911869 , -0.94851583, 1.159258 , 187 | 1.5931036 ], 188 | [ 1.9720763 , -1.0973446 , 1.1731594 , 0.0780869 , 189 | 0.143219 ], 190 | [-1.0157285 , 0.50870734, 0.39398482, 1.1644812 , 191 | -0.26890013], 192 | [ 1.6161795 , 1.644653 , -1.0968473 , 1.0495588 , 193 | 0.47088355], 194 | [-0.13400784, 0.5755616 , 0.4617284 , 0.08174139, 195 | -1.0918598 ]]) 196 | 197 | torch_input = torch.tensor([[ 0.61735314, 0.65116936, 0.37252188, 0.01196358, 198 | -1.0840642 ], 199 | [ 0.40633643, -0.3350711 , 0.433196 , 1.8324155 , 200 | 1.2233032 ], 201 | [ 0.6076932 , 0.62271905, -0.5155139 , -0.8686952 , 202 | 1.3694043 ], 203 | [ 1.5686233 , -1.0647503 , 1.0048455 , 1.4000669 , 204 | 0.30719075], 205 | [ 1.6678249 , -0.5851507 , -1.420454 , -0.05948697, 206 | -1.5111905 ], 207 | [ 1.8621138 , -0.6911869 , -0.94851583, 1.159258 , 208 | 1.5931036 ], 209 | [ 1.9720763 , -1.0973446 , 1.1731594 , 0.0780869 , 210 | 0.143219 ], 211 | [-1.0157285 , 0.50870734, 0.39398482, 1.1644812 , 212 | -0.26890013], 213 | [ 1.6161795 , 1.644653 , -1.0968473 , 1.0495588 , 214 | 0.47088355], 215 | [-0.13400784, 0.5755616 , 0.4617284 , 0.08174139, 216 | -1.0918598 ]]) 217 | 218 | ## TEST particle 219 | # knn_k = 3 220 | # knn_clip = 0.0 221 | # mean = 0.0 222 | # knn_avg = True 223 | # knn_rm = True 224 | # particle_based_entropy = partial(particle_based_entropy, knn_k=knn_k, knn_clip=knn_clip, knn_rm=knn_rm, 225 | # knn_avg=knn_avg) 226 | # value = particle_based_entropy(rep=jax_input, mean=mean, step=1) 227 | # print(value) 228 | # rms = RMS('cpu') 229 | # pbe = PBE(rms, knn_clip, knn_k, knn_avg, knn_rm, 'cpu') 230 | # value_torch = pbe(torch_input) 231 | # print(value_torch) 232 | 233 | ## TEST nce 234 | # out = noise_contrastive_loss(jax_input, jax_input) 235 | # out = cpc_loss(jax_input, jax_input) 236 | # print(out) 237 | # out_torch = torch_nce(torch_input, torch_input) 238 | # print(out_torch) 239 | # print("Sanity Check value should be close to log(1/N): {}".format(math.log(jax_input.shape[0]))) -------------------------------------------------------------------------------- /core/calculations/misc.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | 5 | 6 | 7 | def schedule(schdl, step): 8 | try: 9 | return float(schdl) 10 | except ValueError: 11 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 12 | if match: 13 | init, final, duration = [float(g) for g in match.groups()] 14 | mix = np.clip(step / duration, 0.0, 1.0) 15 | return (1.0 - mix) * init + mix * final 16 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 17 | if match: 18 | init, final1, duration1, final2, duration2 = [ 19 | float(g) for g in match.groups() 20 | ] 21 | if step <= duration1: 22 | mix = np.clip(step / duration1, 0.0, 1.0) 23 | return (1.0 - mix) * init + mix * final1 24 | else: 25 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 26 | return (1.0 - mix) * final1 + mix * final2 27 | raise NotImplementedError(schdl) 28 | 29 | 30 | -------------------------------------------------------------------------------- /core/calculations/params_utils.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax 3 | import tree 4 | 5 | def count_param(params: hk.Params): 6 | params_count_list = [p.size for ((mod_name, x), p) in tree.flatten_with_path(params)] 7 | return sum(params_count_list) 8 | 9 | 10 | def polyak_averaging(params: hk.Params, 11 | target_params: hk.Params, 12 | tau: float 13 | ): 14 | return jax.tree_multimap( 15 | lambda x, y: tau * x + (1 - tau) * y, 16 | params, target_params 17 | ) 18 | -------------------------------------------------------------------------------- /core/calculations/skill_utils.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | import jax 3 | from jax import numpy as jnp 4 | import numpy as np 5 | 6 | 7 | class SkillRewardTracker(NamedTuple): 8 | best_skill: jnp.ndarray 9 | best_score: np.float32 10 | score_sum: np.float32 11 | score_step: int 12 | current_skill: jnp.ndarray 13 | search_steps: int 14 | change_interval: int 15 | low: float 16 | update: bool 17 | 18 | 19 | def constant_fixed_skill(skill_dim: int,) -> jnp.ndarray: 20 | return jnp.ones((skill_dim,), dtype=jnp.float32) * 0.5 21 | 22 | 23 | def random_search_skill( 24 | skill_dim: int, 25 | global_timestep: int, 26 | skill_tracker: SkillRewardTracker, 27 | key: jax.random.PRNGKey 28 | ) -> jnp.ndarray: 29 | if global_timestep >= skill_tracker.search_steps: 30 | return skill_tracker.best_skill 31 | return jax.random.uniform(key, shape=(skill_dim, ), minval=0., maxval=1.) 32 | 33 | 34 | def random_grid_search_skill(key: jax.random.PRNGKey, 35 | skill_dim: int, 36 | global_timestep: int, 37 | skill_tracker: SkillRewardTracker, 38 | **kwargs) -> jnp.ndarray: 39 | if global_timestep >= skill_tracker.search_steps: 40 | return skill_tracker.best_skill 41 | increment = (1 - skill_tracker.low) / (skill_tracker.search_steps // skill_tracker.change_interval) 42 | start = global_timestep // skill_tracker.change_interval * increment 43 | end = (global_timestep // skill_tracker.change_interval + 1) * increment 44 | return jax.random.uniform(key, shape=(skill_dim,), minval=start, maxval=end) 45 | 46 | 47 | def grid_search_skill(skill_dim: int, global_timestep: int, skill_tracker: SkillRewardTracker) -> jnp.ndarray: 48 | if global_timestep >= skill_tracker.search_steps: 49 | return skill_tracker.best_skill 50 | return jnp.ones((skill_dim,)) * jnp.linspace( 51 | -1., 52 | 0., 53 | num=skill_tracker.search_steps // skill_tracker.change_interval 54 | )[global_timestep // skill_tracker.change_interval] 55 | 56 | -------------------------------------------------------------------------------- /core/custom_dmc_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from core.custom_dmc_tasks import quadruped, jaco, cheetah, walker, hopper 2 | 3 | 4 | def make(domain, task, 5 | task_kwargs=None, 6 | environment_kwargs=None, 7 | visualize_reward=False): 8 | 9 | if domain == 'cheetah': 10 | return cheetah.make(task, 11 | task_kwargs=task_kwargs, 12 | environment_kwargs=environment_kwargs, 13 | visualize_reward=visualize_reward) 14 | elif domain == 'walker': 15 | return walker.make(task, 16 | task_kwargs=task_kwargs, 17 | environment_kwargs=environment_kwargs, 18 | visualize_reward=visualize_reward) 19 | elif domain == 'hopper': 20 | return hopper.make(task, 21 | task_kwargs=task_kwargs, 22 | environment_kwargs=environment_kwargs, 23 | visualize_reward=visualize_reward) 24 | elif domain == 'quadruped': 25 | return quadruped.make(task, 26 | task_kwargs=task_kwargs, 27 | environment_kwargs=environment_kwargs, 28 | visualize_reward=visualize_reward) 29 | else: 30 | raise f'{task} not found' 31 | 32 | assert None 33 | 34 | 35 | def make_jaco(task, obs_type, seed): 36 | return jaco.make(task, obs_type, seed) -------------------------------------------------------------------------------- /core/custom_dmc_tasks/cheetah.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Cheetah Domain.""" 16 | 17 | import collections 18 | import os 19 | 20 | from dm_control import mujoco 21 | from dm_control.rl import control 22 | from dm_control.suite import base 23 | from dm_control.suite import common 24 | from dm_control.utils import containers 25 | from dm_control.utils import rewards 26 | from dm_control.utils import io as resources 27 | 28 | # How long the simulation will run, in seconds. 29 | _DEFAULT_TIME_LIMIT = 10 30 | 31 | # Running speed above which reward is 1. 32 | _RUN_SPEED = 10 33 | _SPIN_SPEED = 5 34 | 35 | SUITE = containers.TaggedTasks() 36 | 37 | 38 | def make(task, 39 | task_kwargs=None, 40 | environment_kwargs=None, 41 | visualize_reward=False): 42 | task_kwargs = task_kwargs or {} 43 | if environment_kwargs is not None: 44 | task_kwargs = task_kwargs.copy() 45 | task_kwargs['environment_kwargs'] = environment_kwargs 46 | env = SUITE[task](**task_kwargs) 47 | env.task.visualize_reward = visualize_reward 48 | return env 49 | 50 | 51 | def get_model_and_assets(): 52 | """Returns a tuple containing the model XML string and a dict of assets.""" 53 | root_dir = os.path.dirname(os.path.dirname(__file__)) 54 | xml = resources.GetResource( 55 | os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml')) 56 | return xml, common.ASSETS 57 | 58 | 59 | 60 | @SUITE.add('benchmarking') 61 | def run_backward(time_limit=_DEFAULT_TIME_LIMIT, 62 | random=None, 63 | environment_kwargs=None): 64 | """Returns the run task.""" 65 | physics = Physics.from_xml_string(*get_model_and_assets()) 66 | task = Cheetah(forward=False, flip=False, random=random) 67 | environment_kwargs = environment_kwargs or {} 68 | return control.Environment(physics, 69 | task, 70 | time_limit=time_limit, 71 | **environment_kwargs) 72 | 73 | 74 | @SUITE.add('benchmarking') 75 | def flip(time_limit=_DEFAULT_TIME_LIMIT, 76 | random=None, 77 | environment_kwargs=None): 78 | """Returns the run task.""" 79 | physics = Physics.from_xml_string(*get_model_and_assets()) 80 | task = Cheetah(forward=True, flip=True, random=random) 81 | environment_kwargs = environment_kwargs or {} 82 | return control.Environment(physics, 83 | task, 84 | time_limit=time_limit, 85 | **environment_kwargs) 86 | 87 | 88 | @SUITE.add('benchmarking') 89 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, 90 | random=None, 91 | environment_kwargs=None): 92 | """Returns the run task.""" 93 | physics = Physics.from_xml_string(*get_model_and_assets()) 94 | task = Cheetah(forward=False, flip=True, random=random) 95 | environment_kwargs = environment_kwargs or {} 96 | return control.Environment(physics, 97 | task, 98 | time_limit=time_limit, 99 | **environment_kwargs) 100 | 101 | 102 | class Physics(mujoco.Physics): 103 | """Physics simulation with additional features for the Cheetah domain.""" 104 | def speed(self): 105 | """Returns the horizontal speed of the Cheetah.""" 106 | return self.named.data.sensordata['torso_subtreelinvel'][0] 107 | 108 | def angmomentum(self): 109 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 110 | return self.named.data.subtree_angmom['torso'][1] 111 | 112 | 113 | class Cheetah(base.Task): 114 | """A `Task` to train a running Cheetah.""" 115 | def __init__(self, forward=True, flip=False, random=None): 116 | self._forward = 1 if forward else -1 117 | self._flip = flip 118 | super(Cheetah, self).__init__(random=random) 119 | 120 | def initialize_episode(self, physics): 121 | """Sets the state of the environment at the start of each episode.""" 122 | # The indexing below assumes that all joints have a single DOF. 123 | assert physics.model.nq == physics.model.njnt 124 | is_limited = physics.model.jnt_limited == 1 125 | lower, upper = physics.model.jnt_range[is_limited].T 126 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper) 127 | 128 | # Stabilize the model before the actual simulation. 129 | for _ in range(200): 130 | physics.step() 131 | 132 | physics.data.time = 0 133 | self._timeout_progress = 0 134 | super().initialize_episode(physics) 135 | 136 | def get_observation(self, physics): 137 | """Returns an observation of the state, ignoring horizontal position.""" 138 | obs = collections.OrderedDict() 139 | # Ignores horizontal position to maintain translational invariance. 140 | obs['position'] = physics.data.qpos[1:].copy() 141 | obs['velocity'] = physics.velocity() 142 | return obs 143 | 144 | def get_reward(self, physics): 145 | """Returns a reward to the agent.""" 146 | if self._flip: 147 | reward = rewards.tolerance(self._forward * physics.angmomentum(), 148 | bounds=(_SPIN_SPEED, float('inf')), 149 | margin=_SPIN_SPEED, 150 | value_at_margin=0, 151 | sigmoid='linear') 152 | 153 | else: 154 | reward = rewards.tolerance(self._forward * physics.speed(), 155 | bounds=(_RUN_SPEED, float('inf')), 156 | margin=_RUN_SPEED, 157 | value_at_margin=0, 158 | sigmoid='linear') 159 | return reward 160 | -------------------------------------------------------------------------------- /core/custom_dmc_tasks/cheetah.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /core/custom_dmc_tasks/hopper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Hopper domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from dm_control.suite import base 27 | from dm_control.suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | from dm_control.utils import io as resources 32 | import numpy as np 33 | 34 | SUITE = containers.TaggedTasks() 35 | 36 | _CONTROL_TIMESTEP = .02 # (Seconds) 37 | 38 | # Default duration of an episode, in seconds. 39 | _DEFAULT_TIME_LIMIT = 20 40 | 41 | # Minimal height of torso over foot above which stand reward is 1. 42 | _STAND_HEIGHT = 0.6 43 | 44 | # Hopping speed above which hop reward is 1. 45 | _HOP_SPEED = 2 46 | _SPIN_SPEED = 5 47 | 48 | 49 | def make(task, 50 | task_kwargs=None, 51 | environment_kwargs=None, 52 | visualize_reward=False): 53 | task_kwargs = task_kwargs or {} 54 | if environment_kwargs is not None: 55 | task_kwargs = task_kwargs.copy() 56 | task_kwargs['environment_kwargs'] = environment_kwargs 57 | env = SUITE[task](**task_kwargs) 58 | env.task.visualize_reward = visualize_reward 59 | return env 60 | 61 | def get_model_and_assets(): 62 | """Returns a tuple containing the model XML string and a dict of assets.""" 63 | root_dir = os.path.dirname(os.path.dirname(__file__)) 64 | xml = resources.GetResource( 65 | os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml')) 66 | return xml, common.ASSETS 67 | 68 | 69 | 70 | @SUITE.add('benchmarking') 71 | def hop_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 72 | """Returns a Hopper that strives to hop forward.""" 73 | physics = Physics.from_xml_string(*get_model_and_assets()) 74 | task = Hopper(hopping=True, forward=False, flip=False, random=random) 75 | environment_kwargs = environment_kwargs or {} 76 | return control.Environment(physics, 77 | task, 78 | time_limit=time_limit, 79 | control_timestep=_CONTROL_TIMESTEP, 80 | **environment_kwargs) 81 | 82 | @SUITE.add('benchmarking') 83 | def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 84 | """Returns a Hopper that strives to hop forward.""" 85 | physics = Physics.from_xml_string(*get_model_and_assets()) 86 | task = Hopper(hopping=True, forward=True, flip=True, random=random) 87 | environment_kwargs = environment_kwargs or {} 88 | return control.Environment(physics, 89 | task, 90 | time_limit=time_limit, 91 | control_timestep=_CONTROL_TIMESTEP, 92 | **environment_kwargs) 93 | 94 | @SUITE.add('benchmarking') 95 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 96 | """Returns a Hopper that strives to hop forward.""" 97 | physics = Physics.from_xml_string(*get_model_and_assets()) 98 | task = Hopper(hopping=True, forward=False, flip=True, random=random) 99 | environment_kwargs = environment_kwargs or {} 100 | return control.Environment(physics, 101 | task, 102 | time_limit=time_limit, 103 | control_timestep=_CONTROL_TIMESTEP, 104 | **environment_kwargs) 105 | 106 | 107 | class Physics(mujoco.Physics): 108 | """Physics simulation with additional features for the Hopper domain.""" 109 | def height(self): 110 | """Returns height of torso with respect to foot.""" 111 | return (self.named.data.xipos['torso', 'z'] - 112 | self.named.data.xipos['foot', 'z']) 113 | 114 | def speed(self): 115 | """Returns horizontal speed of the Hopper.""" 116 | return self.named.data.sensordata['torso_subtreelinvel'][0] 117 | 118 | def touch(self): 119 | """Returns the signals from two foot touch sensors.""" 120 | return np.log1p(self.named.data.sensordata[['touch_toe', 121 | 'touch_heel']]) 122 | 123 | def angmomentum(self): 124 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 125 | return self.named.data.subtree_angmom['torso'][1] 126 | 127 | 128 | 129 | class Hopper(base.Task): 130 | """A Hopper's `Task` to train a standing and a jumping Hopper.""" 131 | def __init__(self, hopping, forward=True, flip=False, random=None): 132 | """Initialize an instance of `Hopper`. 133 | 134 | Args: 135 | hopping: Boolean, if True the task is to hop forwards, otherwise it is to 136 | balance upright. 137 | random: Optional, either a `numpy.random.RandomState` instance, an 138 | integer seed for creating a new `RandomState`, or None to select a seed 139 | automatically (default). 140 | """ 141 | self._hopping = hopping 142 | self._forward = 1 if forward else -1 143 | self._flip = flip 144 | super(Hopper, self).__init__(random=random) 145 | 146 | def initialize_episode(self, physics): 147 | """Sets the state of the environment at the start of each episode.""" 148 | randomizers.randomize_limited_and_rotational_joints( 149 | physics, self.random) 150 | self._timeout_progress = 0 151 | super(Hopper, self).initialize_episode(physics) 152 | 153 | def get_observation(self, physics): 154 | """Returns an observation of positions, velocities and touch sensors.""" 155 | obs = collections.OrderedDict() 156 | # Ignores horizontal position to maintain translational invariance: 157 | obs['position'] = physics.data.qpos[1:].copy() 158 | obs['velocity'] = physics.velocity() 159 | obs['touch'] = physics.touch() 160 | return obs 161 | 162 | def get_reward(self, physics): 163 | """Returns a reward applicable to the performed task.""" 164 | standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2)) 165 | assert self._hopping 166 | if self._flip: 167 | hopping = rewards.tolerance(self._forward * physics.angmomentum(), 168 | bounds=(_SPIN_SPEED, float('inf')), 169 | margin=_SPIN_SPEED, 170 | value_at_margin=0, 171 | sigmoid='linear') 172 | else: 173 | hopping = rewards.tolerance(self._forward * physics.speed(), 174 | bounds=(_HOP_SPEED, float('inf')), 175 | margin=_HOP_SPEED / 2, 176 | value_at_margin=0.5, 177 | sigmoid='linear') 178 | return standing * hopping -------------------------------------------------------------------------------- /core/custom_dmc_tasks/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /core/custom_dmc_tasks/jaco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """A task where the goal is to move the hand close to a target prop or site.""" 17 | 18 | import collections 19 | 20 | from dm_control import composer 21 | from dm_control.composer import initializers 22 | from dm_control.composer.observation import observable 23 | from dm_control.composer.variation import distributions 24 | from dm_control.entities import props 25 | from dm_control.manipulation.shared import arenas 26 | from dm_control.manipulation.shared import cameras 27 | from dm_control.manipulation.shared import constants 28 | from dm_control.manipulation.shared import observations 29 | from dm_control.manipulation.shared import registry 30 | from dm_control.manipulation.shared import robots 31 | from dm_control.manipulation.shared import tags 32 | from dm_control.manipulation.shared import workspaces 33 | from dm_control.utils import rewards 34 | import numpy as np 35 | 36 | 37 | _ReachWorkspace = collections.namedtuple( 38 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset']) 39 | 40 | # Ensures that the props are not touching the table before settling. 41 | _PROP_Z_OFFSET = 0.001 42 | 43 | _DUPLO_WORKSPACE = _ReachWorkspace( 44 | target_bbox=workspaces.BoundingBox( 45 | lower=(-0.1, -0.1, _PROP_Z_OFFSET), 46 | upper=(0.1, 0.1, _PROP_Z_OFFSET)), 47 | tcp_bbox=workspaces.BoundingBox( 48 | lower=(-0.1, -0.1, 0.2), 49 | upper=(0.1, 0.1, 0.4)), 50 | arm_offset=robots.ARM_OFFSET) 51 | 52 | _SITE_WORKSPACE = _ReachWorkspace( 53 | target_bbox=workspaces.BoundingBox( 54 | lower=(-0.2, -0.2, 0.02), 55 | upper=(0.2, 0.2, 0.4)), 56 | tcp_bbox=workspaces.BoundingBox( 57 | lower=(-0.2, -0.2, 0.02), 58 | upper=(0.2, 0.2, 0.4)), 59 | arm_offset=robots.ARM_OFFSET) 60 | 61 | _TARGET_RADIUS = 0.05 62 | _TIME_LIMIT = 10. 63 | 64 | TASKS = { 65 | 'reach_top_left': workspaces.BoundingBox( 66 | lower=(-0.09, 0.09, _PROP_Z_OFFSET), 67 | upper=(-0.09, 0.09, _PROP_Z_OFFSET)), 68 | 'reach_top_right': workspaces.BoundingBox( 69 | lower=(0.09, 0.09, _PROP_Z_OFFSET), 70 | upper=(0.09, 0.09, _PROP_Z_OFFSET)), 71 | 'reach_bottom_left': workspaces.BoundingBox( 72 | lower=(-0.09, -0.09, _PROP_Z_OFFSET), 73 | upper=(-0.09, -0.09, _PROP_Z_OFFSET)), 74 | 'reach_bottom_right': workspaces.BoundingBox( 75 | lower=(0.09, -0.09, _PROP_Z_OFFSET), 76 | upper=(0.09, -0.09, _PROP_Z_OFFSET)), 77 | } 78 | 79 | 80 | def make(task_id, obs_type, seed): 81 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES 82 | task = _reach(task_id, obs_settings=obs_settings, use_site=False) 83 | return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed) 84 | 85 | 86 | 87 | class MTReach(composer.Task): 88 | """Bring the hand close to a target prop or site.""" 89 | 90 | def __init__( 91 | self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep): 92 | """Initializes a new `Reach` task. 93 | 94 | Args: 95 | arena: `composer.Entity` instance. 96 | arm: `robot_base.RobotArm` instance. 97 | hand: `robot_base.RobotHand` instance. 98 | prop: `composer.Entity` instance specifying the prop to reach to, or None 99 | in which case the target is a fixed site whose position is specified by 100 | the workspace. 101 | obs_settings: `observations.ObservationSettings` instance. 102 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP. 103 | control_timestep: Float specifying the control timestep in seconds. 104 | """ 105 | self._arena = arena 106 | self._arm = arm 107 | self._hand = hand 108 | self._arm.attach(self._hand) 109 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset) 110 | self.control_timestep = control_timestep 111 | self._tcp_initializer = initializers.ToolCenterPointInitializer( 112 | self._hand, self._arm, 113 | position=distributions.Uniform(*workspace.tcp_bbox), 114 | quaternion=workspaces.DOWN_QUATERNION) 115 | 116 | # Add custom camera observable. 117 | self._task_observables = cameras.add_camera_observables( 118 | arena, obs_settings, cameras.FRONT_CLOSE) 119 | 120 | target_pos_distribution = distributions.Uniform(*TASKS[task_id]) 121 | self._prop = prop 122 | if prop: 123 | # The prop itself is used to visualize the target location. 124 | self._make_target_site(parent_entity=prop, visible=False) 125 | self._target = self._arena.add_free_entity(prop) 126 | self._prop_placer = initializers.PropPlacer( 127 | props=[prop], 128 | position=target_pos_distribution, 129 | quaternion=workspaces.uniform_z_rotation, 130 | settle_physics=True) 131 | else: 132 | self._target = self._make_target_site(parent_entity=arena, visible=True) 133 | self._target_placer = target_pos_distribution 134 | 135 | obs = observable.MJCFFeature('pos', self._target) 136 | obs.configure(**obs_settings.prop_pose._asdict()) 137 | self._task_observables['target_position'] = obs 138 | 139 | # Add sites for visualizing the prop and target bounding boxes. 140 | workspaces.add_bbox_site( 141 | body=self.root_entity.mjcf_model.worldbody, 142 | lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper, 143 | rgba=constants.GREEN, name='tcp_spawn_area') 144 | workspaces.add_bbox_site( 145 | body=self.root_entity.mjcf_model.worldbody, 146 | lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper, 147 | rgba=constants.BLUE, name='target_spawn_area') 148 | 149 | def _make_target_site(self, parent_entity, visible): 150 | return workspaces.add_target_site( 151 | body=parent_entity.mjcf_model.worldbody, 152 | radius=_TARGET_RADIUS, visible=visible, 153 | rgba=constants.RED, name='target_site') 154 | 155 | @property 156 | def root_entity(self): 157 | return self._arena 158 | 159 | @property 160 | def arm(self): 161 | return self._arm 162 | 163 | @property 164 | def hand(self): 165 | return self._hand 166 | 167 | @property 168 | def task_observables(self): 169 | return self._task_observables 170 | 171 | def get_reward(self, physics): 172 | hand_pos = physics.bind(self._hand.tool_center_point).xpos 173 | target_pos = physics.bind(self._target).xpos 174 | distance = np.linalg.norm(hand_pos - target_pos) 175 | return rewards.tolerance( 176 | distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS) 177 | 178 | def initialize_episode(self, physics, random_state): 179 | self._hand.set_grasp(physics, close_factors=random_state.uniform()) 180 | self._tcp_initializer(physics, random_state) 181 | if self._prop: 182 | self._prop_placer(physics, random_state) 183 | else: 184 | physics.bind(self._target).pos = ( 185 | self._target_placer(random_state=random_state)) 186 | 187 | 188 | def _reach(task_id, obs_settings, use_site): 189 | """Configure and instantiate a `Reach` task. 190 | 191 | Args: 192 | obs_settings: An `observations.ObservationSettings` instance. 193 | use_site: Boolean, if True then the target will be a fixed site, otherwise 194 | it will be a moveable Duplo brick. 195 | 196 | Returns: 197 | An instance of `reach.Reach`. 198 | """ 199 | arena = arenas.Standard() 200 | arm = robots.make_arm(obs_settings=obs_settings) 201 | hand = robots.make_hand(obs_settings=obs_settings) 202 | if use_site: 203 | workspace = _SITE_WORKSPACE 204 | prop = None 205 | else: 206 | workspace = _DUPLO_WORKSPACE 207 | prop = props.Duplo(observable_options=observations.make_options( 208 | obs_settings, observations.FREEPROP_OBSERVABLES)) 209 | task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop, 210 | obs_settings=obs_settings, 211 | workspace=workspace, 212 | control_timestep=constants.CONTROL_TIMESTEP) 213 | return task -------------------------------------------------------------------------------- /core/custom_dmc_tasks/quadruped.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /core/custom_dmc_tasks/walker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Planar Walker Domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from dm_control.suite import base 27 | from dm_control.suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | from dm_control.utils import io as resources 32 | from dm_control import suite 33 | 34 | _DEFAULT_TIME_LIMIT = 25 35 | _CONTROL_TIMESTEP = .025 36 | 37 | # Minimal height of torso over foot above which stand reward is 1. 38 | _STAND_HEIGHT = 1.2 39 | 40 | # Horizontal speeds (meters/second) above which move reward is 1. 41 | _WALK_SPEED = 1 42 | _RUN_SPEED = 8 43 | _SPIN_SPEED = 5 44 | 45 | SUITE = containers.TaggedTasks() 46 | 47 | def make(task, 48 | task_kwargs=None, 49 | environment_kwargs=None, 50 | visualize_reward=False): 51 | task_kwargs = task_kwargs or {} 52 | if environment_kwargs is not None: 53 | task_kwargs = task_kwargs.copy() 54 | task_kwargs['environment_kwargs'] = environment_kwargs 55 | env = SUITE[task](**task_kwargs) 56 | env.task.visualize_reward = visualize_reward 57 | return env 58 | 59 | def get_model_and_assets(): 60 | """Returns a tuple containing the model XML string and a dict of assets.""" 61 | root_dir = os.path.dirname(os.path.dirname(__file__)) 62 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks', 63 | 'walker.xml')) 64 | return xml, common.ASSETS 65 | 66 | 67 | 68 | 69 | 70 | 71 | @SUITE.add('benchmarking') 72 | def flip(time_limit=_DEFAULT_TIME_LIMIT, 73 | random=None, 74 | environment_kwargs=None): 75 | """Returns the Run task.""" 76 | physics = Physics.from_xml_string(*get_model_and_assets()) 77 | task = PlanarWalker(move_speed=_RUN_SPEED, 78 | forward=True, 79 | flip=True, 80 | random=random) 81 | environment_kwargs = environment_kwargs or {} 82 | return control.Environment(physics, 83 | task, 84 | time_limit=time_limit, 85 | control_timestep=_CONTROL_TIMESTEP, 86 | **environment_kwargs) 87 | 88 | 89 | class Physics(mujoco.Physics): 90 | """Physics simulation with additional features for the Walker domain.""" 91 | def torso_upright(self): 92 | """Returns projection from z-axes of torso to the z-axes of world.""" 93 | return self.named.data.xmat['torso', 'zz'] 94 | 95 | def torso_height(self): 96 | """Returns the height of the torso.""" 97 | return self.named.data.xpos['torso', 'z'] 98 | 99 | def horizontal_velocity(self): 100 | """Returns the horizontal velocity of the center-of-mass.""" 101 | return self.named.data.sensordata['torso_subtreelinvel'][0] 102 | 103 | def orientations(self): 104 | """Returns planar orientations of all bodies.""" 105 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel() 106 | 107 | def angmomentum(self): 108 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 109 | return self.named.data.subtree_angmom['torso'][1] 110 | 111 | 112 | class PlanarWalker(base.Task): 113 | """A planar walker task.""" 114 | def __init__(self, move_speed, forward=True, flip=False, random=None): 115 | """Initializes an instance of `PlanarWalker`. 116 | 117 | Args: 118 | move_speed: A float. If this value is zero, reward is given simply for 119 | standing up. Otherwise this specifies a target horizontal velocity for 120 | the walking task. 121 | random: Optional, either a `numpy.random.RandomState` instance, an 122 | integer seed for creating a new `RandomState`, or None to select a seed 123 | automatically (default). 124 | """ 125 | self._move_speed = move_speed 126 | self._forward = 1 if forward else -1 127 | self._flip = flip 128 | super(PlanarWalker, self).__init__(random=random) 129 | 130 | def initialize_episode(self, physics): 131 | """Sets the state of the environment at the start of each episode. 132 | 133 | In 'standing' mode, use initial orientation and small velocities. 134 | In 'random' mode, randomize joint angles and let fall to the floor. 135 | 136 | Args: 137 | physics: An instance of `Physics`. 138 | 139 | """ 140 | randomizers.randomize_limited_and_rotational_joints( 141 | physics, self.random) 142 | super(PlanarWalker, self).initialize_episode(physics) 143 | 144 | def get_observation(self, physics): 145 | """Returns an observation of body orientations, height and velocites.""" 146 | obs = collections.OrderedDict() 147 | obs['orientations'] = physics.orientations() 148 | obs['height'] = physics.torso_height() 149 | obs['velocity'] = physics.velocity() 150 | return obs 151 | 152 | def get_reward(self, physics): 153 | """Returns a reward to the agent.""" 154 | standing = rewards.tolerance(physics.torso_height(), 155 | bounds=(_STAND_HEIGHT, float('inf')), 156 | margin=_STAND_HEIGHT / 2) 157 | upright = (1 + physics.torso_upright()) / 2 158 | stand_reward = (3 * standing + upright) / 4 159 | 160 | if self._flip: 161 | move_reward = rewards.tolerance(self._forward * 162 | physics.angmomentum(), 163 | bounds=(_SPIN_SPEED, float('inf')), 164 | margin=_SPIN_SPEED, 165 | value_at_margin=0, 166 | sigmoid='linear') 167 | else: 168 | move_reward = rewards.tolerance( 169 | self._forward * physics.horizontal_velocity(), 170 | bounds=(self._move_speed, float('inf')), 171 | margin=self._move_speed / 2, 172 | value_at_margin=0.5, 173 | sigmoid='linear') 174 | 175 | return stand_reward * (5 * move_reward + 1) / 6 176 | -------------------------------------------------------------------------------- /core/custom_dmc_tasks/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /core/data/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jax.numpy as jnp 4 | from dm_env import specs 5 | 6 | from .replay_buffer import get_reverb_replay_components, Batch, ReverbReplay, IntraEpisodicBuffer 7 | from .replay_buffer_torch import make_replay_loader, ReplayBufferStorage, ReplayBuffer -------------------------------------------------------------------------------- /core/data/replay_buffer.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Optional, NamedTuple, Dict, Any 2 | from collections import deque 3 | import dataclasses 4 | 5 | import reverb 6 | import numpy as np 7 | import dm_env 8 | from dm_env import specs 9 | from acme import adders, specs, types 10 | from acme.adders import reverb as adders_reverb 11 | from acme.datasets import reverb as datasets 12 | 13 | 14 | ################# 15 | ### From Acme ### 16 | ################# 17 | 18 | class Batch(NamedTuple): 19 | observation: np.ndarray 20 | action: np.ndarray 21 | reward: np.ndarray 22 | discount: np.ndarray 23 | next_observation: np.ndarray 24 | extras: Dict 25 | 26 | @dataclasses.dataclass 27 | class ReverbReplay: 28 | server: reverb.Server 29 | adder: adders.Adder 30 | data_iterator: Iterator[reverb.ReplaySample] 31 | client: Optional[reverb.Client] = None 32 | 33 | def make_replay_tables( 34 | environment_spec: specs.EnvironmentSpec, 35 | replay_table_name: str = 'replay buffer', 36 | max_replay_size: int = 2_000_000, 37 | min_replay_size: int = 100, 38 | extras_spec: types.NestedSpec = () 39 | ) -> List[reverb.Table]: 40 | """Creates reverb tables for the algorithm.""" 41 | return [reverb.Table( 42 | name=replay_table_name, 43 | sampler=reverb.selectors.Uniform(), 44 | remover=reverb.selectors.Fifo(), 45 | max_size=max_replay_size, 46 | rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), 47 | signature=adders_reverb.NStepTransitionAdder.signature( 48 | environment_spec, extras_spec))] 49 | 50 | def make_dataset_iterator( 51 | replay_client: reverb.Client, batch_size: int, 52 | prefetch_size: int = 4, replay_table_name: str = 'replay buffer', 53 | ) -> Iterator[reverb.ReplaySample]: 54 | """Creates a dataset iterator to use for learning.""" 55 | dataset = datasets.make_reverb_dataset( 56 | table=replay_table_name, 57 | server_address=replay_client.server_address, 58 | batch_size=batch_size, 59 | prefetch_size=prefetch_size) 60 | return dataset.as_numpy_iterator() 61 | 62 | def make_adder( 63 | replay_client: reverb.Client, 64 | n_step: int, discount: float, 65 | replay_table_name: str = 'replay buffer',) -> adders.Adder: 66 | """Creates an adder which handles observations.""" 67 | return adders_reverb.NStepTransitionAdder( 68 | priority_fns={replay_table_name: None}, 69 | client=replay_client, 70 | n_step=n_step, 71 | discount=discount 72 | ) 73 | 74 | def get_reverb_replay_components( 75 | environment_spec: specs.EnvironmentSpec, 76 | n_step: int, discount: float, batch_size: int, 77 | max_replay_size: int = 2_000_000, 78 | min_replay_size: int = 100, 79 | replay_table_name: str = 'replay buffer', 80 | extras_spec: Optional[types.NestedSpec] = () 81 | ) -> ReverbReplay: 82 | replay_table = make_replay_tables(environment_spec, 83 | replay_table_name, max_replay_size, 84 | min_replay_size=min_replay_size, extras_spec=extras_spec) 85 | server = reverb.Server(replay_table, port=None) 86 | address = f'localhost:{server.port}' 87 | client = reverb.Client(address) 88 | adder = make_adder(client, n_step, discount, replay_table_name) 89 | data_iterator = make_dataset_iterator( 90 | client, batch_size, replay_table_name=replay_table_name) 91 | return ReverbReplay( 92 | server, adder, data_iterator, client 93 | ) 94 | 95 | 96 | class IntraEpisodicBuffer: 97 | def __init__(self, maxlen: int = 1001, full_method: str = 'episodic') -> None: 98 | self.timesteps = deque(maxlen=maxlen) 99 | self.extras = deque(maxlen=maxlen) 100 | self._maxlen = maxlen 101 | self.full_method = full_method 102 | self._last_timestep = None 103 | 104 | def add(self, timestep: dm_env.TimeStep, extra: Dict[str, Any]): 105 | self.timesteps.append(timestep) 106 | self.extras.append(extra) 107 | self._last_timestep = timestep 108 | 109 | def reset(self): 110 | self.timesteps = deque(maxlen=self._maxlen) 111 | self.extras = deque(maxlen=self._maxlen) 112 | self._last_timestep = None 113 | 114 | def __len__(self) -> int: 115 | return len(self.timesteps) 116 | 117 | def is_full(self): 118 | if self.full_method == 'episodic': 119 | # buffer is not full when just initialized/resetted 120 | if self._last_timestep is None: 121 | return False 122 | return self._last_timestep.last() 123 | if self.full_method == 'step': 124 | return len(self.timesteps) == self._maxlen 125 | raise NotImplementedError 126 | -------------------------------------------------------------------------------- /core/data/replay_buffer_torch.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, OrderedDict, List, Tuple, Dict, Union 2 | import datetime 3 | import io 4 | import random 5 | import traceback 6 | from collections import defaultdict 7 | import pathlib 8 | import functools 9 | 10 | import numpy as np 11 | import dm_env 12 | import torch 13 | from torch.utils.data import IterableDataset 14 | 15 | 16 | class Batch(NamedTuple): 17 | observation: Union[np.ndarray, List] 18 | action: Union[np.ndarray, List] 19 | reward: Union[np.ndarray, List] 20 | discount: Union[np.ndarray, List] 21 | next_observation: Union[np.ndarray, List] 22 | extras: Dict # List 23 | 24 | 25 | def compute_episode_len(episode): 26 | # subtract -1 because the dummy first transition 27 | return next(iter(episode.values())).shape[0] - 1 28 | 29 | 30 | def save_episode(episode, fn): 31 | with io.BytesIO() as bs: 32 | np.savez_compressed(bs, **episode) 33 | bs.seek(0) 34 | with fn.open('wb') as f: 35 | f.write(bs.read()) 36 | 37 | 38 | def _preload(replay_dir: pathlib.Path) -> Tuple[int, int]: 39 | """ 40 | returns the number of episode and transitions in the replay_dir, 41 | it assumes that each episode's name has the format {}_{}_{episode_len}.npz 42 | """ 43 | n_episodes, n_transitions = 0, 0 44 | for file in replay_dir.glob('*.npz'): 45 | _, _, episode_len = file.stem.split('_') 46 | n_episodes += 1 47 | n_transitions += int(episode_len) 48 | 49 | return n_episodes, n_transitions 50 | 51 | 52 | class ReplayBufferStorage: 53 | def __init__(self, 54 | data_specs, #: Tuple[specs, ...], 55 | meta_specs, #: Tuple[specs, ...], 56 | replay_dir: pathlib.Path = pathlib.Path.cwd() / 'buffer' 57 | ): 58 | """ 59 | data_specs: (obs, action , reward, discount) 60 | meta_specs: any extra e.g. (skill, mode) 61 | """ 62 | 63 | self._data_specs = data_specs 64 | self._meta_specs = meta_specs 65 | self._replay_dir = replay_dir 66 | replay_dir.mkdir(exist_ok=True) 67 | self._current_episode = defaultdict(list) 68 | self._n_episodes, self._n_transitions = _preload(replay_dir) 69 | 70 | def __len__(self): 71 | return self._n_transitions 72 | 73 | def add(self, 74 | time_step: dm_env.TimeStep, 75 | meta: OrderedDict 76 | ): 77 | self._add_meta(meta=meta) 78 | self._add_time_step(time_step=time_step) 79 | if time_step.last(): 80 | self._store_episode() 81 | 82 | def _add_meta(self, meta: OrderedDict): 83 | for spec in self._meta_specs: 84 | value = meta[spec.name] 85 | if np.isscalar(value): 86 | value = np.full(spec.shape, value, spec.dtype) 87 | self._current_episode[spec.name].append(value) 88 | # for key, value in meta.items(): 89 | # self._current_episode[key].append(value) 90 | 91 | def _add_time_step(self, time_step: dm_env.TimeStep): 92 | for spec in self._data_specs: 93 | value = time_step[spec.name] 94 | if np.isscalar(value): 95 | # convert it to a numpy array as shape given by the data specs (reward & discount) 96 | value = np.full(spec.shape, value, spec.dtype) 97 | assert spec.shape == value.shape and spec.dtype == value.dtype 98 | self._current_episode[spec.name].append(value) 99 | 100 | def _store_episode(self): 101 | episode = dict() 102 | 103 | # datas to save as numpy array 104 | for spec in self._data_specs: 105 | value = self._current_episode[spec.name] 106 | episode[spec.name] = np.array(value, spec.dtype) 107 | 108 | # metas to save as numpy array 109 | for spec in self._meta_specs: 110 | value = self._current_episode[spec.name] 111 | episode[spec.name] = np.array(value, spec.dtype) 112 | 113 | # reset current episode content 114 | self._current_episode = defaultdict(list) 115 | 116 | # save episode 117 | eps_idx = self._n_episodes 118 | eps_len = compute_episode_len(episode) 119 | self._n_episodes += 1 120 | self._n_transitions += eps_len 121 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 122 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz' 123 | save_episode(episode, self._replay_dir / eps_fn) 124 | 125 | 126 | def load_episode(fn): 127 | with fn.open('rb') as f: 128 | episode = np.load(f) 129 | episode = {k: episode[k] for k in episode.keys()} 130 | return episode 131 | 132 | 133 | class ReplayBuffer(IterableDataset): 134 | 135 | def __init__(self, 136 | storage: ReplayBufferStorage, 137 | max_size: int, 138 | num_workers: int, 139 | nstep: int, 140 | discount: int, 141 | fetch_every: int, 142 | save_snapshot: bool 143 | ): 144 | self._storage = storage 145 | self._size = 0 146 | self._max_size = max_size 147 | self._num_workers = max(1, num_workers) 148 | self._episode_fns = [] 149 | self._episodes = dict() 150 | self._nstep = nstep 151 | self._discount = discount 152 | self._fetch_every = fetch_every 153 | self._samples_since_last_fetch = fetch_every 154 | self._save_snapshot = save_snapshot 155 | 156 | def __len__(self): 157 | return len(self._storage) 158 | 159 | def _sample_episode(self): 160 | """ Sample a single episode """ 161 | eps_fn = random.choice(self._episode_fns) 162 | return self._episodes[eps_fn] 163 | 164 | def _store_episode(self, eps_fn): 165 | """ 166 | load an episode in memory with dict self._episodes 167 | and self._episode_fns contains the sorted keys 168 | and deletes the file 169 | """ 170 | try: 171 | episode = load_episode(eps_fn) 172 | except: 173 | return False 174 | eps_len = compute_episode_len(episode) 175 | # remove old episodes if max size is reached 176 | while eps_len + self._size > self._max_size: 177 | early_eps_fn = self._episode_fns.pop(0) 178 | early_eps = self._episodes.pop(early_eps_fn) 179 | self._size -= compute_episode_len(early_eps) 180 | early_eps_fn.unlink(missing_ok=True) 181 | # store the episode 182 | self._episode_fns.append(eps_fn) 183 | self._episode_fns.sort() 184 | self._episodes[eps_fn] = episode 185 | self._size += eps_len 186 | 187 | # delete episode if save_snapshot false 188 | if not self._save_snapshot: 189 | eps_fn.unlink(missing_ok=True) 190 | return True 191 | 192 | def _try_fetch(self): 193 | """ 194 | Fetch all episodes, divided between workers 195 | """ 196 | if self._samples_since_last_fetch < self._fetch_every: 197 | return 198 | self._samples_since_last_fetch = 0 199 | try: 200 | worker_id = torch.utils.data.get_worker_info().id 201 | except: 202 | worker_id = 0 203 | 204 | # last created to first created 205 | eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True) 206 | fetched_size = 0 207 | # load all episodes 208 | for eps_fn in eps_fns: 209 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]] 210 | # each worker load an episode 211 | if eps_idx % self._num_workers != worker_id: 212 | continue 213 | if eps_fn in self._episodes.keys(): 214 | break 215 | if fetched_size + eps_len > self._max_size: 216 | break 217 | fetched_size += eps_len 218 | # stop if fail to load episode 219 | if not self._store_episode(eps_fn): 220 | break 221 | 222 | def _sample(self 223 | ): # -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, ...]: 224 | try: 225 | self._try_fetch() 226 | except: 227 | traceback.print_exc() 228 | self._samples_since_last_fetch += 1 229 | episode = self._sample_episode() 230 | # add +1 for the first dummy transition so only starts above 1 and below max-nstep 231 | # idx to take inside the episode 232 | idx = np.random.randint(0, compute_episode_len(episode) - self._nstep + 1) + 1 233 | meta = dict() 234 | # meta = [] 235 | for spec in self._storage._meta_specs: 236 | meta[spec.name] = episode[spec.name][idx - 1] 237 | # meta.append(episode[spec.name][idx - 1]) 238 | obs = episode['observation'][idx - 1] # account for first dummy transition 239 | action = episode['action'][idx] # on first dummy transition action is set to 0 240 | next_obs = episode['observation'][idx + self._nstep - 1]# account for first dummy transition 241 | reward = np.zeros_like(episode['reward'][idx]) 242 | discount = np.ones_like(episode['discount'][idx]) 243 | for i in range(self._nstep): 244 | step_reward = episode['reward'][idx + i] 245 | reward += discount * step_reward 246 | discount *= episode['discount'][idx + i] * self._discount 247 | # noinspection PyRedundantParentheses 248 | data = dict( 249 | observation=obs, 250 | action=action, 251 | reward=reward, 252 | discount=discount, 253 | next_observation=next_obs, 254 | ) 255 | data.update(meta) 256 | return data 257 | # return (obs, action, reward, discount, next_obs, *meta) 258 | 259 | def __iter__(self): 260 | while True: 261 | yield self._sample() 262 | 263 | 264 | class RepeatSampler(object): 265 | """ Sampler that repeats forever. 266 | Args: 267 | sampler (Sampler) 268 | """ 269 | 270 | def __init__(self, sampler): 271 | self.sampler = sampler 272 | 273 | def __iter__(self): 274 | while True: 275 | yield from iter(self.sampler) 276 | 277 | 278 | def numpy_collate(batch: List[Dict], meta_specs): 279 | res = defaultdict(list) 280 | for b in batch: 281 | for k, v in b.items(): 282 | res[k].append(v) 283 | extras = dict() 284 | for spec in meta_specs: 285 | extras[spec.name] = np.stack(res[spec.name]) 286 | return Batch( 287 | observation=np.stack(res['observation']), 288 | action=np.stack(res['action']), 289 | reward=np.stack(res['reward']), 290 | discount=np.stack(res['discount']), 291 | next_observation=np.stack(res['next_observation']), 292 | extras=extras 293 | ) 294 | 295 | def numpy_collate_mode(batch: List[Dict], meta_specs): 296 | res_mode0 = defaultdict(list) 297 | res_mode1 = defaultdict(list) 298 | for b in batch: 299 | 300 | if b['mode'] == 0: 301 | res_mode0['skill'].append(b['skill']) 302 | res_mode0['observation'].append(b['observation']) 303 | res_mode0['next_observation'].append(b['next_observation']) 304 | res_mode0['action'].append(b['action']) 305 | res_mode0['reward'].append(b['reward']) 306 | res_mode0['discount'].append(b['discount']) 307 | elif b['mode'] == 1: 308 | res_mode1['skill'].append(b['skill']) 309 | res_mode1['observation'].append(b['observation']) 310 | res_mode1['next_observation'].append(b['next_observation']) 311 | res_mode1['action'].append(b['action']) 312 | res_mode1['reward'].append(b['reward']) 313 | # res_mode1['discount'].append(b['discount'] * 0.25) 314 | res_mode1['discount'].append(b['discount']) 315 | 316 | extras = dict() 317 | # for spec in meta_specs: 318 | extras['skill'] = [np.stack(res_mode0['skill']), np.stack(res_mode1['skill'])] 319 | # extras['skill'] = [] #[np.stack(res_mode0[spec.name]), np.stack(res_mode1[spec.name])] 320 | # if len(res_mode0['skill']): 321 | # extras['skill'].append(np.stack(res_mode0['skill'])) 322 | # if len(res_mode1['skill']): 323 | # extras['skill'].append(np.stack(res_mode1['skill'])) 324 | 325 | return Batch( 326 | observation=[np.stack(res_mode0['observation']), np.stack(res_mode1['observation'])], 327 | action=[np.stack(res_mode0['action']), np.stack(res_mode1['action'])], 328 | reward=[np.stack(res_mode0['reward']), np.stack(res_mode1['reward'])], 329 | discount=[np.stack(res_mode0['discount']), np.stack(res_mode1['discount'])], 330 | next_observation=[np.stack(res_mode0['next_observation']), np.stack(res_mode1['next_observation'])], 331 | extras=extras 332 | ) 333 | 334 | 335 | def _worker_init_fn(worker_id): 336 | seed = np.random.get_state()[1][0] + worker_id 337 | np.random.seed(seed) 338 | random.seed(seed) 339 | 340 | 341 | def make_replay_loader(storage, 342 | max_size, 343 | batch_size, 344 | num_workers, 345 | nstep, 346 | discount, 347 | meta_specs, 348 | save_snapshot: bool = False): 349 | 350 | 351 | if 'mode' in [spec.name for spec in meta_specs]: 352 | # collate_fct = functools.partial(numpy_collate, meta_specs=meta_specs) 353 | collate_fct = functools.partial(numpy_collate_mode, meta_specs=meta_specs) 354 | else: 355 | collate_fct = functools.partial(numpy_collate, meta_specs=meta_specs) 356 | 357 | max_size_per_worker = max_size // max(1, num_workers) 358 | iterable = ReplayBuffer(storage, 359 | max_size_per_worker, 360 | num_workers, 361 | nstep, 362 | discount, 363 | fetch_every=1000, 364 | save_snapshot=save_snapshot) 365 | 366 | loader = torch.utils.data.DataLoader(iterable, 367 | batch_size=batch_size, 368 | num_workers=num_workers, 369 | pin_memory=True, 370 | worker_init_fn=_worker_init_fn, 371 | collate_fn=collate_fct 372 | ) 373 | return loader 374 | -------------------------------------------------------------------------------- /core/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import dm_env 4 | 5 | from .dmc import make as make_dmc_env 6 | 7 | 8 | def make_env(action_type: str, cfg: NamedTuple, seed: int) -> dm_env.Environment: 9 | if action_type == 'continuous': 10 | return make_dmc_env(cfg.task, cfg.obs_type, cfg.frame_stack, cfg.action_repeat, seed) 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /core/envs/dmc.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, deque 2 | from typing import Any, NamedTuple 3 | 4 | import dm_env 5 | import numpy as np 6 | from dm_control import suite 7 | from dm_control.suite.wrappers import action_scale, pixels 8 | from dm_env import StepType, specs 9 | 10 | from core import custom_dmc_tasks as cdmc 11 | from core.envs.wrappers import InformativeTimestepWrapper, DMCTimeWrapper 12 | 13 | 14 | class ExtendedTimeStep(NamedTuple): 15 | step_type: Any 16 | reward: Any 17 | discount: Any 18 | observation: Any 19 | action: Any 20 | 21 | def first(self): 22 | return self.step_type == StepType.FIRST 23 | 24 | def mid(self): 25 | return self.step_type == StepType.MID 26 | 27 | def last(self): 28 | return self.step_type == StepType.LAST 29 | 30 | def __getitem__(self, attr): 31 | return getattr(self, attr) 32 | 33 | 34 | class FlattenJacoObservationWrapper(dm_env.Environment): 35 | def __init__(self, env): 36 | self._env = env 37 | self._obs_spec = OrderedDict() 38 | wrapped_obs_spec = env.observation_spec().copy() 39 | if 'front_close' in wrapped_obs_spec: 40 | spec = wrapped_obs_spec['front_close'] 41 | # drop batch dim 42 | self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:], 43 | dtype=spec.dtype, 44 | minimum=spec.minimum, 45 | maximum=spec.maximum, 46 | name='pixels') 47 | wrapped_obs_spec.pop('front_close') 48 | 49 | for key, spec in wrapped_obs_spec.items(): 50 | assert spec.dtype == np.float64 51 | assert type(spec) == specs.Array 52 | dim = np.sum( 53 | np.fromiter((np.int(np.prod(spec.shape)) 54 | for spec in wrapped_obs_spec.values()), np.int32)) 55 | 56 | self._obs_spec['observations'] = specs.Array(shape=(dim,), 57 | dtype=np.float32, 58 | name='observations') 59 | 60 | def _transform_observation(self, time_step): 61 | obs = OrderedDict() 62 | 63 | if 'front_close' in time_step.observation: 64 | pixels = time_step.observation['front_close'] 65 | time_step.observation.pop('front_close') 66 | pixels = np.squeeze(pixels) 67 | obs['pixels'] = pixels 68 | 69 | features = [] 70 | for feature in time_step.observation.values(): 71 | features.append(feature.ravel()) 72 | obs['observations'] = np.concatenate(features, axis=0) 73 | return time_step._replace(observation=obs) 74 | 75 | def reset(self): 76 | time_step = self._env.reset() 77 | return self._transform_observation(time_step) 78 | 79 | def step(self, action): 80 | time_step = self._env.step(action) 81 | return self._transform_observation(time_step) 82 | 83 | def observation_spec(self): 84 | return self._obs_spec 85 | 86 | def action_spec(self): 87 | return self._env.action_spec() 88 | 89 | def __getattr__(self, name): 90 | return getattr(self._env, name) 91 | 92 | 93 | class ActionRepeatWrapper(dm_env.Environment): 94 | def __init__(self, env, num_repeats): 95 | self._env = env 96 | self._num_repeats = num_repeats 97 | 98 | def step(self, action): 99 | reward = 0.0 100 | discount = 1.0 101 | for i in range(self._num_repeats): 102 | time_step = self._env.step(action) 103 | reward += (time_step.reward or 0.0) * discount 104 | discount *= time_step.discount 105 | if time_step.last(): 106 | break 107 | 108 | return time_step._replace(reward=reward, discount=discount) 109 | 110 | def observation_spec(self): 111 | return self._env.observation_spec() 112 | 113 | def action_spec(self): 114 | return self._env.action_spec() 115 | 116 | def reset(self): 117 | return self._env.reset() 118 | 119 | def __getattr__(self, name): 120 | return getattr(self._env, name) 121 | 122 | 123 | class FrameStackWrapper(dm_env.Environment): 124 | def __init__(self, env, num_frames, pixels_key='pixels'): 125 | self._env = env 126 | self._num_frames = num_frames 127 | self._frames = deque([], maxlen=num_frames) 128 | self._pixels_key = pixels_key 129 | 130 | wrapped_obs_spec = env.observation_spec() 131 | assert pixels_key in wrapped_obs_spec 132 | 133 | pixels_shape = wrapped_obs_spec[pixels_key].shape 134 | # remove batch dim 135 | if len(pixels_shape) == 4: 136 | pixels_shape = pixels_shape[1:] 137 | self._obs_spec = specs.BoundedArray(shape=np.concatenate( 138 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0), 139 | dtype=np.uint8, 140 | minimum=0, 141 | maximum=255, 142 | name='observation') 143 | 144 | def _transform_observation(self, time_step): 145 | assert len(self._frames) == self._num_frames 146 | obs = np.concatenate(list(self._frames), axis=0) 147 | return time_step._replace(observation=obs) 148 | 149 | def _extract_pixels(self, time_step): 150 | pixels = time_step.observation[self._pixels_key] 151 | # remove batch dim 152 | if len(pixels.shape) == 4: 153 | pixels = pixels[0] 154 | return pixels.transpose(2, 0, 1).copy() 155 | 156 | def reset(self): 157 | time_step = self._env.reset() 158 | pixels = self._extract_pixels(time_step) 159 | for _ in range(self._num_frames): 160 | self._frames.append(pixels) 161 | return self._transform_observation(time_step) 162 | 163 | def step(self, action): 164 | time_step = self._env.step(action) 165 | pixels = self._extract_pixels(time_step) 166 | self._frames.append(pixels) 167 | return self._transform_observation(time_step) 168 | 169 | def observation_spec(self): 170 | return self._obs_spec 171 | 172 | def action_spec(self): 173 | return self._env.action_spec() 174 | 175 | def __getattr__(self, name): 176 | return getattr(self._env, name) 177 | 178 | 179 | class ActionDTypeWrapper(dm_env.Environment): 180 | def __init__(self, env, dtype): 181 | self._env = env 182 | wrapped_action_spec = env.action_spec() 183 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, 184 | dtype, 185 | wrapped_action_spec.minimum, 186 | wrapped_action_spec.maximum, 187 | 'action') 188 | 189 | def step(self, action): 190 | action = action.astype(self._env.action_spec().dtype) 191 | return self._env.step(action) 192 | 193 | def observation_spec(self): 194 | return self._env.observation_spec() 195 | 196 | def action_spec(self): 197 | return self._action_spec 198 | 199 | def reset(self): 200 | return self._env.reset() 201 | 202 | def __getattr__(self, name): 203 | return getattr(self._env, name) 204 | 205 | 206 | class ObservationDTypeWrapper(dm_env.Environment): 207 | def __init__(self, env, dtype): 208 | self._env = env 209 | self._dtype = dtype 210 | wrapped_obs_spec = env.observation_spec()['observations'] 211 | self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype, 212 | 'observation') 213 | 214 | def _transform_observation(self, time_step): 215 | obs = time_step.observation['observations'].astype(self._dtype) 216 | return time_step._replace(observation=obs) 217 | 218 | def reset(self): 219 | time_step = self._env.reset() 220 | return self._transform_observation(time_step) 221 | 222 | def step(self, action): 223 | time_step = self._env.step(action) 224 | return self._transform_observation(time_step) 225 | 226 | def observation_spec(self): 227 | return self._obs_spec 228 | 229 | def action_spec(self): 230 | return self._env.action_spec() 231 | 232 | def __getattr__(self, name): 233 | return getattr(self._env, name) 234 | 235 | 236 | class ExtendedTimeStepWrapper(dm_env.Environment): 237 | def __init__(self, env): 238 | self._env = env 239 | 240 | def reset(self): 241 | time_step = self._env.reset() 242 | return self._augment_time_step(time_step) 243 | 244 | def step(self, action): 245 | time_step = self._env.step(action) 246 | return self._augment_time_step(time_step, action) 247 | 248 | def _augment_time_step(self, time_step, action=None): 249 | if action is None: 250 | action_spec = self.action_spec() 251 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 252 | return ExtendedTimeStep(observation=time_step.observation, 253 | step_type=time_step.step_type, 254 | action=action, 255 | reward=time_step.reward or 0.0, 256 | discount=time_step.discount or 1.0) 257 | 258 | def observation_spec(self): 259 | return self._env.observation_spec() 260 | 261 | def action_spec(self): 262 | return self._env.action_spec() 263 | 264 | def __getattr__(self, name): 265 | return getattr(self._env, name) 266 | 267 | 268 | def _make_jaco(obs_type, domain, task, frame_stack, action_repeat, seed): 269 | env = cdmc.make_jaco(task, obs_type, seed) 270 | env = ActionDTypeWrapper(env, np.float32) 271 | env = ActionRepeatWrapper(env, action_repeat) 272 | env = FlattenJacoObservationWrapper(env) 273 | return env 274 | 275 | 276 | def _make_dmc(obs_type, domain, task, frame_stack, action_repeat, seed): 277 | visualize_reward = False 278 | if (domain, task) in suite.ALL_TASKS: 279 | env = suite.load(domain, 280 | task, 281 | task_kwargs=dict(random=seed), 282 | environment_kwargs=dict(flat_observation=True), 283 | visualize_reward=visualize_reward) 284 | else: 285 | env = cdmc.make(domain, 286 | task, 287 | task_kwargs=dict(random=seed), 288 | environment_kwargs=dict(flat_observation=True), 289 | visualize_reward=visualize_reward) 290 | 291 | env = ActionDTypeWrapper(env, np.float32) 292 | env = ActionRepeatWrapper(env, action_repeat) 293 | if obs_type == 'pixels': 294 | # zoom in camera for quadruped 295 | camera_id = dict(quadruped=2).get(domain, 0) 296 | render_kwargs = dict(height=84, width=84, camera_id=camera_id) 297 | env = pixels.Wrapper(env, 298 | pixels_only=True, 299 | render_kwargs=render_kwargs) 300 | return env 301 | 302 | 303 | def make(name, obs_type, frame_stack, action_repeat, seed): 304 | assert obs_type in ['states', 'pixels'] 305 | domain, task = name.split('_', 1) 306 | domain = dict(cup='ball_in_cup').get(domain, domain) 307 | 308 | make_fn = _make_jaco if domain == 'jaco' else _make_dmc 309 | env = make_fn(obs_type, domain, task, frame_stack, action_repeat, seed) 310 | 311 | if obs_type == 'pixels': 312 | env = FrameStackWrapper(env, frame_stack) 313 | else: 314 | env = ObservationDTypeWrapper(env, np.float32) 315 | 316 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) 317 | env = ExtendedTimeStepWrapper(env) 318 | return InformativeTimestepWrapper( 319 | DMCTimeWrapper( 320 | env, 321 | ) 322 | ) 323 | -------------------------------------------------------------------------------- /core/envs/dmc_benchmark.py: -------------------------------------------------------------------------------- 1 | from core.custom_dmc_tasks import quadruped, jaco, cheetah, walker, hopper 2 | 3 | 4 | def make(domain, task, 5 | task_kwargs=None, 6 | environment_kwargs=None, 7 | visualize_reward=False): 8 | if domain == 'cheetah': 9 | return cheetah.make(task, 10 | task_kwargs=task_kwargs, 11 | environment_kwargs=environment_kwargs, 12 | visualize_reward=visualize_reward) 13 | elif domain == 'walker': 14 | return walker.make(task, 15 | task_kwargs=task_kwargs, 16 | environment_kwargs=environment_kwargs, 17 | visualize_reward=visualize_reward) 18 | elif domain == 'hopper': 19 | return hopper.make(task, 20 | task_kwargs=task_kwargs, 21 | environment_kwargs=environment_kwargs, 22 | visualize_reward=visualize_reward) 23 | elif domain == 'quadruped': 24 | return quadruped.make(task, 25 | task_kwargs=task_kwargs, 26 | environment_kwargs=environment_kwargs, 27 | visualize_reward=visualize_reward) 28 | else: 29 | raise f'{task} not found' 30 | 31 | assert None 32 | 33 | 34 | def make_jaco(task, obs_type, seed): 35 | return jaco.make(task, obs_type, seed) -------------------------------------------------------------------------------- /core/envs/wrappers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, NamedTuple 2 | from collections import deque 3 | 4 | import dm_env 5 | from dm_env import StepType, TimeStep 6 | from jax import numpy as jnp 7 | import numpy as np 8 | 9 | 10 | class InformativeTimeStep(NamedTuple): 11 | step_type: StepType 12 | reward: float 13 | discount: float 14 | observation: jnp.ndarray 15 | action: jnp.ndarray 16 | mode: int 17 | current_timestep: int 18 | max_timestep: int 19 | 20 | def first(self) -> bool: 21 | return self.step_type == StepType.FIRST 22 | 23 | def mid(self) -> bool: 24 | return self.step_type == StepType.MID 25 | 26 | def last(self) -> bool: 27 | return self.step_type == StepType.LAST 28 | 29 | def __getitem__(self, attr): 30 | return getattr(self, attr) 31 | 32 | 33 | def timestep2informative_timestep( 34 | timestep: TimeStep, 35 | action: Optional[jnp.ndarray] = None, 36 | mode: Optional[int] = None, 37 | current_timestep: Optional[int] = None, 38 | max_timestep: Optional[int] = None,) -> InformativeTimeStep: 39 | return InformativeTimeStep( 40 | step_type=timestep.step_type, 41 | reward=timestep.reward, 42 | discount=timestep.discount, 43 | observation=timestep.observation, 44 | action=action, 45 | mode=mode, 46 | current_timestep=current_timestep, 47 | max_timestep=max_timestep, 48 | ) 49 | 50 | 51 | class Wrapper(dm_env.Environment): 52 | def __init__(self, env: dm_env.Environment): 53 | self._env = env 54 | # inherent some attributes from env, like time counter, etc 55 | for attr, val in vars(self._env).items(): 56 | if attr not in vars(self): 57 | setattr(self, attr, val) 58 | 59 | def action_spec(self): 60 | return self._env.action_spec() 61 | 62 | @property 63 | def timestep(self): 64 | return self._env._timestep 65 | 66 | @property 67 | def max_timestep(self): 68 | return self._env.max_timestep 69 | 70 | def reset(self): 71 | return self._env.reset() 72 | 73 | def observation_spec(self): 74 | return self._env.observation_spec() 75 | 76 | def step(self, action: np.ndarray) -> dm_env.TimeStep: 77 | return self._env.step(action) 78 | 79 | def __getattr__(self, name): 80 | return getattr(self._env, name) 81 | 82 | 83 | class FrameStacker(Wrapper): 84 | def __init__(self, env: dm_env.Environment, frame_stack: int = 3): 85 | super().__init__(env) 86 | self._observation = deque(maxlen=frame_stack) 87 | self.n_stacks = frame_stack 88 | 89 | def observation_spec(self): 90 | single_observation_spec = self._env.observation_spec() 91 | new_shape = list(single_observation_spec.shape) 92 | new_shape[self._env._channel_axis] = new_shape[self._env._channel_axis] * self.n_stacks 93 | return dm_env.specs.Array( 94 | shape=tuple(new_shape), 95 | dtype=single_observation_spec.dtype, 96 | name=single_observation_spec.name 97 | ) 98 | 99 | def reset(self,) -> dm_env.TimeStep: 100 | timestep = self._env.reset() 101 | # stack n_stacks init frames for first observation 102 | for _ in range(self.n_stacks): 103 | self._observation.append(timestep.observation) 104 | return timestep._replace( 105 | observation=np.concatenate(self._observation, axis=self._env._channel_axis)) 106 | 107 | def step(self, action: np.ndarray) -> dm_env.TimeStep: 108 | timestep = self._env.step(action) 109 | self._observation.append(timestep.observation) 110 | return timestep._replace( 111 | observation=np.concatenate(self._observation, axis=self._env._channel_axis)) 112 | 113 | 114 | class ActionRepeater(Wrapper): 115 | def __init__(self, env: dm_env.Environment, nrepeats: int = 3): 116 | super().__init__(env) 117 | self._nrepeats = nrepeats 118 | 119 | def reset(self,) -> dm_env.TimeStep: 120 | return self._env.reset() 121 | 122 | def step(self, action: np.ndarray) -> dm_env.TimeStep: 123 | for _ in range(self._nrepeats): 124 | timestep = self._env.step(action) 125 | return timestep 126 | 127 | 128 | class InformativeTimestepWrapper(Wrapper): 129 | def __init__(self, env: dm_env.Environment): 130 | super().__init__(env) 131 | 132 | def reset(self,) -> InformativeTimeStep: 133 | timestep = self._env.reset() 134 | action_spec = self.action_spec() 135 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 136 | return timestep2informative_timestep( 137 | # this namedtuple contains obs, reward, etc. 138 | timestep, 139 | action=action, 140 | # this is the time spent in this episode 141 | current_timestep=self._env.timestep, 142 | max_timestep=self._env.max_timestep, 143 | ) 144 | 145 | def step(self, action: np.ndarray) -> InformativeTimeStep: 146 | timestep = self._env.step(action) 147 | return timestep2informative_timestep( 148 | timestep, 149 | action=action, 150 | current_timestep=self._env.timestep, 151 | max_timestep=self._env.max_timestep, 152 | ) 153 | 154 | 155 | class RewardScaler(Wrapper): 156 | def __init__(self, env: dm_env.Environment, reward_scale: float): 157 | super().__init__(env) 158 | self._reward_scale = reward_scale 159 | 160 | def step(self, action: np.ndarray) -> dm_env.TimeStep: 161 | timestep = self._env.step(action) 162 | return dm_env.TimeStep( 163 | step_type=timestep.step_type, reward=timestep.reward * self._reward_scale, 164 | discount=timestep.discount, observation=timestep.observation 165 | ) 166 | 167 | 168 | class DMCTimeWrapper(Wrapper): 169 | def __init__(self, env: dm_env.Environment,): 170 | super().__init__(env) 171 | self._env = env 172 | self._timestep = 0 173 | self.action_shape = self._env.action_spec().shape 174 | 175 | @property 176 | def max_timestep(self,) -> int: 177 | # last step 178 | if hasattr(self._env, '_time_limit'): 179 | return self._env._time_limit / self._env._task.control_timestep 180 | if hasattr(self._env, '_step_limit'): 181 | return self._env._step_limit 182 | 183 | @property 184 | def timestep(self,) -> int: 185 | # current in the episode 186 | return self._timestep 187 | 188 | def step(self, action: np.ndarray) -> dm_env.TimeStep: 189 | self._timestep += 1 190 | return self._env.step(action) 191 | 192 | def reset(self,) -> dm_env.TimeStep: 193 | self._timestep = 0 194 | return self._env.reset() 195 | -------------------------------------------------------------------------------- /core/exp_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | 5 | from .checkpointing import Checkpointer 6 | from .video import VideoRecorder, TrainVideoRecorder 7 | from .loggers import MetricLogger, log_params_to_wandb, LogParamsEvery, Timer, Until, Every, dict_to_header 8 | 9 | 10 | def set_seed(seed): 11 | torch.manual_seed(seed) 12 | # if torch.cuda.is_available(): 13 | # torch.cuda.manual_seed_all(seed) 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | -------------------------------------------------------------------------------- /core/exp_utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Any, Mapping, Text 4 | from functools import partial 5 | 6 | import dill 7 | import jax 8 | import jax.numpy as jnp 9 | from iopath.common.file_io import PathManager 10 | 11 | # from core.misc.utils import broadcast 12 | 13 | logger = logging.getLogger(__name__) 14 | path_manager = PathManager() 15 | 16 | def tag_last_checkpoint(save_dir: str, 17 | last_filename_basename: str) -> None: 18 | """ save name of the last checkpoint in the file `last_checkpoint` """ 19 | save_file = os.path.join(save_dir, "last_checkpoint") 20 | with path_manager.open(save_file, "w") as f: 21 | f.write(last_filename_basename) 22 | 23 | def save_state(save_dir: str, 24 | name: str, 25 | state: Mapping[Text, jnp.ndarray], 26 | step: int, 27 | rng, 28 | **kwargs: Any) -> None: 29 | n_devices = jax.local_device_count() 30 | if jax.process_index() != 0: # only checkpoint the first worker 31 | return 32 | checkpoint_data = dict( 33 | # state=state, 34 | state= jax.tree_map( 35 | lambda x: jax.device_get(x[0]) if n_devices > 1 else jax.device_get(x), state), 36 | step=step, 37 | rng=rng 38 | ) 39 | checkpoint_data.update(kwargs) 40 | basename = "{}.pth".format(name) 41 | save_file = os.path.join(save_dir, basename) 42 | assert os.path.basename(save_file) == basename, basename 43 | logger.info("Saving checkpoint to {}".format(save_file)) 44 | with path_manager.open(save_file, "wb") as f: 45 | dill.dump(checkpoint_data, f) 46 | # tag it for auto resuming 47 | tag_last_checkpoint( 48 | save_dir=save_dir, 49 | last_filename_basename=basename, 50 | ) 51 | 52 | def has_last_checkpoint(save_dir:str) -> bool: 53 | save_dir = os.path.join(save_dir, "last_checkpoint") 54 | return path_manager.exists(save_dir) 55 | 56 | def get_last_checkpoint(save_dir: str) -> str: 57 | save_file = os.path.join(save_dir, "last_checkpoint") 58 | try: 59 | with path_manager.open(save_file, "r") as f: 60 | last_saved = f.read().strip() 61 | except IOError: 62 | # if file doesn't exist, maybe because it has just been 63 | # deleted by a separate process 64 | return "" 65 | return os.path.join(save_dir, last_saved) 66 | 67 | def resume_or_load(path: str, save_dir, *, resume: bool = False): 68 | if resume and has_last_checkpoint(save_dir): 69 | path = get_last_checkpoint(save_dir) 70 | return load_checkpoint(path) 71 | else: 72 | return load_checkpoint(path) 73 | 74 | def load_checkpoint(path: str) -> Mapping[str, Any]: 75 | """ 76 | :param path: 77 | :return: empty dict if checkpoint doesn't exist 78 | """ 79 | if not path: 80 | logger.info("No checkpoint given.") 81 | return dict() 82 | 83 | if not os.path.isfile(path): 84 | path = path_manager.get_local_path(path) 85 | assert os.path.isfile(path), "Checkpoint {} not found!".format(path) 86 | 87 | with path_manager.open(path, 'rb') as checkpoint_file: 88 | checkpoint = dill.load(checkpoint_file) 89 | logger.info('Loading checkpoint from %s', checkpoint_file) 90 | 91 | return checkpoint 92 | 93 | class Checkpointer: 94 | def __init__(self, 95 | save_dir: str = "checkpoints", 96 | ): 97 | self.save_dir = save_dir 98 | os.makedirs(save_dir, exist_ok=True) 99 | self.save_state = partial(save_state, save_dir=save_dir) 100 | self.load_checkpoint = load_checkpoint 101 | self.resume_or_load = partial(resume_or_load, save_dir=save_dir) 102 | -------------------------------------------------------------------------------- /core/exp_utils/loggers.py: -------------------------------------------------------------------------------- 1 | import random 2 | import timeit 3 | import time 4 | import contextlib 5 | import logging 6 | from collections import defaultdict 7 | 8 | import haiku as hk 9 | import csv 10 | import torch 11 | import numpy as np 12 | import jax.numpy as jnp 13 | import wandb 14 | from pathlib import Path 15 | 16 | #TODO remove those 17 | def log_params_to_wandb(params: hk.Params, step: int): 18 | if params: 19 | for module in sorted(params): 20 | if 'w' in params[module]: 21 | wandb.log({ 22 | f'{module}/w': wandb.Histogram(params[module]['w']) 23 | }, step=step) 24 | if 'b' in params[module]: 25 | wandb.log({ 26 | f'{module}/b': wandb.Histogram(params[module]['b']) 27 | }, step=step) 28 | 29 | class LogParamsEvery: 30 | def __init__(self, every, action_repeat=1): 31 | self._every = every 32 | self._action_repeat = action_repeat 33 | 34 | def __call__(self, params: hk.Params, step): 35 | if self._every is None: 36 | pass 37 | every = self._every // self._action_repeat 38 | if step % every == 0: 39 | log_params_to_wandb(params, step) 40 | pass 41 | 42 | class Until: 43 | def __init__(self, until, action_repeat=1): 44 | self._until = until 45 | self._action_repeat = action_repeat 46 | 47 | def __call__(self, step): 48 | if self._until is None: 49 | return True 50 | until = self._until // self._action_repeat 51 | return step < until 52 | 53 | 54 | class Every: 55 | def __init__(self, every, action_repeat=1): 56 | self._every = every 57 | self._action_repeat = action_repeat 58 | 59 | def __call__(self, step): 60 | if self._every is None: 61 | return False 62 | every = self._every // self._action_repeat 63 | if step % every == 0: 64 | return True 65 | return False 66 | 67 | 68 | class Timer: 69 | def __init__(self): 70 | self._start_time = time.time() 71 | self._last_time = time.time() 72 | 73 | def reset(self): 74 | elapsed_time = time.time() - self._last_time 75 | self._last_time = time.time() 76 | total_time = time.time() - self._start_time 77 | return elapsed_time, total_time 78 | 79 | def total_time(self): 80 | return time.time() - self._start_time 81 | 82 | @contextlib.contextmanager 83 | def time_activity(activity_name: str): 84 | logger = logging.getLogger(__name__) 85 | start = timeit.default_timer() 86 | yield 87 | duration = timeit.default_timer() - start 88 | logger.info('[Timing] %s finished (Took %.2fs).', activity_name, duration) 89 | 90 | class AverageMeter: 91 | def __init__(self): 92 | self._sum = 0. 93 | self._count = 0 94 | self.fmt = "{value:.4f}" 95 | 96 | def update(self, value, n=1): 97 | self._sum += value 98 | self._count += n 99 | 100 | @property 101 | def value(self): 102 | return self._sum / max(1, self._count) 103 | 104 | def __str__(self): 105 | return self.fmt.format( 106 | value=self.value 107 | ) 108 | 109 | def dict_to_header(data: dict, header=None): 110 | if header is not None: 111 | header = [header] 112 | else: 113 | header = [] 114 | delimiter = '\t' 115 | for name, value in data.items(): 116 | if type(value) == float: 117 | header.append( 118 | '{}: {:.4f}'.format(name, value) 119 | ) 120 | elif type(value) == np.ndarray: # reward is a np.ndarray of shape () 121 | header.append( 122 | '{}: {:.4f}'.format(name, value) 123 | ) 124 | else: 125 | header.append( 126 | '{}: {}'.format(name, value) 127 | ) 128 | return delimiter.join(header) 129 | 130 | class MetricLogger: 131 | def __init__(self, 132 | csv_file_name: Path, 133 | use_wandb: bool, 134 | delimiter= "\t" 135 | ): 136 | self.logger = logging.getLogger(__name__) 137 | self._meters = defaultdict(AverageMeter) # factory 138 | self._csv_writer = None 139 | self._csv_file = None 140 | self._csv_file_name = csv_file_name 141 | self.delimiter = delimiter 142 | self.use_wandb = use_wandb 143 | 144 | def update_metrics(self,**kwargs): 145 | """Log the average of variables that are logged per episode""" 146 | for k, v in kwargs.items(): 147 | if isinstance(v, jnp.DeviceArray): 148 | v = v.item() 149 | assert isinstance(v, (float, int)) 150 | self._meters[k].update(v) 151 | 152 | def log_and_dump_metrics_to_wandb(self, step: int, header=''): 153 | """log and dump to wandb metrics""" 154 | if type(header) == dict: 155 | header = dict_to_header(data=header) 156 | self.logger.info(self._log_meters(header=header)) 157 | if self.use_wandb: 158 | for name, meter in self._meters.items(): 159 | wandb.log({name: np.mean(meter.value).item()}, step=step) 160 | self._clean_meters() 161 | 162 | def _clean_meters(self): 163 | self._meters.clear() 164 | 165 | def _remove_old_entries(self, data): 166 | rows = [] 167 | with self._csv_file_name.open('r') as f: 168 | reader = csv.DictReader(f) 169 | for row in reader: 170 | if float(row['episode']) >= data['episode']: # assume episode exist in header of existing file 171 | break 172 | rows.append(row) 173 | with self._csv_file_name.open('w') as f: 174 | writer = csv.DictWriter(f, 175 | fieldnames=sorted(data.keys()), 176 | restval=0.0) 177 | writer.writeheader() 178 | for row in rows: 179 | writer.writerow(row) 180 | 181 | def dump_dict_to_csv(self, data: dict): 182 | """dump to wandb and csv the dict""" 183 | if self._csv_writer is None: 184 | should_write_header = True 185 | if self._csv_file_name.exists(): # if file already exists remove entries 186 | self._remove_old_entries(data) 187 | should_write_header = False 188 | 189 | self._csv_file = self._csv_file_name.open('a') 190 | self._csv_writer = csv.DictWriter( 191 | self._csv_file, 192 | fieldnames=sorted(data.keys()), 193 | restval=0.0 194 | ) 195 | if should_write_header: 196 | self._csv_writer.writeheader() 197 | self._csv_writer.writerow(data) 198 | self._csv_file.flush() 199 | 200 | def dump_dict_to_wandb(self, step: int, data: dict): 201 | for name, value in data.items(): 202 | if self.use_wandb: 203 | wandb.log({name: np.mean(value).item()}, step=step) 204 | 205 | def log_dict(self, header, data): 206 | self.logger.info(dict_to_header(data=data, header=header)) 207 | 208 | def _log_meters(self, header: str): 209 | loss_str = [header] 210 | for name, meter in self._meters.items(): 211 | loss_str.append( 212 | "{}: {}".format(name, str(meter)) 213 | ) 214 | return self.delimiter.join(loss_str) 215 | -------------------------------------------------------------------------------- /core/exp_utils/video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | import wandb 5 | 6 | 7 | class VideoRecorder: 8 | def __init__(self, 9 | root_dir, 10 | render_size=256, 11 | fps=20, 12 | camera_id=0, 13 | use_wandb=False): 14 | if root_dir is not None: 15 | self.save_dir = root_dir / 'eval_video' 16 | self.save_dir.mkdir(exist_ok=True) 17 | else: 18 | self.save_dir = None 19 | 20 | self.render_size = render_size 21 | self.fps = fps 22 | self.frames = [] 23 | self.camera_id = camera_id 24 | self.use_wandb = use_wandb 25 | 26 | def init(self, env, enabled=True): 27 | self.frames = [] 28 | self.enabled = self.save_dir is not None and enabled 29 | self.record(env) 30 | 31 | def record(self, env): 32 | if self.enabled: 33 | if hasattr(env, 'physics'): 34 | frame = env.physics.render(height=self.render_size, 35 | width=self.render_size, 36 | camera_id=self.camera_id) 37 | else: 38 | frame = env.render() 39 | self.frames.append(frame) 40 | 41 | def log_to_wandb(self, step): 42 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2)) 43 | fps, skip = 6, 8 44 | wandb.log({ 45 | 'eval/video': 46 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif") 47 | }, step=step) 48 | 49 | def save(self, file_name, step): 50 | if self.enabled: 51 | if self.use_wandb: 52 | self.log_to_wandb(step) 53 | path = self.save_dir / file_name 54 | imageio.mimsave(str(path), self.frames, fps=self.fps) 55 | 56 | 57 | class TrainVideoRecorder: 58 | def __init__(self, 59 | root_dir, 60 | render_size=256, 61 | fps=20, 62 | camera_id=0, 63 | use_wandb=False): 64 | if root_dir is not None: 65 | self.save_dir = root_dir / 'train_video' 66 | self.save_dir.mkdir(exist_ok=True) 67 | else: 68 | self.save_dir = None 69 | 70 | self.render_size = render_size 71 | self.fps = fps 72 | self.frames = [] 73 | self.camera_id = camera_id 74 | self.use_wandb = use_wandb 75 | 76 | def init(self, obs, enabled=True): 77 | self.frames = [] 78 | self.enabled = self.save_dir is not None and enabled 79 | self.record(obs) 80 | 81 | def record(self, obs): 82 | if self.enabled: 83 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0), 84 | dsize=(self.render_size, self.render_size), 85 | interpolation=cv2.INTER_CUBIC) 86 | self.frames.append(frame) 87 | 88 | def log_to_wandb(self, step): 89 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2)) 90 | fps, skip = 6, 8 91 | wandb.log({ 92 | 'train/video': 93 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif") 94 | }, step=step) 95 | 96 | def save(self, file_name, step): 97 | if self.enabled: 98 | if self.use_wandb: 99 | self.log_to_wandb(step) 100 | path = self.save_dir / file_name 101 | imageio.mimsave(str(path), self.frames, fps=self.fps) 102 | -------------------------------------------------------------------------------- /core/intrinsic/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | 5 | from .intrinsic_reward_base import IntrinsicReward 6 | from .cic import CICReward 7 | from .multimodal_cic import MultimodalCICReward 8 | 9 | def make_intrinsic_reward(cfg): 10 | return hydra.utils.instantiate(cfg) -------------------------------------------------------------------------------- /core/intrinsic/cic.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Mapping, Any, NamedTuple, Tuple, Dict, Union 2 | from functools import partial 3 | import logging 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import haiku as hk 8 | import optax 9 | 10 | from core import intrinsic 11 | from core import calculations 12 | from core import data 13 | 14 | class CICnetwork(hk.Module): 15 | def __init__(self, 16 | hidden_dim: int, 17 | skill_dim: int, 18 | project_skill: bool 19 | ): 20 | super(CICnetwork, self).__init__() 21 | 22 | self.state_net = calculations.mlp(hidden_dim, skill_dim, name='state_net') 23 | # self.state_net = calculations.mlp(hidden_dim, skill_dim//2, name='state_net') 24 | # self.next_state_net = calculations.mlp(hidden_dim, skill_dim, name='next_state_net') 25 | self.pred_net = calculations.mlp(hidden_dim, skill_dim, name='pred_net') 26 | 27 | if project_skill: 28 | self.skill_net = calculations.mlp(hidden_dim, skill_dim, name='skill_net') 29 | else: 30 | self.skill_net = calculations.Identity() 31 | 32 | def __call__(self, state, next_state, skill, is_training=True): # input is obs_dim - skill_dim 33 | state = self.state_net(state) 34 | next_state = self.state_net(next_state) 35 | # next_state = self.next_state_net(next_state) 36 | if is_training: 37 | query = self.skill_net(skill) 38 | key = self.pred_net(jnp.concatenate([state, next_state], axis=-1)) 39 | return query, key 40 | else: 41 | return state, next_state 42 | 43 | def cic_foward(state: jnp.ndarray, 44 | next_state: jnp.ndarray, 45 | skill: jnp.ndarray, 46 | is_training: bool, 47 | network_cfg: Mapping[str, Any] 48 | ): 49 | model: Callable = CICnetwork( 50 | hidden_dim=network_cfg['hidden_dim'], 51 | skill_dim=network_cfg['skill_dim'], 52 | project_skill=network_cfg['project_skill'] 53 | ) 54 | return model(state, next_state, skill, is_training) 55 | 56 | 57 | class CICState(NamedTuple): 58 | cic_params: hk.Params 59 | cic_opt_params: optax.OptState 60 | running_mean: Union[float, None] 61 | running_std: Union[float, None] 62 | running_num: Union[float, None] 63 | 64 | 65 | class CICReward(intrinsic.IntrinsicReward): 66 | def __init__(self, 67 | to_jit: bool, 68 | network_cfg: Mapping[str, Any], 69 | lr: float, 70 | knn_entropy_config, 71 | temperature: float, 72 | name: str = 'cic', 73 | ): 74 | self.cic = hk.without_apply_rng( 75 | hk.transform( 76 | partial( 77 | cic_foward, 78 | network_cfg=network_cfg, 79 | ) 80 | ) 81 | ) 82 | self._cpc_loss = partial(self._cpc_loss, temperature=temperature) 83 | self.init_params = partial(self.init_params, skill_dim=network_cfg['skill_dim']) 84 | self.cic_optimizer = optax.adam(learning_rate=lr) 85 | self.entropy_estimator = partial(calculations.particle_based_entropy, 86 | **knn_entropy_config) 87 | if to_jit: 88 | self.update_batch = jax.jit(self.update_batch) 89 | 90 | def init_params(self, 91 | init_key: jax.random.PRNGKey, 92 | dummy_obs: jnp.ndarray, 93 | skill_dim: int, 94 | summarize: bool = True 95 | ): 96 | # batch_size = dummy_obs.shape[0] 97 | dummy_skill = jax.random.uniform(key=init_key, shape=(skill_dim, ), minval=0, maxval=1) 98 | cic_init = self.cic.init(rng=init_key, state=dummy_obs, next_state=dummy_obs, skill=dummy_skill, is_training=True) 99 | cic_opt_init = self.cic_optimizer.init(cic_init) 100 | if summarize: 101 | logger = logging.getLogger(__name__) 102 | summarize_cic_forward = partial(self.cic.apply, is_training=True) # somehow only works this way 103 | logger.info(hk.experimental.tabulate(summarize_cic_forward)(cic_init, dummy_obs, dummy_obs, dummy_skill)) 104 | return CICState( 105 | cic_params=cic_init, 106 | cic_opt_params=cic_opt_init, 107 | running_mean=jnp.zeros((1,)), 108 | running_std=jnp.ones((1,)), 109 | running_num=1e-4 110 | ) 111 | 112 | def _cpc_loss(self, 113 | cic_params: hk.Params, 114 | obs: jnp.ndarray, 115 | next_obs: jnp.ndarray, 116 | skill: jnp.ndarray, 117 | temperature: float 118 | ): 119 | query, key = self.cic.apply(cic_params, obs, next_obs, skill, is_training=True) #(b, c) 120 | # loss = calculations.noise_contrastive_loss(query, key, temperature=temperature) 121 | loss = calculations.cpc_loss(query=query, key=key) 122 | logs = dict( 123 | cpc_loss=loss 124 | ) 125 | return loss, logs 126 | # return noise_contrastive_loss(query, key) 127 | 128 | def _update_cic(self, 129 | cic_params: hk.Params, 130 | cic_opt_params: optax.OptState, 131 | obs: jnp.ndarray, 132 | next_obs: jnp.ndarray, 133 | skill: jnp.ndarray 134 | ): 135 | grad_fn = jax.grad(self._cpc_loss, has_aux=True) 136 | grads, logs = grad_fn(cic_params, obs, next_obs, skill) 137 | deltas, cic_opt_params = self.cic_optimizer.update(grads, cic_opt_params) 138 | cic_params = optax.apply_updates(cic_params, deltas) 139 | return (cic_params, cic_opt_params), logs 140 | 141 | def compute_reward(self, cic_params, obs, next_obs, skill, running_mean, running_std, running_num): 142 | source, target = self.cic.apply(cic_params, obs, next_obs, skill, is_training=False) 143 | reward, running_mean, running_std, running_num = self.entropy_estimator( 144 | source=source, 145 | target=target, 146 | num=running_num, 147 | mean=running_mean, 148 | std=running_std) 149 | return reward, running_mean, running_std, running_num 150 | 151 | def update_batch(self, 152 | state: CICState, 153 | batch: data.Batch, 154 | step: int, 155 | ) -> Tuple[CICState, NamedTuple, Dict]: 156 | obs = batch.observation 157 | extrinsic_reward = batch.reward 158 | next_obs = batch.next_observation 159 | meta = batch.extras 160 | skill = meta['skill'] 161 | """ Updates CIC and batch""" 162 | logs = dict() 163 | # TODO add aug for pixel based 164 | (cic_params, cic_opt_params), cic_logs = self._update_cic( 165 | cic_params=state.cic_params, 166 | cic_opt_params=state.cic_opt_params, 167 | obs=obs, 168 | next_obs=next_obs, 169 | skill=skill) 170 | logs.update(cic_logs) 171 | 172 | intrinsic_reward, running_mean, running_std, running_num = self.compute_reward( 173 | cic_params=state.cic_params, 174 | obs=obs, 175 | next_obs=next_obs, 176 | running_num=state.running_num, 177 | skill=skill, 178 | running_mean=state.running_mean, 179 | running_std=state.running_std) 180 | 181 | logs['intrinsic_reward'] = jnp.mean(intrinsic_reward) 182 | logs['extrinsic_reward'] = jnp.mean(extrinsic_reward) 183 | 184 | return CICState( 185 | cic_params=cic_params, 186 | cic_opt_params=cic_opt_params, 187 | running_mean=running_mean, 188 | running_std=running_std, 189 | running_num=running_num 190 | ), batch._replace(reward=intrinsic_reward), logs 191 | -------------------------------------------------------------------------------- /core/intrinsic/intrinsic_reward_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | 5 | class IntrinsicReward(ABC): 6 | 7 | @abstractmethod 8 | def init_params(self, *args, **kwargs): 9 | raise NotImplementedError 10 | 11 | @abstractmethod 12 | def compute_reward(self, *args, **kwargs): 13 | raise NotImplementedError 14 | 15 | @abstractmethod 16 | def update_batch(self, *args, **kwargs): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /core/intrinsic/multimodal_cic.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Tuple, Dict 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import haiku as hk 6 | import optax 7 | import numpy as np 8 | 9 | from core import intrinsic 10 | from core import data 11 | 12 | 13 | class MultimodalCICState(NamedTuple): 14 | cic_params: hk.Params 15 | cic_opt_params: optax.OptState 16 | mode_0_running_mean: jnp.ndarray 17 | mode_0_running_std: jnp.ndarray 18 | mode_0_running_num: float 19 | 20 | 21 | class RunningStatistics(NamedTuple): 22 | mode_0_running_mean: jnp.ndarray 23 | mode_0_running_std: jnp.ndarray 24 | mode_0_running_num: float 25 | 26 | 27 | class MultimodalCICReward(intrinsic.CICReward): 28 | def __init__(self, to_jit, *args, **kwargs): 29 | # only jit the update fn 30 | super().__init__(False, *args, **kwargs) 31 | # # rewrite or will inherent 32 | if to_jit: 33 | self._update_cic = jax.jit(self._update_cic) 34 | self.entropy_estimator = jax.jit(self.entropy_estimator) 35 | 36 | def init_params(self, 37 | init_key: jax.random.PRNGKey, 38 | dummy_obs: jnp.ndarray, 39 | skill_dim: int, 40 | summarize: bool = True, 41 | ): 42 | cic_state = super().init_params(init_key, dummy_obs, skill_dim, summarize) 43 | return MultimodalCICState( 44 | cic_params=cic_state.cic_params, 45 | cic_opt_params=cic_state.cic_opt_params, 46 | mode_0_running_mean=jnp.zeros((1,)), 47 | mode_0_running_std=jnp.ones((1,)), 48 | mode_0_running_num=1e-4, 49 | ) 50 | 51 | def compute_reward(self, 52 | cic_params, 53 | obs, 54 | next_obs, 55 | skill, 56 | statistics, 57 | **kwargs 58 | ): 59 | source_0, target_0 = self.cic.apply(cic_params, 60 | obs, 61 | next_obs, 62 | skill, 63 | is_training=False) 64 | reward, running_mean_0, running_std_0, running_num_0 = self.entropy_estimator( 65 | source=source_0, 66 | target=target_0, 67 | mean=statistics.mode_0_running_mean, 68 | std=statistics.mode_0_running_std, 69 | num=statistics.mode_0_running_num, 70 | ) 71 | 72 | return reward, RunningStatistics( 73 | running_mean_0, 74 | running_std_0, 75 | running_num_0, 76 | ) 77 | 78 | def update_batch(self, 79 | state: MultimodalCICState, 80 | batch: data.Batch, 81 | step: int, 82 | ) -> Tuple[MultimodalCICState, data.Batch, Dict]: 83 | """ Updates CIC and batch""" 84 | obs = batch.observation 85 | extrinsic_reward = batch.reward 86 | next_obs = batch.next_observation 87 | meta = batch.extras 88 | skill = meta['skill'] 89 | logs = dict() 90 | # TODO add aug for pixel baseds 91 | (cic_params, cic_opt_params), cic_logs = self._update_cic( 92 | cic_params=state.cic_params, 93 | cic_opt_params=state.cic_opt_params, 94 | obs=jnp.concatenate(obs), 95 | next_obs=jnp.concatenate(next_obs), 96 | skill=jnp.concatenate(skill)) 97 | logs.update(cic_logs) 98 | 99 | intrinsic_reward, statistics = self.compute_reward(cic_params=state.cic_params, 100 | obs=jnp.concatenate(obs), 101 | next_obs=jnp.concatenate(next_obs), 102 | skill=jnp.concatenate(meta['skill']), 103 | statistics=state) 104 | # todo do we care about logging? put before to prevent moving out of gpu and putting back 105 | logs['intrinsic_reward'] = jnp.mean(intrinsic_reward) 106 | logs['extrinsic_reward'] = jnp.mean(jnp.concatenate(extrinsic_reward)) # don't mean on a list 107 | 108 | intrinsic_reward = np.array(intrinsic_reward) 109 | intrinsic_reward[len(obs[0]):, :] *= -1 110 | 111 | 112 | return MultimodalCICState( 113 | cic_params=cic_params, 114 | cic_opt_params=cic_opt_params, 115 | mode_0_running_mean=statistics.mode_0_running_mean, 116 | mode_0_running_std=statistics.mode_0_running_std, 117 | mode_0_running_num=statistics.mode_0_running_num, 118 | ), data.Batch( 119 | observation=jnp.concatenate(obs), 120 | action=jnp.concatenate(batch.action), 121 | reward=intrinsic_reward, 122 | discount=jnp.concatenate(batch.discount), 123 | next_observation=jnp.concatenate(next_obs), 124 | extras=dict(skill=jnp.concatenate(skill)) 125 | ), logs 126 | -------------------------------------------------------------------------------- /figures/MOSS_robot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/MOSS/534820b5cb5713389f416f60c867de505e791166/figures/MOSS_robot.png -------------------------------------------------------------------------------- /figures/fraction_rliable.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/MOSS/534820b5cb5713389f416f60c867de505e791166/figures/fraction_rliable.png -------------------------------------------------------------------------------- /figures/rliable.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/MOSS/534820b5cb5713389f416f60c867de505e791166/figures/rliable.png -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Any, Dict 2 | import time 3 | 4 | import wandb 5 | import haiku as hk 6 | 7 | from core.calculations import skill_utils 8 | 9 | 10 | def log_params_to_wandb(params: hk.Params, step: int): 11 | if params: 12 | for module in sorted(params): 13 | if 'w' in params[module]: 14 | wandb.log({ 15 | f'{module}/w': wandb.Histogram(params[module]['w']) 16 | }, step=step) 17 | if 'b' in params[module]: 18 | wandb.log({ 19 | f'{module}/b': wandb.Histogram(params[module]['b']) 20 | }, step=step) 21 | 22 | 23 | class LogParamsEvery: 24 | def __init__(self, every, action_repeat=1): 25 | self._every = every 26 | self._action_repeat = action_repeat 27 | 28 | def __call__(self, params: hk.Params, step): 29 | if self._every is None: 30 | pass 31 | every = self._every // self._action_repeat 32 | if step % every == 0: 33 | log_params_to_wandb(params, step) 34 | pass 35 | 36 | 37 | class Until: 38 | def __init__(self, until, action_repeat=1): 39 | self._until = until 40 | self._action_repeat = action_repeat 41 | 42 | def __call__(self, step): 43 | if self._until is None: 44 | return True 45 | until = self._until // self._action_repeat 46 | return step < until 47 | 48 | 49 | class Every: 50 | def __init__(self, every, action_repeat=1): 51 | self._every = every 52 | self._action_repeat = action_repeat 53 | 54 | def __call__(self, step): 55 | if self._every is None: 56 | return False 57 | every = self._every // self._action_repeat 58 | if step % every == 0: 59 | return True 60 | return False 61 | 62 | 63 | class Timer: 64 | def __init__(self): 65 | self._start_time = time.time() 66 | self._last_time = time.time() 67 | 68 | def reset(self): 69 | elapsed_time = time.time() - self._last_time 70 | self._last_time = time.time() 71 | total_time = time.time() - self._start_time 72 | return elapsed_time, total_time 73 | 74 | def total_time(self): 75 | return time.time() - self._start_time 76 | 77 | 78 | class CsvData(NamedTuple): 79 | episode_reward: float 80 | episode_length: int 81 | episode: int 82 | step: int 83 | total_time: float 84 | fps: float 85 | 86 | 87 | class LoopVar(NamedTuple): 88 | global_step: int 89 | global_episode: int 90 | episode_step: int 91 | episode_reward: float 92 | total_reward: float 93 | pointer: int 94 | 95 | 96 | class LoopsLength(NamedTuple): 97 | eval_until_episode: Until 98 | train_until_step: Until 99 | seed_until_step: Until 100 | eval_every_step: Every 101 | 102 | 103 | def increment_step(x: LoopVar, 104 | reward: float, 105 | n: int = 1 106 | ) -> LoopVar: 107 | return LoopVar( 108 | global_step=x.global_step + n, 109 | global_episode=x.global_episode, 110 | episode_step=x.episode_step + n, 111 | episode_reward=x.episode_reward + reward, 112 | total_reward=x.episode_reward + reward, 113 | pointer=x.pointer, 114 | ) 115 | 116 | 117 | def increment_episode(x: LoopVar, 118 | n: int = 1 119 | ) -> LoopVar: 120 | return LoopVar( 121 | global_step=x.global_step, 122 | global_episode=x.global_episode + n, 123 | episode_step=x.episode_step, 124 | episode_reward=x.episode_reward, 125 | total_reward=x.episode_reward, 126 | pointer=x.pointer, 127 | ) 128 | 129 | 130 | def reset_episode(x: LoopVar, 131 | ) -> LoopVar: 132 | return LoopVar( 133 | global_step=x.global_step, 134 | global_episode=x.global_episode, 135 | episode_step=0, 136 | episode_reward=0., 137 | total_reward=x.episode_reward, 138 | pointer=x.pointer, 139 | ) 140 | 141 | 142 | def update_skilltracker( 143 | x: skill_utils.SkillRewardTracker, 144 | reward: float 145 | ) -> skill_utils.SkillRewardTracker: 146 | # for pretrain, we dont need skill tracker 147 | if x is None: 148 | return 149 | return x._replace( 150 | score_sum=x.score_sum + reward, 151 | score_step=x.score_step + 1, 152 | ) 153 | 154 | 155 | def parse_skilltracker( 156 | x: skill_utils.SkillRewardTracker, 157 | meta: Dict[str, Any], 158 | ) -> skill_utils.SkillRewardTracker: 159 | if not meta or 'tracker' not in meta: 160 | return x 161 | return meta['tracker'] 162 | 163 | 164 | def init_skilltracker( 165 | search_steps: int, 166 | change_interval: int, 167 | low: float, 168 | ) -> skill_utils.SkillRewardTracker: 169 | return skill_utils.SkillRewardTracker( 170 | best_skill=None, 171 | best_score=-float('inf'), 172 | score_sum=0., 173 | score_step=0, 174 | current_skill=None, 175 | search_steps=search_steps, 176 | change_interval=change_interval, 177 | low=low, 178 | update=True, 179 | ) 180 | 181 | 182 | def skilltracker_update_on( 183 | x: skill_utils.SkillRewardTracker, 184 | ) -> skill_utils.SkillRewardTracker: 185 | if x is None: 186 | return 187 | return x._replace(update=True) 188 | 189 | 190 | def skilltracker_update_off( 191 | x: skill_utils.SkillRewardTracker, 192 | ) -> skill_utils.SkillRewardTracker: 193 | if x is None: 194 | return 195 | return x._replace(update=False) 196 | -------------------------------------------------------------------------------- /pretrain_multimodal.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 3 | os.environ['MUJOCO_GL'] = 'egl' 4 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 5 | import random 6 | import logging 7 | 8 | import wandb 9 | import jax 10 | import torch 11 | import hydra 12 | import numpy as np 13 | from omegaconf import OmegaConf, DictConfig 14 | import pathlib 15 | from dm_env import specs 16 | 17 | from core import envs 18 | from core import agents 19 | from core import exp_utils 20 | import helpers 21 | 22 | 23 | class PretrainLoop: 24 | def __init__(self, cfg): 25 | self.logger = logging.getLogger(__name__) 26 | self.logger.info(OmegaConf.to_yaml(cfg)) 27 | self.init_rng = jax.random.PRNGKey(cfg.seed) 28 | torch.manual_seed(cfg.seed) 29 | np.random.seed(cfg.seed) 30 | random.seed(cfg.seed) 31 | 32 | self.work_dir = pathlib.Path.cwd() 33 | self.use_wandb = cfg.use_wandb 34 | if cfg.use_wandb: 35 | WANDB_NOTES = cfg.wandb_note 36 | os.environ["WANDB_MODE"] = "offline" 37 | wandb.init(project=cfg.wandb_project_name, 38 | name='pretrain_multimodal_'+cfg.benchmark.task+WANDB_NOTES+str(cfg.seed), 39 | config=OmegaConf.to_container(cfg, resolve=True)) 40 | 41 | # init env 42 | self.train_env = envs.make_env(cfg.agent.action_type, cfg.benchmark, seed=cfg.seed) 43 | self.eval_env = envs.make_env(cfg.agent.action_type, cfg.benchmark, seed=cfg.seed) 44 | self.action_repeat = cfg.benchmark.action_repeat 45 | 46 | # init agent 47 | self.agent = agents.make_agent( 48 | cfg.benchmark.obs_type, 49 | self.train_env.action_spec().shape, 50 | cfg.agent 51 | ) 52 | self.intrinsic_reward = hydra.utils.instantiate(cfg.intrinsic) 53 | self.state = None 54 | self.intrinsic_state = None 55 | data_specs = ( 56 | self.train_env.observation_spec(), 57 | self.train_env.action_spec(), 58 | specs.Array((1,), np.float32, 'reward'), 59 | specs.Array((1,), np.float32, 'discount'), 60 | ) 61 | self.agent.init_replay_buffer( 62 | replay_buffer_cfg=cfg.agent.replay_buffer_cfg, 63 | replay_dir=self.work_dir / 'buffer', 64 | environment_spec=data_specs, 65 | ) 66 | self.update_agent_every = cfg.update_every_steps 67 | self.checkpointer = exp_utils.Checkpointer(save_dir=cfg.save_dir) 68 | self.snapshot_steps = cfg.snapshots 69 | 70 | # init exp_utils 71 | self.train_video_recorder = exp_utils.TrainVideoRecorder( 72 | self.work_dir if cfg.save_train_video else None # if state based no pixels to save 73 | ) 74 | self.video_recorder = exp_utils.VideoRecorder( 75 | self.work_dir if cfg.save_video else None 76 | ) 77 | 78 | # init loop 79 | eval_until_episode = helpers.Until(cfg.num_eval_episodes, cfg.benchmark.action_repeat) 80 | seed_until_step = helpers.Until(cfg.num_seed_frames, cfg.benchmark.action_repeat) 81 | eval_every_step = helpers.Every(cfg.eval_every_frames, cfg.benchmark.action_repeat) 82 | train_until_step = helpers.Until(cfg.num_pretrain_frames, cfg.benchmark.action_repeat) 83 | self.loops_length = helpers.LoopsLength( 84 | eval_until_episode=eval_until_episode, 85 | train_until_step=train_until_step, 86 | seed_until_step=seed_until_step, 87 | eval_every_step=eval_every_step, 88 | ) 89 | self.global_loop_var = helpers.LoopVar( 90 | global_step=0, 91 | global_episode=0, 92 | episode_step=0, 93 | episode_reward=0., 94 | total_reward=0., 95 | pointer=0) 96 | 97 | @property 98 | def global_frame(self): 99 | return self.global_loop_var.global_step * self.action_repeat 100 | 101 | def _exploration_loop(self, rng): 102 | time_step = self.train_env.reset() 103 | meta_rng, rng = jax.random.split(key=rng, num=2) 104 | meta = self.agent.init_meta(key=meta_rng, time_step=time_step)[0] 105 | 106 | self.agent.store_timestep(time_step=time_step, meta=meta) 107 | while self.loops_length.seed_until_step(self.global_loop_var.global_step): 108 | if time_step.last(): 109 | self.global_loop_var = helpers.increment_episode(self.global_loop_var) 110 | time_step = self.train_env.reset() 111 | meta_rng, rng = jax.random.split(key=rng, num=2) 112 | meta = self.agent.init_meta(key=meta_rng, time_step=time_step)[0] 113 | self.agent.store_timestep(time_step=time_step, meta=meta) 114 | self.global_loop_var = helpers.reset_episode(self.global_loop_var) 115 | 116 | meta_rng, action_rng, rng = tuple(jax.random.split(key=rng, num=3)) 117 | meta = self.agent.update_meta(key=meta_rng, 118 | meta=meta, 119 | step=self.global_loop_var.global_step, time_step=time_step)[0] 120 | action = jax.random.uniform(key=action_rng, shape=self.train_env.action_spec().shape, minval=-1.0, maxval=1.0) 121 | action = np.array(action) 122 | 123 | # take env step 124 | time_step = self.train_env.step(action) 125 | self.agent.store_timestep(time_step=time_step, meta=meta) 126 | # increment loop_vars and skill tracker 127 | self.global_loop_var = helpers.increment_step(self.global_loop_var, reward=time_step.reward) 128 | 129 | return time_step, meta 130 | 131 | def train_loop(self): 132 | 133 | metric_logger = exp_utils.MetricLogger(csv_file_name=self.work_dir / 'train.csv', use_wandb=self.use_wandb) 134 | timer = exp_utils.Timer() 135 | time_step = self.train_env.reset() 136 | 137 | self.logger.info("Pretraining from scratch") 138 | self.state = self.agent.init_params( 139 | init_key=self.init_rng, 140 | dummy_obs=time_step.observation 141 | ) 142 | self.intrinsic_state = self.intrinsic_reward.init_params( 143 | init_key=self.init_rng, 144 | dummy_obs=time_step.observation, 145 | ) 146 | 147 | step_rng, eval_rng, rng = jax.random.split(self.init_rng, num=3) 148 | # self.evaluate(eval_rng) 149 | self.logger.info("Exploration loop") 150 | time_step, meta = self._exploration_loop(step_rng) 151 | self.logger.info("Starting training at episode: {}, step: {}".format(self.global_loop_var.global_episode, 152 | self.global_loop_var.global_step)) 153 | 154 | metrics = None 155 | while self.loops_length.train_until_step(self.global_loop_var.global_step): 156 | if time_step.last(): 157 | self.global_loop_var = helpers.increment_episode(self.global_loop_var) 158 | 159 | # log metrics 160 | if metrics is not None: 161 | elapsed_time, total_time = timer.reset() 162 | episode_frame = self.global_loop_var.episode_step * self.action_repeat 163 | data = helpers.CsvData( 164 | step=self.global_loop_var.global_step, 165 | episode=self.global_loop_var.global_episode, 166 | episode_length=episode_frame, 167 | episode_reward=self.global_loop_var.episode_reward, # not a float type 168 | total_time=total_time, 169 | fps=episode_frame / elapsed_time 170 | ) 171 | data = data._asdict() 172 | metric_logger.dump_dict_to_csv(data=data) 173 | metric_logger.dump_dict_to_wandb(step=self.global_loop_var.global_step, data=data) 174 | data.update(buffer_size=len(self.agent)) 175 | metric_logger.log_and_dump_metrics_to_wandb(step=self.global_loop_var.global_step, header=data) 176 | 177 | # reset env 178 | time_step = self.train_env.reset() 179 | step_rng, rng = tuple(jax.random.split(rng, num=2)) 180 | meta = self.agent.init_meta(step_rng, time_step=time_step)[0] 181 | # no need to parse because not updating it duing finetune loop 182 | self.agent.store_timestep(time_step=time_step, meta=meta) 183 | # train_video_recorder.init(time_step.observation) 184 | self.global_loop_var = helpers.reset_episode(self.global_loop_var) 185 | 186 | chkpt_pointer = min(self.global_loop_var.pointer, len(self.snapshot_steps) - 1) 187 | if (self.global_loop_var.global_step + 1) >= self.snapshot_steps[chkpt_pointer]: 188 | self.checkpointer.save_state( 189 | name=str(self.global_loop_var.global_step + 1), 190 | state=self.state, 191 | step=self.snapshot_steps[chkpt_pointer], 192 | rng=rng, 193 | ) 194 | self.checkpointer.save_state( 195 | name=str(self.global_loop_var.global_step + 1) + '_cic', 196 | state=self.intrinsic_state, 197 | step=self.snapshot_steps[chkpt_pointer], 198 | rng=rng, 199 | ) 200 | self.global_loop_var = self.global_loop_var._replace(pointer=self.global_loop_var.pointer + 1) 201 | 202 | if self.loops_length.eval_every_step(self.global_loop_var.global_step): 203 | eval_rng, rng = tuple(jax.random.split(rng, num=2)) 204 | 205 | # agent step 206 | meta_rng, step_rng, update_rng, rng = tuple(jax.random.split(rng, num=4)) 207 | meta = self.agent.update_meta( 208 | key=meta_rng, meta=meta, step=self.global_loop_var.global_step, time_step=time_step)[0] 209 | action = self.agent.select_action( 210 | state=self.state, 211 | obs=time_step.observation, 212 | meta=meta, 213 | step=self.global_loop_var.global_step, 214 | key=step_rng, 215 | greedy=False 216 | ) 217 | if self.global_loop_var.global_step % self.update_agent_every == 0: 218 | batch = self.agent.sample_timesteps() 219 | self.intrinsic_state, batch, intrinsic_metrics = self.intrinsic_reward.update_batch( 220 | state=self.intrinsic_state, 221 | batch=batch, 222 | step=self.global_loop_var.global_step, 223 | ) 224 | metric_logger.update_metrics(**intrinsic_metrics) 225 | self.state, metrics = self.agent.update( 226 | state=self.state, 227 | key=update_rng, 228 | step=self.global_loop_var.global_step, 229 | batch=batch 230 | ) 231 | metric_logger.update_metrics(**metrics) 232 | 233 | # step on env 234 | time_step = self.train_env.step(action) 235 | self.agent.store_timestep(time_step=time_step, meta=meta) 236 | self.global_loop_var = helpers.increment_step(self.global_loop_var, reward=time_step.reward) 237 | 238 | eval_rng, rng = jax.random.split(rng, num=2) 239 | self.evaluate(eval_rng=eval_rng) 240 | 241 | def evaluate(self, eval_rng): 242 | metric_logger = exp_utils.MetricLogger(csv_file_name=pathlib.Path.cwd() / 'eval.csv', use_wandb=self.use_wandb) 243 | timer = exp_utils.Timer() 244 | local_loop_var = helpers.LoopVar( 245 | global_step=0, 246 | global_episode=0, 247 | episode_step=0, 248 | episode_reward=0., 249 | total_reward=0., 250 | pointer=0, 251 | ) 252 | while self.loops_length.eval_until_episode(local_loop_var.global_episode): 253 | step_rng, rng = jax.random.split(key=eval_rng, num=2) 254 | time_step = self.eval_env.reset() 255 | meta = self.agent.init_meta(key=step_rng, time_step=time_step)[0] 256 | self.video_recorder.init(self.eval_env, enabled=(local_loop_var.global_episode == 0)) 257 | while not time_step.last(): 258 | action = self.agent.select_action( 259 | state=self.state, 260 | obs=time_step.observation, 261 | meta=meta, 262 | step=self.global_loop_var.global_step, 263 | key=step_rng, 264 | greedy=True 265 | ) 266 | time_step = self.eval_env.step(action) 267 | self.video_recorder.record(self.eval_env) 268 | local_loop_var = helpers.increment_step(local_loop_var, reward=time_step.reward) 269 | 270 | # episode += 1 271 | local_loop_var = helpers.increment_episode(local_loop_var) 272 | self.video_recorder.save(f'{self.global_loop_var.global_step * self.action_repeat}.mp4', 273 | step=self.global_loop_var.global_step) 274 | 275 | n_frame = local_loop_var.global_step * self.action_repeat 276 | total_time = timer.total_time() 277 | data = helpers.CsvData( 278 | # episode_reward=total_reward / episode, 279 | episode_reward=local_loop_var.total_reward / local_loop_var.global_episode, 280 | episode_length=int(local_loop_var.global_step * self.action_repeat / local_loop_var.global_episode), 281 | episode=self.global_loop_var.global_episode, # must name it episode otherwise the csv cannot clean it 282 | step=self.global_loop_var.global_step, 283 | total_time=total_time, 284 | fps=n_frame / total_time 285 | ) 286 | data = data._asdict() 287 | metric_logger.dump_dict_to_csv(data=data) 288 | metric_logger.dump_dict_to_wandb(step=self.global_loop_var.global_step, data=data) 289 | metric_logger.log_dict(data=data, header="Evaluation results: ") 290 | return data 291 | 292 | 293 | @hydra.main(config_path='conf/', config_name='config') 294 | def main(cfg: DictConfig): 295 | trainer = PretrainLoop(cfg) 296 | trainer.train_loop() 297 | 298 | 299 | if __name__ == '__main__': 300 | import warnings 301 | warnings.filterwarnings('ignore', category=DeprecationWarning) # dmc version 302 | import tensorflow as tf 303 | 304 | tf.config.set_visible_devices([], "GPU") # resolves tf/jax concurrent use conflict 305 | main() 306 | --------------------------------------------------------------------------------