├── .gitignore
├── LICENSE
├── README.md
├── conf
├── agent
│ └── ddpg_multimodal_skill_torch.yaml
├── benchmark
│ └── dmc.yaml
├── config.yaml
├── hydra
│ └── job_logging
│ │ └── custom.yaml
└── intrinsic
│ └── multimodal_cic.yaml
├── core
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── agent_base.py
│ ├── ddpg.py
│ ├── ddpg_multimodal_skill_torch.py
│ └── ddpg_skill.py
├── calculations
│ ├── __init__.py
│ ├── augmentations.py
│ ├── distributions.py
│ ├── layers.py
│ ├── losses.py
│ ├── misc.py
│ ├── params_utils.py
│ └── skill_utils.py
├── custom_dmc_tasks
│ ├── __init__.py
│ ├── cheetah.py
│ ├── cheetah.xml
│ ├── hopper.py
│ ├── hopper.xml
│ ├── jaco.py
│ ├── quadruped.py
│ ├── quadruped.xml
│ ├── walker.py
│ └── walker.xml
├── data
│ ├── __init__.py
│ ├── replay_buffer.py
│ └── replay_buffer_torch.py
├── envs
│ ├── __init__.py
│ ├── dmc.py
│ ├── dmc_benchmark.py
│ └── wrappers.py
├── exp_utils
│ ├── __init__.py
│ ├── checkpointing.py
│ ├── loggers.py
│ └── video.py
└── intrinsic
│ ├── __init__.py
│ ├── cic.py
│ ├── intrinsic_reward_base.py
│ └── multimodal_cic.py
├── figures
├── MOSS_robot.png
├── fraction_rliable.png
└── rliable.png
├── finetune_multimodal.py
├── helpers.py
└── pretrain_multimodal.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # General
132 | .DS_Store
133 | .AppleDouble
134 | .LSOverride
135 |
136 | # Icon must end with two \r
137 | Icon
138 |
139 | # Thumbnails
140 | ._*
141 |
142 | # Files that might appear in the root of a volume
143 | .DocumentRevisions-V100
144 | .fseventsd
145 | .Spotlight-V100
146 | .TemporaryItems
147 | .Trashes
148 | .VolumeIcon.icns
149 | .com.apple.timemachine.donotpresent
150 |
151 | # Directories potentially created on remote AFP share
152 | .AppleDB
153 | .AppleDesktop
154 | Network Trash Folder
155 | Temporary Items
156 | .apdisk
157 |
158 | # Created by .ignore support plugin (hsz.mobi)
159 | ### Python template
160 | # Byte-compiled / optimized / DLL files
161 | __pycache__/
162 | *.py[cod]
163 | *$py.class
164 |
165 | # C extensions
166 | *.so
167 |
168 | # Distribution / packaging
169 | .Python
170 | build/
171 | develop-eggs/
172 | dist/
173 | downloads/
174 | eggs/
175 | .eggs/
176 | lib/
177 | lib64/
178 | parts/
179 | sdist/
180 | var/
181 | *.egg-info/
182 | .installed.cfg
183 | *.egg
184 |
185 | # PyInstaller
186 | # Usually these files are written by a python script from a template
187 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
188 | *.manifest
189 | *.spec
190 |
191 | # Installer logs
192 | pip-log.txt
193 | pip-delete-this-directory.txt
194 |
195 | # Unit test / coverage reports
196 | htmlcov/
197 | .tox/
198 | .coverage
199 | .coverage.*
200 | .cache
201 | nosetests.xml
202 | coverage.xml
203 | *,cover
204 | .hypothesis/
205 |
206 | # Translations
207 | *.mo
208 | *.pot
209 |
210 | # Django stuff:
211 | *.log
212 | local_settings.py
213 |
214 | # Flask stuff:
215 | instance/
216 | .webassets-cache
217 |
218 | # Scrapy stuff:
219 | .scrapy
220 |
221 | # Sphinx documentation
222 | docs/_build/
223 |
224 | # PyBuilder
225 | target/
226 |
227 | # IPython Notebook
228 | .ipynb_checkpoints
229 |
230 | # pyenv
231 | .python-version
232 |
233 | # celery beat schedule file
234 | celerybeat-schedule
235 |
236 | # dotenv
237 | #.env
238 |
239 | # virtualenv
240 | #venv/
241 | #ENV/
242 |
243 | # Spyder project settings
244 | .spyderproject
245 |
246 | # Rope project settings
247 | .ropeproject
248 | ### VirtualEnv template
249 | # Virtualenv
250 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
251 | .Python
252 | [Bb]in
253 | [Ii]nclude
254 | [Ll]ib
255 | [Ll]ib64
256 | [Ll]ocal
257 | [Ss]cripts
258 | pyvenv.cfg
259 | .venv
260 | pip-selfcheck.json
261 | ### JetBrains template
262 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
263 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
264 |
265 | # User-specific stuff:
266 | .idea/workspace.xml
267 | .idea/tasks.xml
268 | .idea/dictionaries
269 | .idea/vcs.xml
270 | .idea/jsLibraryMappings.xml
271 |
272 | # Sensitive or high-churn files:
273 | .idea/dataSources.ids
274 | .idea/dataSources.xml
275 | .idea/dataSources.local.xml
276 | .idea/sqlDataSources.xml
277 | .idea/dynamic.xml
278 | .idea/uiDesigner.xml
279 |
280 | # Gradle:
281 | .idea/gradle.xml
282 | .idea/libraries
283 |
284 | # Mongo Explorer plugin:
285 | .idea/mongoSettings.xml
286 |
287 | .idea/
288 |
289 | ## File-based project format:
290 | *.iws
291 |
292 | ## Plugin-specific files:
293 |
294 | # IntelliJ
295 | /out/
296 |
297 | # mpeltonen/sbt-idea plugin
298 | .idea_modules/
299 |
300 | # JIRA plugin
301 | atlassian-ide-plugin.xml
302 |
303 | # Crashlytics plugin (for Android Studio and IntelliJ)
304 | com_crashlytics_export_strings.xml
305 | crashlytics.properties
306 | crashlytics-build.properties
307 | fabric.properties
308 |
309 | outputs/
310 | commands.md
311 | testing.ipynb
312 | testing_wrapper.py
313 | is_pretrain*/
314 | .vscode*
315 | script/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mixture Of SurpriseS (MOSS)
2 | This repo contains the official [**Jax/Haiku**](https://github.com/google/jax) code for A Mixture Of Surprises for Unsupervised Reinforcement Learning. [[arxiv]](https://arxiv.org/abs/2210.06702)
3 |
4 | ## Introduction
5 |
6 | 
7 |
8 | We investigated a method that uses mixture of skills to alleviate the assumptions needed on the environment for unsupervised reinforcement learning.
9 |
10 | ## Results
11 |
12 | [RLiable](https://github.com/google-research/rliable) Aggregated Results on the [Unsupervised Reinforcement Learning Benchmark](https://github.com/rll-research/url_benchmark)
13 |
14 | 
15 |
16 | ## Numerical Results
17 |
18 | | Domain | | Walker | | | | Quadruped | | | | Jaco | |
19 | |----------------------------------------------------|-----------------|-----------------|----------------|----------------|-----------------|----------------|-----------------|-----------------|----------------|-----------------|----------------|
20 | | Method\Task | Flip | Run | Stand | Walk | Jump | Run | Stand | Walk | Bottom Left | Bottom Right | Top Left |
21 | | ICM | 381±10 | 180±15 | 868±30 | 568±38 | 337±18 | 221±14 | 452±15 | 234±18 | 112±7 | 94±5 | 90±6 |
22 | | Disagreement | 313±8 | 166±9 | 658±33 | 453±37 | 512±14 | 395±12 | 686±30 | 358±25 | 120±7 | 132±5 | 111±10 |
23 | | RND | 412±18 | 267±18 | 842±19 | 694±26 | **681±11** | 455±7 | 875±25 | 581±42 | 106±6 | 111±6 | 83±7 |
24 | | ICM APT | 596±24 | 491±18 | 949±3 | 850±22 | 508±44 | 390±24 | 676±44 | 464±52 | 114±5 | 120±3 | 116±4 |
25 | | IND APT | 576±20 | 467±21 | 947±4 | 888±19 | 542±34 | 328±18 | 605±32 | 367±24 | 126±5 | 131±4 | 109±6 |
26 | | Proto | 378±4 | 225±16 | 828±24 | 610±40 | 426±32 | 310±22 | 702±59 | 348±55 | 130±12 | 131±11 | 134±12 |
27 | | AS-Bob | 475±16 | 247±23 | 917±36 | 675±21 | 449±24 | 285±23 | 594±37 | 353±39 | 116±21 | **166±12** | 143±12 |
28 | | AS-Alice | 491±20 | 211±9 | 868±47 | 655±36 | 415±20 | 296±18 | 590±41 | 337±17 | 109±20 | 141±19 | 140±17 |
29 | | SMM | 428±8 | 345±31 | 924±9 | 731±43 | 271±35 | 222±23 | 388±51 | 167±20 | 52±5 | 55±2 | 53±2 |
30 | | DIAYN | 306±12 | 146±7 | 631±46 | 394±22 | 491±38 | 325±21 | 662±38 | 273±19 | 35±5 | 35±6 | 23±3 |
31 | | APS | 355±18 | 166±15 | 667±56 | 500±40 | 283±22 | 206±16 | 379±31 | 192±17 | 61±6 | 79±12 | 51±5 |
32 | | CIC | 715±40 | **535±25** | **968±2** | 914±12 | 541±31 | 376±19 | 717±46 | 460±36 | 147±8 | 150±6 | 145±9 |
33 | | MOSS (Ours) | **729±40** | 531±20 | 962±3 | **942±5** | 674±11 | **485±6** | **911±11** | **635±36** | **151±5** | 150±5 | **150±5** |
34 |
35 |
36 | ## Get Started
37 | ### Pretraining
38 | ```
39 | # example for pretraining on the jaco domain
40 | python pretrain_multimodal.py \
41 | reward_free=true \
42 | agent=ddpg_multimodal_skill_torch \
43 | agent.skill_mode=sign \
44 | agent.partitions=1.5 \
45 | agent.skills_cfg.update_skill_every=50 \
46 | intrinsic=multimodal_cic \
47 | intrinsic.temperature=0.5 \
48 | intrinsic.network_cfg.skill_dim=64\
49 | intrinsic.knn_entropy_config.minus_mean=true \
50 | benchmark=dmc \
51 | benchmark.task=jaco_reach_top_left \
52 | seed=0 \
53 | wandb_note=moss_pretrain_base_sign
54 | ```
55 | ### Finetuning
56 | ```
57 | # example for finetuning on the jaco domain
58 | python finetune_multimodal.py \
59 | reward_free=false \
60 | agent=ddpg_multimodal_skill_torch \
61 | intrinsic.network_cfg.skill_dim=64 \
62 | agent.search_mode=constant \
63 | benchmark=dmc \
64 | benchmark.task=jaco_reach_top_left \
65 | seed=0 \
66 | checkpoint=../../../../is_pretrain_True/jaco_reach_top_left/0/moss_pretrain_base_sign/checkpoints/2000000.pth \
67 | num_finetune_frames=100000 \
68 | wandb_note=moss_finetune_base_sign
69 | ```
70 |
71 | ## Contact
72 |
73 | If you have any question, please feel free to contact the authors. Andrew Zhao: [zqc21@mails.tsinghua.edu.cn](mailto:zqc21@mails.tsinghua.edu.cn).
74 |
75 | ## Acknowledgment
76 |
77 | Our code is based on [Contrastive Intrinsic Control](https://github.com/rll-research/cic) and [URL Benchmark](https://github.com/rll-research/url_benchmark).
78 |
79 | ## Citation
80 |
81 | If you find our work is useful in your research, please consider citing:
82 |
83 | ```bibtex
84 | @article{zhao2022mixture,
85 | title={A Mixture of Surprises for Unsupervised Reinforcement Learning},
86 | author={Zhao, Andrew and Lin, Matthieu Gaetan and Li, Yangguang and Liu, Yong-Jin and Huang, Gao},
87 | journal={arXiv preprint arXiv:2210.06702},
88 | year={2022}
89 | }
90 | ```
--------------------------------------------------------------------------------
/conf/agent/ddpg_multimodal_skill_torch.yaml:
--------------------------------------------------------------------------------
1 | _target_: core.agents.ddpg_multimodal_skill_torch.DDPGAgentMultiModalSkill
2 |
3 | action_type: continuous # [continuous, discrete]
4 | to_jit: true
5 | stddev_schedule: 0.2
6 | stddev_clip: 0.3
7 | critic_target_tau: 0.01
8 | l2_weight: 0.0
9 | #lr_encoder: 1e-4
10 | lr_actor: 1e-4
11 | lr_critic: 1e-4
12 | network_cfg:
13 | obs_type: ${benchmark.obs_type}
14 | action_shape: ???
15 | feature_dim: 50
16 | hidden_dim: 1024
17 | ln_config:
18 | axis: -1
19 | create_scale: True
20 | create_offset: True
21 |
22 | # replay buffer
23 | replay_buffer_cfg:
24 | nstep: 3
25 | replay_buffer_size: 1000000
26 | batch_size: 1024 #2048 #
27 | discount: 0.99
28 | num_workers: 4
29 | skill_dim: ${intrinsic.network_cfg.skill_dim}
30 |
31 | search_mode: grid_search
32 | skill_mode: half
33 | partitions: 2 # 4
34 | reward_free: ${reward_free}
35 | # additional for skill based DDPG
36 | skills_cfg:
37 | update_skill_every: 50
38 | skill_dim: ${intrinsic.network_cfg.skill_dim}
39 |
--------------------------------------------------------------------------------
/conf/benchmark/dmc.yaml:
--------------------------------------------------------------------------------
1 | task: quadruped_walk
2 | obs_type: states # [states, pixels]
3 | frame_stack: 3
4 | action_repeat: 1
5 | seed: ${seed}
6 | reward_scale: 1.0
7 |
--------------------------------------------------------------------------------
/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - agent: ddpg_multimodal_skill_torch
4 | - intrinsic: multimodal_cic
5 | - benchmark: dmc
6 | - override hydra/job_logging: custom
7 |
8 |
9 | # mode
10 | reward_free: false
11 |
12 | # train settings
13 | num_pretrain_frames: 2000000
14 | num_finetune_frames: 4001 #100000
15 | num_seed_frames: 4000
16 | update_every_steps: 2
17 |
18 | # eval
19 | eval_every_frames: 10000
20 | num_eval_episodes: 10
21 |
22 | # wandb
23 | log_params_to_wandb_every: 100000
24 | run_id: '1' # used for resuming
25 | resume: allow
26 | use_wandb: false
27 | wandb_note: 'entropy_calc'
28 | wandb_project_name: unsupervisedRL
29 |
30 | # misc
31 | seed: 0
32 | save_video: true
33 | save_train_video: false
34 | checkpoint: null
35 | save_dir: checkpoints
36 | snapshots: [100000, 500000, 1000000, 2000000]
37 |
38 |
39 | hydra:
40 | run:
41 | dir: is_pretrain_${reward_free}/${benchmark.task}/${seed}/${wandb_note}
42 |
--------------------------------------------------------------------------------
/conf/hydra/job_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 | formatters:
3 | simple:
4 | format: "[%(asctime)s](%(filename)s %(lineno)d): %(message)s"
5 | colored:
6 | (): colorlog.ColoredFormatter
7 | # format: "[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s"
8 | format: '%(green)s[%(asctime)s](%(filename)s %(lineno)d): %(white)s%(message)s'
9 | handlers:
10 | console: # console handler
11 | class: logging.StreamHandler
12 | level: INFO
13 | formatter: colored
14 | stream: ext://sys.stdout
15 | file: # file handler
16 | class: logging.FileHandler
17 | formatter: colored
18 | level: INFO
19 | filename: output.log
20 | loggers: # parents
21 | finetune:
22 | level: INFO
23 | handlers: [console, file]
24 | propagate: no
25 | root: # default one
26 | level: INFO #DEBUG
27 | handlers: [console, file]
28 |
29 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/conf/intrinsic/multimodal_cic.yaml:
--------------------------------------------------------------------------------
1 | name: multimodal_cic
2 | _target_: core.intrinsic.MultimodalCICReward
3 | lr: 1e-4
4 | to_jit: true
5 | temperature: 0.5
6 | network_cfg:
7 | hidden_dim: 1024
8 | skill_dim: 32
9 | project_skill: true
10 | knn_entropy_config:
11 | knn_clip: 0.0005
12 | knn_k: 16
13 | knn_avg: true
14 | knn_rm: true
15 | minus_mean: true
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/MOSS/534820b5cb5713389f416f60c867de505e791166/core/__init__.py
--------------------------------------------------------------------------------
/core/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, NamedTuple
2 |
3 | import hydra
4 |
5 | from .agent_base import Agent
6 | from .ddpg import DDPGAgent, DDPGTrainState
7 | from .ddpg_skill import DDPGAgentSkill
8 | from .ddpg_multimodal_skill_torch import DDPGAgentMultiModalSkill
9 |
10 |
11 | def make_agent(obs_type, action_shape, agent_cfg):
12 | if agent_cfg.action_type == 'continuous':
13 | return make_continuous_agent(action_shape, agent_cfg)
14 | elif agent_cfg.action_type == 'discrete':
15 | return make_discrete_agent(obs_type, action_shape, agent_cfg)
16 | else:
17 | raise NotImplementedError
18 |
19 | def make_continuous_agent(action_shape, agent_cfg):
20 | agent_cfg.network_cfg.action_shape = action_shape
21 | return hydra.utils.instantiate(agent_cfg)
22 |
23 | def make_discrete_agent(obs_type: str, action_shape: Tuple[int], cfg):
24 | cfg.network_cfg.obs_type = obs_type
25 | cfg.network_cfg.action_shape = action_shape
26 | return hydra.utils.instantiate(cfg)
27 |
--------------------------------------------------------------------------------
/core/agents/agent_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | class Agent(ABC):
4 |
5 | @abstractmethod
6 | def init_params(self,
7 | init_key,
8 | dummy_obs,
9 | summarize = True
10 | ):
11 | raise NotImplementedError
12 |
13 | @abstractmethod
14 | def select_action(self, *args, **kwargs):
15 | """act function"""
16 | raise NotImplementedError
17 |
18 | @abstractmethod
19 | def update(self, *args, **kwargs):
20 | raise NotImplementedError
21 |
22 | @abstractmethod
23 | def get_meta_specs(self, *args, **kwargs):
24 | raise NotImplementedError
25 |
26 | @abstractmethod
27 | def init_meta(self, *args, **kwargs):
28 | raise NotImplementedError
29 |
30 | @abstractmethod
31 | def update_meta(self, *args, **kwargs):
32 | raise NotImplementedError
33 |
34 | @abstractmethod
35 | def init_replay_buffer(self, *args, **kwargs):
36 | raise NotImplementedError
37 |
38 | @abstractmethod
39 | def store_timestep(self, *args, **kwargs):
40 | raise NotImplementedError
41 |
42 | @abstractmethod
43 | def sample_timesteps(self, *args, **kwargs):
44 | raise NotImplementedError
45 |
--------------------------------------------------------------------------------
/core/agents/ddpg_multimodal_skill_torch.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Any
2 | from collections import OrderedDict
3 | from functools import partial
4 |
5 | import dm_env
6 | import jax
7 | import jax.numpy as jnp
8 | import numpy as np
9 |
10 | from core import agents
11 | from core.envs import wrappers
12 | from core import calculations
13 |
14 |
15 | def episode_partition_mode_selector(
16 | time_step: wrappers.InformativeTimeStep,
17 | partitions: int = 2,
18 | ) -> bool:
19 | """
20 | True for control
21 | False for explore
22 | """
23 | max_time_step = time_step.max_timestep
24 | interval = max_time_step // partitions
25 | current_timestep = time_step.current_timestep
26 | return bool(current_timestep // interval % 2)
27 |
28 |
29 | def get_meta_specs(skill_dim: int,
30 | reward_free: bool
31 | ) -> Tuple:
32 | # noinspection PyRedundantParentheses
33 | if reward_free:
34 | return (
35 | dm_env.specs.Array((skill_dim,), np.float32, 'skill'),
36 | dm_env.specs.Array((), np.bool, 'mode') # for pytorch replay buffer
37 | )
38 | else:
39 | return (
40 | dm_env.specs.Array((skill_dim,), np.float32, 'skill'),
41 | )
42 |
43 | def init_meta(key,
44 | time_step: wrappers.InformativeTimeStep,
45 | reward_free: bool,
46 | skill_dim: int,
47 | partitions: int,
48 | search_mode = 'random_grid_search',
49 | skill_mode = 'half',
50 | skill_tracker: calculations.skill_utils.SkillRewardTracker=None,
51 | step: int = None,
52 | ) -> Tuple[OrderedDict, Any]:
53 | """
54 | :param key: only parameter needed in forward pass
55 | :param reward_free: defined as a constant during init of ddpg skill
56 | :param step: global step, at a certain step it only outputs the best skill
57 | :param skill_dim: defined as a constant during init of ddpg skill
58 | :param time_step: used to get current step in the episode for mode only
59 | :param skill_tracker: keep track in a NamedTuple of the best skill
60 | :return: during pretrain runing meta with skill and mode. Finetune return best skill in skill_tracker
61 | """
62 | meta = OrderedDict()
63 | if reward_free:
64 | # mode_key, skill_key = jax.random.split(key)
65 | skill_key = key
66 | # mode = bool(jax.random.bernoulli(key=mode_key, p=0.4))
67 | mode = episode_partition_mode_selector(time_step=time_step, partitions=partitions)
68 | if skill_mode == 'half':
69 | first_half_dim = int(skill_dim / 2)
70 | second_half_dim = skill_dim - first_half_dim
71 | zero = jnp.zeros(shape=(first_half_dim,), dtype=jnp.float32)
72 | uniform = jax.random.uniform(skill_key, shape=(second_half_dim,), minval=0., maxval=1.)
73 | if mode:
74 | skill = jnp.concatenate([zero, uniform])
75 | else:
76 | skill = jnp.concatenate([uniform, zero])
77 | elif skill_mode == 'sign':
78 | sign = -1. if mode else 1.
79 | skill = jax.random.uniform(skill_key, shape=(skill_dim,), minval=0., maxval=1.) * sign
80 | elif skill_mode == 'same':
81 | skill = jax.random.uniform(skill_key, shape=(skill_dim,), minval=0., maxval=1.)
82 | elif skill_mode == 'discrete':
83 | sign = -1. if mode else 1.
84 | skill = jnp.ones((skill_dim,)) * sign
85 | # sign = 0. if mode else 1.
86 | # skill = jnp.ones(shape=(skill_dim,), dtype=jnp.float32) * sign
87 | meta['mode'] = mode
88 | else:
89 | # outputs best skill after exploration loop
90 | # use constant skill function for baseline
91 | if search_mode == 'random_grid_search':
92 | skill = calculations.skill_utils.random_grid_search_skill(
93 | skill_dim=skill_dim,
94 | global_timestep=step,
95 | skill_tracker=skill_tracker,
96 | key=key
97 | )
98 | elif search_mode == 'grid_search':
99 | skill = calculations.skill_utils.grid_search_skill(
100 | skill_dim=skill_dim,
101 | global_timestep=step,
102 | skill_tracker=skill_tracker,
103 | )
104 | elif search_mode == 'random_search':
105 | skill = calculations.skill_utils.random_search_skill(
106 | skill_dim=skill_dim,
107 | global_timestep=step,
108 | skill_tracker=skill_tracker,
109 | key=key
110 | )
111 | elif search_mode == 'constant':
112 | skill = calculations.skill_utils.constant_fixed_skill(
113 | skill_dim=skill_dim,
114 | )
115 | elif search_mode == 'explore':
116 | skill = jnp.ones((skill_dim,))
117 |
118 | elif search_mode == 'control':
119 | skill = -jnp.ones((skill_dim,))
120 |
121 | if skill_tracker.update:
122 | # first step
123 | if skill_tracker.score_step == 0:
124 | pass
125 | elif skill_tracker.score_sum / skill_tracker.score_step > skill_tracker.best_score:
126 | skill_tracker = skill_tracker._replace(
127 | best_skill=skill_tracker.current_skill,
128 | best_score=skill_tracker.score_sum / skill_tracker.score_step
129 | )
130 | skill_tracker = skill_tracker._replace(
131 | score_sum=0.,
132 | score_step=0
133 | )
134 | # skill = jnp.ones(skill_dim, dtype=jnp.float32) * 0.5
135 | skill_tracker = skill_tracker._replace(current_skill=skill)
136 |
137 | meta['skill'] = skill
138 |
139 | return meta, skill_tracker
140 |
141 | class DDPGAgentMultiModalSkill(agents.DDPGAgentSkill):
142 |
143 | """Implement DDPG with skills"""
144 | def __init__(self,
145 | skills_cfg,
146 | reward_free: bool,
147 | search_mode,
148 | skill_mode,
149 | partitions,
150 | **kwargs
151 | ):
152 | super().__init__(
153 | skills_cfg,
154 | reward_free,
155 | **kwargs
156 | )
157 | # init in exploration mode
158 | self._mode = bool(0)
159 |
160 | to_jit = jax.jit if kwargs['to_jit'] else lambda x: x
161 |
162 | self.get_meta_specs = partial(
163 | get_meta_specs, skill_dim=skills_cfg.skill_dim, reward_free=reward_free
164 | )
165 | self.init_meta = partial(
166 | init_meta,
167 | partitions=partitions,
168 | reward_free=reward_free,
169 | skill_dim=skills_cfg.skill_dim,
170 | search_mode=search_mode,
171 | skill_mode=skill_mode
172 | )
173 |
174 | def update_meta(self,
175 | key: jax.random.PRNGKey,
176 | meta: OrderedDict,
177 | step: int,
178 | update_skill_every: int,
179 | time_step,
180 | skill_tracker=None,
181 | ) -> Tuple[OrderedDict, Any]:
182 | if step % update_skill_every == 0:
183 | return self.init_meta(key, step=step, skill_tracker=skill_tracker, time_step=time_step)
184 | return meta, skill_tracker
--------------------------------------------------------------------------------
/core/agents/ddpg_skill.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Any
2 | from collections import OrderedDict
3 | from functools import partial
4 |
5 | import dm_env
6 | import jax
7 | import jax.numpy as jnp
8 | import numpy as np
9 |
10 | from core import agents
11 | from core import calculations
12 |
13 | def init_meta(key,
14 | reward_free: bool,
15 | skill_dim: int,
16 | skill_tracker: calculations.skill_utils.SkillRewardTracker=None,
17 | step: int = None) -> Tuple[OrderedDict, Any]:
18 |
19 | meta = OrderedDict()
20 | if reward_free:
21 | skill = jax.random.uniform(key, shape=(skill_dim, ), minval=0., maxval=1.)
22 |
23 | else:
24 | # outputs best skill after exploration loop
25 | # use constant skill function for baseline
26 | skill = calculations.skill_utils.grid_search_skill(
27 | skill_dim=skill_dim,
28 | global_timestep=step,
29 | skill_tracker=skill_tracker,
30 | )
31 | if skill_tracker.update:
32 | # first step
33 | if skill_tracker.score_step == 0:
34 | pass
35 | elif skill_tracker.score_sum / skill_tracker.score_step > skill_tracker.best_score:
36 | skill_tracker = skill_tracker._replace(
37 | best_skill=skill_tracker.current_skill,
38 | best_score=skill_tracker.score_sum / skill_tracker.score_step
39 | )
40 | skill_tracker = skill_tracker._replace(
41 | score_sum=0.,
42 | score_step=0
43 | )
44 | # skill = jnp.ones(skill_dim, dtype=jnp.float32) * 0.5
45 | skill_tracker = skill_tracker._replace(current_skill=skill)
46 |
47 | meta['skill'] = skill
48 | return meta, skill_tracker
49 |
50 |
51 | def get_meta_specs(skill_dim: int) -> Tuple:
52 | """
53 | Each element of the tuple represent one spec for a particular element
54 | """
55 | # noinspection PyRedundantParentheses
56 | return (dm_env.specs.Array((skill_dim,), np.float32, 'skill'),)
57 |
58 |
59 | class DDPGAgentSkill(agents.DDPGAgent):
60 |
61 | """Implement DDPG with skills"""
62 | def __init__(self,
63 | skills_cfg,
64 | reward_free: bool,
65 | **kwargs
66 | ):
67 | super(DDPGAgentSkill, self).__init__(**kwargs)
68 | self.get_meta_specs = partial(get_meta_specs, skill_dim=skills_cfg.skill_dim)
69 | self.init_meta = partial(
70 | init_meta,
71 | reward_free=reward_free,
72 | skill_dim=skills_cfg.skill_dim,
73 | )
74 | self.update_meta = partial(self.update_meta, update_skill_every=skills_cfg.update_skill_every)
75 | self.init_params = partial(
76 | self.init_params,
77 | obs_type=kwargs['network_cfg'].obs_type
78 | )
79 |
80 | def init_params(self,
81 | init_key: jax.random.PRNGKey,
82 | dummy_obs: jnp.ndarray,
83 | summarize: bool = True,
84 | checkpoint_state = None,
85 | **kwargs
86 | ):
87 | """
88 | :param init_key:
89 | :param dummy_obs:
90 | :param summarize:
91 | :param checkpoint_state:
92 | :return:
93 | """
94 | skill = jnp.empty(self.get_meta_specs()[0].shape)
95 | dummy_obs = jnp.concatenate([dummy_obs, skill], axis=-1)
96 | state = super().init_params(init_key=init_key,
97 | dummy_obs=dummy_obs,
98 | summarize=summarize,
99 | checkpoint_state=checkpoint_state)
100 | return state
101 |
102 | def update_meta(self,
103 | key: jax.random.PRNGKey,
104 | meta: OrderedDict,
105 | step: int,
106 | update_skill_every: int,
107 | time_step=None,
108 | skill_tracker=None,
109 | ) -> Tuple[OrderedDict, Any]:
110 |
111 | if step % update_skill_every == 0:
112 | return self.init_meta(key, step=step, skill_tracker=skill_tracker)
113 | return meta, skill_tracker
114 |
115 |
--------------------------------------------------------------------------------
/core/calculations/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import l2_loss, particle_based_entropy, noise_contrastive_loss, cpc_loss, l2_loss_without_bias, softmax_probabilities
2 | from .layers import Identity, trunk, linear_relu, default_linear_init, feature_extractor, mlp, mlp_bottlneck
3 | from .distributions import TruncNormal
4 | from .params_utils import polyak_averaging
5 | from .misc import schedule
6 | from .skill_utils import random_search_skill, constant_fixed_skill, grid_search_skill, random_grid_search_skill
--------------------------------------------------------------------------------
/core/calculations/augmentations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import jax
5 | import jax.numpy as jnp
6 |
7 |
8 | def _random_flip_single_image(image, rng):
9 | _, flip_rng = jax.random.split(rng)
10 | should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
11 | image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x)
12 | return image
13 |
14 |
15 | def random_flip(images, rng):
16 | rngs = jax.random.split(rng, images.shape[0])
17 | return jax.vmap(_random_flip_single_image)(images, rngs)
18 |
19 |
20 | def random_shift_aug(x: jnp.ndarray):
21 | """x: [N, H, W, C]"""
22 | x = x.astype(dtype=jnp.float32)
23 | n, h, w, c = x.shape
24 | assert h == w
25 |
26 | return jax.lax.stop_gradient(x)
27 |
28 | class RandomShiftsAug(nn.Module):
29 | def __init__(self, pad):
30 | super().__init__()
31 | self.pad = pad
32 |
33 | def forward(self, x):
34 | x = x.float()
35 | n, c, h, w = x.size()
36 | assert h == w
37 | padding = tuple([self.pad] * 4)
38 | x = F.pad(x, padding, 'replicate')
39 | eps = 1.0 / (h + 2 * self.pad)
40 | arange = torch.linspace(-1.0 + eps,
41 | 1.0 - eps,
42 | h + 2 * self.pad,
43 | device=x.device,
44 | dtype=x.dtype)[:h]
45 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
46 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
47 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
48 |
49 | shift = torch.randint(0,
50 | 2 * self.pad + 1,
51 | size=(n, 1, 1, 2),
52 | device=x.device,
53 | dtype=x.dtype)
54 | shift *= 2.0 / (h + 2 * self.pad)
55 |
56 | grid = base_grid + shift
57 | return F.grid_sample(x,
58 | grid,
59 | padding_mode='zeros',
60 | align_corners=False)
--------------------------------------------------------------------------------
/core/calculations/distributions.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 |
5 | class Distribution:
6 | """
7 | Abstract base class for probability distribution
8 | """
9 | def __init__(self, batch_shape, event_shape):
10 | self._batch_shape = batch_shape
11 | self._event_shape = event_shape
12 |
13 | def sample(self, sample_shape):
14 | pass
15 |
16 | class TruncNormal:
17 | def __init__(self, loc, scale, low=-1.0, high=1.0):
18 | """ Trunc from -1 to 1 for DMC action space
19 | :param loc: mean (N, action_dim)
20 | :param scale: stddev ()
21 | :param low: clamp to low
22 | :param high: clamp to high
23 | """
24 | self.low = low
25 | self.high = high
26 | self.loc = loc
27 | self.scale = scale
28 | self.eps = 1e-6
29 |
30 | def mean(self):
31 | return self.loc
32 |
33 | def sample(self,
34 | clip=None,
35 | *,
36 | seed: jax.random.PRNGKey,
37 | # sample_shape: Sequence[int] = (),
38 | ):
39 | """Samples an event.
40 |
41 | Args:
42 | clip: implements clipped noise in DrQ-v2
43 | seed: PRNG key or integer seed.
44 |
45 | Returns:
46 | A sample of shape `sample_shape` + `batch_shape` + `event_shape`.
47 | """
48 | sample_shape = self.loc.shape
49 | noise = jax.random.normal(seed, sample_shape) # has to be same shape as loc which specifies the mean for each individual Gaussians
50 | noise *= self.scale
51 |
52 | if clip is not None:
53 | # clip N(0, var) of exploration schedule in DrQ-v2
54 | noise = jnp.clip(noise, a_min=-clip, a_max=clip)
55 | x = self.loc + noise
56 | # return jnp.clip(x, a_min=self.low, a_max=self.high)
57 | clamped_x = jnp.clip(x, a_min=self.low + self.eps, a_max=self.high - self.eps)
58 | x = x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(clamped_x) # trick to backprop on x without clamping affecting it
59 | return x
60 | #
61 | # class TruncNormal(distrax.Normal):
62 | # def __init__(self, loc, scale, low=-1.0, high=1.0):
63 | # """ Trunc from -1 to 1 for DMC action space
64 | # :param loc: mean
65 | # :param scale: stddev
66 | # :param low:
67 | # :param high:
68 | # :param eps:
69 | # """
70 | # super(TruncNormal, self).__init__(loc=loc, scale=scale)
71 | #
72 | # self.low = low
73 | # self.high = high
74 | # # self.eps = eps
75 | #
76 | # def _clamp(self, x):
77 | # """ Clamping method for TruncNormal"""
78 | # clamped_x = jnp.clip(x, self.low, self.high)
79 | # x = x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(clamped_x)
80 | # return x
81 | #
82 | # def sample(self,
83 | # clip=None,
84 | # *,
85 | # seed, #: Union[IntLike, PRNGKey],
86 | # sample_shape = (),#: Union[IntLike, Sequence[IntLike]] = ()
87 | # ):
88 | # """Samples an event.
89 | #
90 | # Args:
91 | # clip: implements clipped noise in DrQ-v2
92 | # seed: PRNG key or integer seed.
93 | # sample_shape: Additional leading dimensions for sample.
94 | #
95 | # Returns:
96 | # A sample of shape `sample_shape` + `batch_shape` + `event_shape`.
97 | # """
98 | # # this line check if rng is a PRNG key and sample_shape a tuple if not it converts them.
99 | # # rng, sample_shape = convert_seed_and_sample_shape(seed, sample_shape)
100 | # num_samples = functools.reduce(operator.mul, sample_shape, 1) # product
101 | #
102 | # eps = self._sample_from_std_normal(seed, num_samples)
103 | # scale = jnp.expand_dims(self._scale, range(eps.ndim - self._scale.ndim))
104 | # loc = jnp.expand_dims(self._loc, range(eps.ndim - self._loc.ndim))
105 | #
106 | # eps *= scale
107 | # if clip is not None:
108 | # # clip N(0, var) of exploration schedule in DrQ-v2
109 | # eps = jnp.clip(eps, a_min=-clip, a_max=clip)
110 | # samples = loc + eps
111 | # samples = self._clamp(samples)
112 | # return samples.reshape(sample_shape + samples.shape[1:])
113 |
114 | #
115 | # import torch
116 | # from torch import distributions as pyd
117 | # from torch.distributions.utils import _standard_normal
118 | #
119 | #
120 | # class TruncatedNormal(pyd.Normal):
121 | # def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
122 | # super().__init__(loc, scale, validate_args=False)
123 | # self.low = low
124 | # self.high = high
125 | # self.eps = eps
126 | #
127 | # def _clamp(self, x):
128 | # clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
129 | # x = x - x.detach() + clamped_x.detach()
130 | # return x
131 | #
132 | # def sample(self, clip=None, sample_shape=torch.Size()):
133 | # shape = self._extended_shape(sample_shape)
134 | # eps = _standard_normal(shape,
135 | # dtype=self.loc.dtype,
136 | # device=self.loc.device)
137 | # eps *= self.scale
138 | # if clip is not None:
139 | # eps = torch.clamp(eps, -clip, clip)
140 | # x = self.loc + eps
141 | # return self._clamp(x)
142 | #
143 | # if __name__ == "__main__":
144 | # truncNormal = TruncNormal(jnp.ones((3,)), 1.)
145 | # samples_jax = truncNormal.sample(clip=2, seed=jax.random.PRNGKey(666))
146 | #
147 | # torchtruncNormal = TruncatedNormal(torch.ones(3), 1.)
148 | # samples_torch = torchtruncNormal.sample(clip=2)
149 | #
150 | # print(samples_jax, samples_torch)
151 | # [[0.96648777 1.]
152 | # [0.4025777 1.]
153 | # [-0.59399736
154 | # 1.]]
--------------------------------------------------------------------------------
/core/calculations/layers.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Mapping, Union
2 |
3 | import haiku as hk
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | FloatStrOrBool = Union[str, float, bool]
8 | default_linear_init = hk.initializers.Orthogonal()
9 |
10 | class Identity(hk.Module):
11 | def __init__(self, name = 'identity'):
12 | super(Identity, self).__init__(name=name)
13 |
14 | def __call__(self, inputs):
15 | return inputs
16 |
17 | def trunk(ln_config: Mapping[str, FloatStrOrBool], feature_dim: int, name='trunk') -> Callable:
18 | """Layer"""
19 | return hk.Sequential([
20 | hk.Linear(output_size=feature_dim, w_init=default_linear_init, name='trunk_linear'),
21 | hk.LayerNorm(**ln_config, name='trunk_ln'),
22 | jax.nn.tanh
23 | ], name=name)
24 |
25 | def linear_relu(dim: int, name='linear_relu') -> Callable:
26 | """Layer"""
27 | return hk.Sequential([
28 | hk.Linear(output_size=dim, w_init=default_linear_init), #TODO pass it as argument
29 | jax.nn.relu
30 | ], name=name)
31 |
32 | def mlp(dim: int, out_dim: int, name='mlp') -> Callable:
33 | return hk.Sequential(
34 | [linear_relu(dim=dim),
35 | linear_relu(dim=dim),
36 | hk.Linear(out_dim, w_init=default_linear_init)
37 | ],
38 | name=name
39 | )
40 |
41 | def mlp_bottlneck(dim: int, out_dim: int, name='mlp') -> Callable:
42 | return hk.Sequential(
43 | [linear_relu(dim=dim // 2),
44 | linear_relu(dim=dim),
45 | hk.Linear(out_dim, w_init=default_linear_init)
46 | ],
47 | name=name
48 | )
49 |
50 | def feature_extractor(obs: jnp.ndarray, obs_type: str, name='encoder') -> jnp.ndarray:
51 | """encoder"""
52 | if obs_type == 'pixels':
53 | encoder = hk.Sequential([
54 | lambda x: x / 255.0 - 0.5, #FIXME put on GPU instead of CPU
55 | hk.Conv2D(output_channels=32, kernel_shape=3, stride=2, padding='VALID'),
56 | jax.nn.relu,
57 | hk.Conv2D(output_channels=32, kernel_shape=3, stride=1, padding='VALID'),
58 | jax.nn.relu,
59 | hk.Conv2D(output_channels=32, kernel_shape=3, stride=1, padding='VALID'),
60 | jax.nn.relu,
61 | hk.Conv2D(output_channels=32, kernel_shape=3, stride=1, padding='VALID'),
62 | jax.nn.relu,
63 | hk.Flatten(preserve_dims=-3) # [N, H, W, C] -> [N, -1]
64 | ], name=name)
65 | else:
66 | encoder = Identity()
67 |
68 | return encoder(inputs=obs)
69 |
70 |
71 | if __name__ == "__main__":
72 | def network(obs):
73 | def make_q(name):
74 | return hk.Sequential([
75 | linear_relu(10),
76 | hk.Linear(1, w_init=default_linear_init)
77 | ], name)
78 |
79 | q1 = make_q(name='q1')
80 | q2 = make_q(name='q2') # q1 neq q2
81 | return q1(obs), q2(obs)
82 |
83 | forward = hk.without_apply_rng(hk.transform(network))
84 | key = jax.random.PRNGKey(2)
85 | obs = jnp.ones((1, 10))
86 | state = forward.init(rng=key, obs=obs) # state = state2
87 | state_2 = forward.init(rng=key, obs=obs)
88 | print(state)
--------------------------------------------------------------------------------
/core/calculations/losses.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple
3 |
4 | import jax.numpy as jnp
5 | import jax
6 | import torch
7 | import chex
8 | import haiku as hk
9 | import tree
10 |
11 |
12 | def l2_loss(preds: jnp.ndarray,
13 | targets: jnp.ndarray = None
14 | ) -> jnp.ndarray:
15 | """Compute l2 loss if target not provided computes l2 loss with target 0"""
16 | if targets is None:
17 | targets = jnp.zeros_like(preds)
18 | chex.assert_type([preds, targets], float)
19 | return 0.5 * (preds - targets)**2
20 |
21 | def l2_loss_without_bias(params: hk.Params):
22 | l2_params = [p for ((module_name, x), p) in tree.flatten_with_path(params) if x == 'w']
23 | return 0.5 * sum(jnp.sum(jnp.square(p)) for p in l2_params)
24 |
25 |
26 | def running_stats(
27 | mean: jnp.ndarray,
28 | std: jnp.ndarray,
29 | x: jnp.ndarray,
30 | num: float,
31 | ):
32 | bs = x.shape[0]
33 | delta = jnp.mean(x, axis=0) - mean
34 | new_mean = mean + delta * bs / (num + bs)
35 | new_std = (std * num + jnp.var(x, axis=0) * bs +
36 | (delta**2) * num * bs / (num + bs)) / (num + bs)
37 | return new_mean, new_std, num + bs
38 |
39 |
40 | def particle_based_entropy(source: jnp.ndarray,
41 | target: jnp.ndarray,
42 | knn_clip: float = 0.0005, # todo remove for minimization
43 | knn_k: int = 16,
44 | knn_avg: bool = True,
45 | knn_rm: bool = True,
46 | minus_mean: bool = True,
47 | mean: jnp.ndarray = None,
48 | std: jnp.ndarray = None,
49 | num: float = None,
50 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]:
51 | """ Implement Particle Based Entropy Estimator as in APT
52 | :param knn_rm:
53 | :param mean: mean for running mean
54 | :param knn_clip:
55 | :param knn_k: hyperparameter k
56 | :param knn_avg: whether to take the average over k nearest neighbors
57 | :param source: value to compute entropy over [b1, c]
58 | :param target: value to compute entropy over [b1, c]
59 | :return: entropy of rep # (b1, 1)
60 | """
61 | # source = target = rep #[b1, c] [b2, c]
62 |
63 | b1, b2 = source.shape[0], target.shape[0]
64 | # (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2)
65 | sim_matrix = jnp.linalg.norm(
66 | source[:, None, :].reshape(b1, 1, -1) - target[None, :, :].reshape(1, b2, -1),
67 | axis=-1,
68 | ord=2
69 | )
70 | # take the min of the sim_matrix to get largest=False
71 | reward, _ = jax.lax.top_k(
72 | operand=-sim_matrix, #(b1, b2)
73 | k=knn_k
74 | )
75 | reward = -reward
76 |
77 | if not knn_avg: # only keep k-th nearest neighbor
78 | reward = reward[:, -1]
79 | reward = reward.reshape(-1, 1) # (b1 * k, 1)
80 | if knn_rm:
81 | mean, std, num = running_stats(mean, std, reward, num)
82 | if minus_mean:
83 | reward = (reward - mean) / std
84 | else:
85 | reward = reward/ std
86 | reward = jnp.maximum(
87 | reward - knn_clip,
88 | jnp.zeros_like(reward)
89 | )
90 | else: # average over all k nearest neigbors
91 | reward = reward.reshape(-1, 1) #(b1 * k, 1)
92 | if knn_rm:
93 | mean, std, num = running_stats(mean, std, reward, num)
94 | if minus_mean:
95 | reward = (reward - mean) / std
96 | else:
97 | reward = reward / std
98 | if knn_clip >= 0.0:
99 | reward = jnp.maximum(
100 | reward - knn_clip,
101 | jnp.zeros_like(reward)
102 | )
103 | reward = reward.reshape((b1, knn_k))
104 | reward = jnp.mean(reward, axis=1, keepdims=True) # (b1, 1)
105 |
106 | reward = jnp.log(reward + 1.0)
107 | return reward, mean, std, num
108 |
109 |
110 |
111 | def log_sum_exp(logits: jnp.ndarray):
112 | return jnp.log(
113 | jnp.sum(
114 | jnp.exp(logits),# [N, C]
115 | axis=-1
116 | ) # [N]
117 | )
118 |
119 | def normalize(x):
120 | return x / (jnp.linalg.norm(x=x, ord=2, axis=-1, keepdims=True) + 1e-12)
121 | # jnp.sqrt(jnp.sum(jnp.square(normalize(a))))
122 |
123 | def noise_contrastive_loss(
124 | query,
125 | key,
126 | temperature = 0.5
127 | ):
128 | """
129 | s_i - \sum \exp s_i
130 | """
131 | query = normalize(query)
132 | key = normalize(key)
133 | logits = query @ key.T #(N, N) positive pairs on the diagonal
134 | logits = logits / temperature
135 | shifted_cov = logits - jax.lax.stop_gradient(logits.max(axis=-1, keepdims=True)) # [N, N]
136 | diag_indexes = jnp.arange(shifted_cov.shape[0])[:, None]# [N, 1]
137 | pos = jnp.take_along_axis(arr=shifted_cov, indices=diag_indexes, axis=-1) # [N, 1]
138 | neg = log_sum_exp(shifted_cov)
139 | return -jnp.mean(pos.reshape(-1) - neg.reshape(-1))
140 |
141 |
142 | def softmax_probabilities(query, key, temperature=0.5):
143 | query = normalize(query)
144 | key = normalize(key)
145 | logits = query @ key.T
146 | logits = logits / temperature
147 | shifted_cov = logits - jax.lax.stop_gradient(logits.max(axis=-1, keepdims=True)) # [N, N]
148 | diag_indexes = jnp.arange(shifted_cov.shape[0])[:, None] # [N, 1]
149 | pos = jnp.take_along_axis(arr=shifted_cov, indices=diag_indexes, axis=-1) # [N, 1]
150 | pos = jnp.exp(pos)
151 | neg = jnp.sum(jnp.exp(logits), axis=-1, keepdims=True) # [N, 1]
152 | return pos / neg
153 |
154 |
155 | def cpc_loss(
156 | query,
157 | key,
158 | temperature = 0.5
159 | ):
160 |
161 | query = normalize(query)
162 | key = normalize(key)
163 | cov = query @ key.T # (N, N) positive pairs on the diagonal
164 | sim = jnp.exp(cov / temperature)
165 | neg = sim.sum(axis=-1) # b
166 | row_sub = jnp.ones_like(neg) * math.exp(1/temperature)
167 | neg = jnp.clip(neg - row_sub, a_min=1e-6)
168 |
169 | pos = jnp.exp(jnp.sum(query * key, axis=-1) / temperature) # b
170 | loss = -jnp.log(pos / (neg + 1e-6))
171 | return loss.mean()
172 |
173 | if __name__ == "__main__":
174 | # x = jax.random.normal(key=jax.random.PRNGKey(5), shape=(15, 5))
175 | # 10, 5
176 | jax_input = jnp.array([[ 0.61735314, 0.65116936, 0.37252188, 0.01196358,
177 | -1.0840642 ],
178 | [ 0.40633643, -0.3350711 , 0.433196 , 1.8324155 ,
179 | 1.2233032 ],
180 | [ 0.6076932 , 0.62271905, -0.5155139 , -0.8686952 ,
181 | 1.3694043 ],
182 | [ 1.5686233 , -1.0647503 , 1.0048455 , 1.4000669 ,
183 | 0.30719075],
184 | [ 1.6678249 , -0.5851507 , -1.420454 , -0.05948697,
185 | -1.5111905 ],
186 | [ 1.8621138 , -0.6911869 , -0.94851583, 1.159258 ,
187 | 1.5931036 ],
188 | [ 1.9720763 , -1.0973446 , 1.1731594 , 0.0780869 ,
189 | 0.143219 ],
190 | [-1.0157285 , 0.50870734, 0.39398482, 1.1644812 ,
191 | -0.26890013],
192 | [ 1.6161795 , 1.644653 , -1.0968473 , 1.0495588 ,
193 | 0.47088355],
194 | [-0.13400784, 0.5755616 , 0.4617284 , 0.08174139,
195 | -1.0918598 ]])
196 |
197 | torch_input = torch.tensor([[ 0.61735314, 0.65116936, 0.37252188, 0.01196358,
198 | -1.0840642 ],
199 | [ 0.40633643, -0.3350711 , 0.433196 , 1.8324155 ,
200 | 1.2233032 ],
201 | [ 0.6076932 , 0.62271905, -0.5155139 , -0.8686952 ,
202 | 1.3694043 ],
203 | [ 1.5686233 , -1.0647503 , 1.0048455 , 1.4000669 ,
204 | 0.30719075],
205 | [ 1.6678249 , -0.5851507 , -1.420454 , -0.05948697,
206 | -1.5111905 ],
207 | [ 1.8621138 , -0.6911869 , -0.94851583, 1.159258 ,
208 | 1.5931036 ],
209 | [ 1.9720763 , -1.0973446 , 1.1731594 , 0.0780869 ,
210 | 0.143219 ],
211 | [-1.0157285 , 0.50870734, 0.39398482, 1.1644812 ,
212 | -0.26890013],
213 | [ 1.6161795 , 1.644653 , -1.0968473 , 1.0495588 ,
214 | 0.47088355],
215 | [-0.13400784, 0.5755616 , 0.4617284 , 0.08174139,
216 | -1.0918598 ]])
217 |
218 | ## TEST particle
219 | # knn_k = 3
220 | # knn_clip = 0.0
221 | # mean = 0.0
222 | # knn_avg = True
223 | # knn_rm = True
224 | # particle_based_entropy = partial(particle_based_entropy, knn_k=knn_k, knn_clip=knn_clip, knn_rm=knn_rm,
225 | # knn_avg=knn_avg)
226 | # value = particle_based_entropy(rep=jax_input, mean=mean, step=1)
227 | # print(value)
228 | # rms = RMS('cpu')
229 | # pbe = PBE(rms, knn_clip, knn_k, knn_avg, knn_rm, 'cpu')
230 | # value_torch = pbe(torch_input)
231 | # print(value_torch)
232 |
233 | ## TEST nce
234 | # out = noise_contrastive_loss(jax_input, jax_input)
235 | # out = cpc_loss(jax_input, jax_input)
236 | # print(out)
237 | # out_torch = torch_nce(torch_input, torch_input)
238 | # print(out_torch)
239 | # print("Sanity Check value should be close to log(1/N): {}".format(math.log(jax_input.shape[0])))
--------------------------------------------------------------------------------
/core/calculations/misc.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 |
5 |
6 |
7 | def schedule(schdl, step):
8 | try:
9 | return float(schdl)
10 | except ValueError:
11 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
12 | if match:
13 | init, final, duration = [float(g) for g in match.groups()]
14 | mix = np.clip(step / duration, 0.0, 1.0)
15 | return (1.0 - mix) * init + mix * final
16 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
17 | if match:
18 | init, final1, duration1, final2, duration2 = [
19 | float(g) for g in match.groups()
20 | ]
21 | if step <= duration1:
22 | mix = np.clip(step / duration1, 0.0, 1.0)
23 | return (1.0 - mix) * init + mix * final1
24 | else:
25 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
26 | return (1.0 - mix) * final1 + mix * final2
27 | raise NotImplementedError(schdl)
28 |
29 |
30 |
--------------------------------------------------------------------------------
/core/calculations/params_utils.py:
--------------------------------------------------------------------------------
1 | import haiku as hk
2 | import jax
3 | import tree
4 |
5 | def count_param(params: hk.Params):
6 | params_count_list = [p.size for ((mod_name, x), p) in tree.flatten_with_path(params)]
7 | return sum(params_count_list)
8 |
9 |
10 | def polyak_averaging(params: hk.Params,
11 | target_params: hk.Params,
12 | tau: float
13 | ):
14 | return jax.tree_multimap(
15 | lambda x, y: tau * x + (1 - tau) * y,
16 | params, target_params
17 | )
18 |
--------------------------------------------------------------------------------
/core/calculations/skill_utils.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 | import jax
3 | from jax import numpy as jnp
4 | import numpy as np
5 |
6 |
7 | class SkillRewardTracker(NamedTuple):
8 | best_skill: jnp.ndarray
9 | best_score: np.float32
10 | score_sum: np.float32
11 | score_step: int
12 | current_skill: jnp.ndarray
13 | search_steps: int
14 | change_interval: int
15 | low: float
16 | update: bool
17 |
18 |
19 | def constant_fixed_skill(skill_dim: int,) -> jnp.ndarray:
20 | return jnp.ones((skill_dim,), dtype=jnp.float32) * 0.5
21 |
22 |
23 | def random_search_skill(
24 | skill_dim: int,
25 | global_timestep: int,
26 | skill_tracker: SkillRewardTracker,
27 | key: jax.random.PRNGKey
28 | ) -> jnp.ndarray:
29 | if global_timestep >= skill_tracker.search_steps:
30 | return skill_tracker.best_skill
31 | return jax.random.uniform(key, shape=(skill_dim, ), minval=0., maxval=1.)
32 |
33 |
34 | def random_grid_search_skill(key: jax.random.PRNGKey,
35 | skill_dim: int,
36 | global_timestep: int,
37 | skill_tracker: SkillRewardTracker,
38 | **kwargs) -> jnp.ndarray:
39 | if global_timestep >= skill_tracker.search_steps:
40 | return skill_tracker.best_skill
41 | increment = (1 - skill_tracker.low) / (skill_tracker.search_steps // skill_tracker.change_interval)
42 | start = global_timestep // skill_tracker.change_interval * increment
43 | end = (global_timestep // skill_tracker.change_interval + 1) * increment
44 | return jax.random.uniform(key, shape=(skill_dim,), minval=start, maxval=end)
45 |
46 |
47 | def grid_search_skill(skill_dim: int, global_timestep: int, skill_tracker: SkillRewardTracker) -> jnp.ndarray:
48 | if global_timestep >= skill_tracker.search_steps:
49 | return skill_tracker.best_skill
50 | return jnp.ones((skill_dim,)) * jnp.linspace(
51 | -1.,
52 | 0.,
53 | num=skill_tracker.search_steps // skill_tracker.change_interval
54 | )[global_timestep // skill_tracker.change_interval]
55 |
56 |
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from core.custom_dmc_tasks import quadruped, jaco, cheetah, walker, hopper
2 |
3 |
4 | def make(domain, task,
5 | task_kwargs=None,
6 | environment_kwargs=None,
7 | visualize_reward=False):
8 |
9 | if domain == 'cheetah':
10 | return cheetah.make(task,
11 | task_kwargs=task_kwargs,
12 | environment_kwargs=environment_kwargs,
13 | visualize_reward=visualize_reward)
14 | elif domain == 'walker':
15 | return walker.make(task,
16 | task_kwargs=task_kwargs,
17 | environment_kwargs=environment_kwargs,
18 | visualize_reward=visualize_reward)
19 | elif domain == 'hopper':
20 | return hopper.make(task,
21 | task_kwargs=task_kwargs,
22 | environment_kwargs=environment_kwargs,
23 | visualize_reward=visualize_reward)
24 | elif domain == 'quadruped':
25 | return quadruped.make(task,
26 | task_kwargs=task_kwargs,
27 | environment_kwargs=environment_kwargs,
28 | visualize_reward=visualize_reward)
29 | else:
30 | raise f'{task} not found'
31 |
32 | assert None
33 |
34 |
35 | def make_jaco(task, obs_type, seed):
36 | return jaco.make(task, obs_type, seed)
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/cheetah.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Cheetah Domain."""
16 |
17 | import collections
18 | import os
19 |
20 | from dm_control import mujoco
21 | from dm_control.rl import control
22 | from dm_control.suite import base
23 | from dm_control.suite import common
24 | from dm_control.utils import containers
25 | from dm_control.utils import rewards
26 | from dm_control.utils import io as resources
27 |
28 | # How long the simulation will run, in seconds.
29 | _DEFAULT_TIME_LIMIT = 10
30 |
31 | # Running speed above which reward is 1.
32 | _RUN_SPEED = 10
33 | _SPIN_SPEED = 5
34 |
35 | SUITE = containers.TaggedTasks()
36 |
37 |
38 | def make(task,
39 | task_kwargs=None,
40 | environment_kwargs=None,
41 | visualize_reward=False):
42 | task_kwargs = task_kwargs or {}
43 | if environment_kwargs is not None:
44 | task_kwargs = task_kwargs.copy()
45 | task_kwargs['environment_kwargs'] = environment_kwargs
46 | env = SUITE[task](**task_kwargs)
47 | env.task.visualize_reward = visualize_reward
48 | return env
49 |
50 |
51 | def get_model_and_assets():
52 | """Returns a tuple containing the model XML string and a dict of assets."""
53 | root_dir = os.path.dirname(os.path.dirname(__file__))
54 | xml = resources.GetResource(
55 | os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml'))
56 | return xml, common.ASSETS
57 |
58 |
59 |
60 | @SUITE.add('benchmarking')
61 | def run_backward(time_limit=_DEFAULT_TIME_LIMIT,
62 | random=None,
63 | environment_kwargs=None):
64 | """Returns the run task."""
65 | physics = Physics.from_xml_string(*get_model_and_assets())
66 | task = Cheetah(forward=False, flip=False, random=random)
67 | environment_kwargs = environment_kwargs or {}
68 | return control.Environment(physics,
69 | task,
70 | time_limit=time_limit,
71 | **environment_kwargs)
72 |
73 |
74 | @SUITE.add('benchmarking')
75 | def flip(time_limit=_DEFAULT_TIME_LIMIT,
76 | random=None,
77 | environment_kwargs=None):
78 | """Returns the run task."""
79 | physics = Physics.from_xml_string(*get_model_and_assets())
80 | task = Cheetah(forward=True, flip=True, random=random)
81 | environment_kwargs = environment_kwargs or {}
82 | return control.Environment(physics,
83 | task,
84 | time_limit=time_limit,
85 | **environment_kwargs)
86 |
87 |
88 | @SUITE.add('benchmarking')
89 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT,
90 | random=None,
91 | environment_kwargs=None):
92 | """Returns the run task."""
93 | physics = Physics.from_xml_string(*get_model_and_assets())
94 | task = Cheetah(forward=False, flip=True, random=random)
95 | environment_kwargs = environment_kwargs or {}
96 | return control.Environment(physics,
97 | task,
98 | time_limit=time_limit,
99 | **environment_kwargs)
100 |
101 |
102 | class Physics(mujoco.Physics):
103 | """Physics simulation with additional features for the Cheetah domain."""
104 | def speed(self):
105 | """Returns the horizontal speed of the Cheetah."""
106 | return self.named.data.sensordata['torso_subtreelinvel'][0]
107 |
108 | def angmomentum(self):
109 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
110 | return self.named.data.subtree_angmom['torso'][1]
111 |
112 |
113 | class Cheetah(base.Task):
114 | """A `Task` to train a running Cheetah."""
115 | def __init__(self, forward=True, flip=False, random=None):
116 | self._forward = 1 if forward else -1
117 | self._flip = flip
118 | super(Cheetah, self).__init__(random=random)
119 |
120 | def initialize_episode(self, physics):
121 | """Sets the state of the environment at the start of each episode."""
122 | # The indexing below assumes that all joints have a single DOF.
123 | assert physics.model.nq == physics.model.njnt
124 | is_limited = physics.model.jnt_limited == 1
125 | lower, upper = physics.model.jnt_range[is_limited].T
126 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
127 |
128 | # Stabilize the model before the actual simulation.
129 | for _ in range(200):
130 | physics.step()
131 |
132 | physics.data.time = 0
133 | self._timeout_progress = 0
134 | super().initialize_episode(physics)
135 |
136 | def get_observation(self, physics):
137 | """Returns an observation of the state, ignoring horizontal position."""
138 | obs = collections.OrderedDict()
139 | # Ignores horizontal position to maintain translational invariance.
140 | obs['position'] = physics.data.qpos[1:].copy()
141 | obs['velocity'] = physics.velocity()
142 | return obs
143 |
144 | def get_reward(self, physics):
145 | """Returns a reward to the agent."""
146 | if self._flip:
147 | reward = rewards.tolerance(self._forward * physics.angmomentum(),
148 | bounds=(_SPIN_SPEED, float('inf')),
149 | margin=_SPIN_SPEED,
150 | value_at_margin=0,
151 | sigmoid='linear')
152 |
153 | else:
154 | reward = rewards.tolerance(self._forward * physics.speed(),
155 | bounds=(_RUN_SPEED, float('inf')),
156 | margin=_RUN_SPEED,
157 | value_at_margin=0,
158 | sigmoid='linear')
159 | return reward
160 |
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/cheetah.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/hopper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Hopper domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from dm_control.suite import base
27 | from dm_control.suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | from dm_control.utils import io as resources
32 | import numpy as np
33 |
34 | SUITE = containers.TaggedTasks()
35 |
36 | _CONTROL_TIMESTEP = .02 # (Seconds)
37 |
38 | # Default duration of an episode, in seconds.
39 | _DEFAULT_TIME_LIMIT = 20
40 |
41 | # Minimal height of torso over foot above which stand reward is 1.
42 | _STAND_HEIGHT = 0.6
43 |
44 | # Hopping speed above which hop reward is 1.
45 | _HOP_SPEED = 2
46 | _SPIN_SPEED = 5
47 |
48 |
49 | def make(task,
50 | task_kwargs=None,
51 | environment_kwargs=None,
52 | visualize_reward=False):
53 | task_kwargs = task_kwargs or {}
54 | if environment_kwargs is not None:
55 | task_kwargs = task_kwargs.copy()
56 | task_kwargs['environment_kwargs'] = environment_kwargs
57 | env = SUITE[task](**task_kwargs)
58 | env.task.visualize_reward = visualize_reward
59 | return env
60 |
61 | def get_model_and_assets():
62 | """Returns a tuple containing the model XML string and a dict of assets."""
63 | root_dir = os.path.dirname(os.path.dirname(__file__))
64 | xml = resources.GetResource(
65 | os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml'))
66 | return xml, common.ASSETS
67 |
68 |
69 |
70 | @SUITE.add('benchmarking')
71 | def hop_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
72 | """Returns a Hopper that strives to hop forward."""
73 | physics = Physics.from_xml_string(*get_model_and_assets())
74 | task = Hopper(hopping=True, forward=False, flip=False, random=random)
75 | environment_kwargs = environment_kwargs or {}
76 | return control.Environment(physics,
77 | task,
78 | time_limit=time_limit,
79 | control_timestep=_CONTROL_TIMESTEP,
80 | **environment_kwargs)
81 |
82 | @SUITE.add('benchmarking')
83 | def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
84 | """Returns a Hopper that strives to hop forward."""
85 | physics = Physics.from_xml_string(*get_model_and_assets())
86 | task = Hopper(hopping=True, forward=True, flip=True, random=random)
87 | environment_kwargs = environment_kwargs or {}
88 | return control.Environment(physics,
89 | task,
90 | time_limit=time_limit,
91 | control_timestep=_CONTROL_TIMESTEP,
92 | **environment_kwargs)
93 |
94 | @SUITE.add('benchmarking')
95 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
96 | """Returns a Hopper that strives to hop forward."""
97 | physics = Physics.from_xml_string(*get_model_and_assets())
98 | task = Hopper(hopping=True, forward=False, flip=True, random=random)
99 | environment_kwargs = environment_kwargs or {}
100 | return control.Environment(physics,
101 | task,
102 | time_limit=time_limit,
103 | control_timestep=_CONTROL_TIMESTEP,
104 | **environment_kwargs)
105 |
106 |
107 | class Physics(mujoco.Physics):
108 | """Physics simulation with additional features for the Hopper domain."""
109 | def height(self):
110 | """Returns height of torso with respect to foot."""
111 | return (self.named.data.xipos['torso', 'z'] -
112 | self.named.data.xipos['foot', 'z'])
113 |
114 | def speed(self):
115 | """Returns horizontal speed of the Hopper."""
116 | return self.named.data.sensordata['torso_subtreelinvel'][0]
117 |
118 | def touch(self):
119 | """Returns the signals from two foot touch sensors."""
120 | return np.log1p(self.named.data.sensordata[['touch_toe',
121 | 'touch_heel']])
122 |
123 | def angmomentum(self):
124 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
125 | return self.named.data.subtree_angmom['torso'][1]
126 |
127 |
128 |
129 | class Hopper(base.Task):
130 | """A Hopper's `Task` to train a standing and a jumping Hopper."""
131 | def __init__(self, hopping, forward=True, flip=False, random=None):
132 | """Initialize an instance of `Hopper`.
133 |
134 | Args:
135 | hopping: Boolean, if True the task is to hop forwards, otherwise it is to
136 | balance upright.
137 | random: Optional, either a `numpy.random.RandomState` instance, an
138 | integer seed for creating a new `RandomState`, or None to select a seed
139 | automatically (default).
140 | """
141 | self._hopping = hopping
142 | self._forward = 1 if forward else -1
143 | self._flip = flip
144 | super(Hopper, self).__init__(random=random)
145 |
146 | def initialize_episode(self, physics):
147 | """Sets the state of the environment at the start of each episode."""
148 | randomizers.randomize_limited_and_rotational_joints(
149 | physics, self.random)
150 | self._timeout_progress = 0
151 | super(Hopper, self).initialize_episode(physics)
152 |
153 | def get_observation(self, physics):
154 | """Returns an observation of positions, velocities and touch sensors."""
155 | obs = collections.OrderedDict()
156 | # Ignores horizontal position to maintain translational invariance:
157 | obs['position'] = physics.data.qpos[1:].copy()
158 | obs['velocity'] = physics.velocity()
159 | obs['touch'] = physics.touch()
160 | return obs
161 |
162 | def get_reward(self, physics):
163 | """Returns a reward applicable to the performed task."""
164 | standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
165 | assert self._hopping
166 | if self._flip:
167 | hopping = rewards.tolerance(self._forward * physics.angmomentum(),
168 | bounds=(_SPIN_SPEED, float('inf')),
169 | margin=_SPIN_SPEED,
170 | value_at_margin=0,
171 | sigmoid='linear')
172 | else:
173 | hopping = rewards.tolerance(self._forward * physics.speed(),
174 | bounds=(_HOP_SPEED, float('inf')),
175 | margin=_HOP_SPEED / 2,
176 | value_at_margin=0.5,
177 | sigmoid='linear')
178 | return standing * hopping
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/jaco.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """A task where the goal is to move the hand close to a target prop or site."""
17 |
18 | import collections
19 |
20 | from dm_control import composer
21 | from dm_control.composer import initializers
22 | from dm_control.composer.observation import observable
23 | from dm_control.composer.variation import distributions
24 | from dm_control.entities import props
25 | from dm_control.manipulation.shared import arenas
26 | from dm_control.manipulation.shared import cameras
27 | from dm_control.manipulation.shared import constants
28 | from dm_control.manipulation.shared import observations
29 | from dm_control.manipulation.shared import registry
30 | from dm_control.manipulation.shared import robots
31 | from dm_control.manipulation.shared import tags
32 | from dm_control.manipulation.shared import workspaces
33 | from dm_control.utils import rewards
34 | import numpy as np
35 |
36 |
37 | _ReachWorkspace = collections.namedtuple(
38 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
39 |
40 | # Ensures that the props are not touching the table before settling.
41 | _PROP_Z_OFFSET = 0.001
42 |
43 | _DUPLO_WORKSPACE = _ReachWorkspace(
44 | target_bbox=workspaces.BoundingBox(
45 | lower=(-0.1, -0.1, _PROP_Z_OFFSET),
46 | upper=(0.1, 0.1, _PROP_Z_OFFSET)),
47 | tcp_bbox=workspaces.BoundingBox(
48 | lower=(-0.1, -0.1, 0.2),
49 | upper=(0.1, 0.1, 0.4)),
50 | arm_offset=robots.ARM_OFFSET)
51 |
52 | _SITE_WORKSPACE = _ReachWorkspace(
53 | target_bbox=workspaces.BoundingBox(
54 | lower=(-0.2, -0.2, 0.02),
55 | upper=(0.2, 0.2, 0.4)),
56 | tcp_bbox=workspaces.BoundingBox(
57 | lower=(-0.2, -0.2, 0.02),
58 | upper=(0.2, 0.2, 0.4)),
59 | arm_offset=robots.ARM_OFFSET)
60 |
61 | _TARGET_RADIUS = 0.05
62 | _TIME_LIMIT = 10.
63 |
64 | TASKS = {
65 | 'reach_top_left': workspaces.BoundingBox(
66 | lower=(-0.09, 0.09, _PROP_Z_OFFSET),
67 | upper=(-0.09, 0.09, _PROP_Z_OFFSET)),
68 | 'reach_top_right': workspaces.BoundingBox(
69 | lower=(0.09, 0.09, _PROP_Z_OFFSET),
70 | upper=(0.09, 0.09, _PROP_Z_OFFSET)),
71 | 'reach_bottom_left': workspaces.BoundingBox(
72 | lower=(-0.09, -0.09, _PROP_Z_OFFSET),
73 | upper=(-0.09, -0.09, _PROP_Z_OFFSET)),
74 | 'reach_bottom_right': workspaces.BoundingBox(
75 | lower=(0.09, -0.09, _PROP_Z_OFFSET),
76 | upper=(0.09, -0.09, _PROP_Z_OFFSET)),
77 | }
78 |
79 |
80 | def make(task_id, obs_type, seed):
81 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
82 | task = _reach(task_id, obs_settings=obs_settings, use_site=False)
83 | return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed)
84 |
85 |
86 |
87 | class MTReach(composer.Task):
88 | """Bring the hand close to a target prop or site."""
89 |
90 | def __init__(
91 | self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep):
92 | """Initializes a new `Reach` task.
93 |
94 | Args:
95 | arena: `composer.Entity` instance.
96 | arm: `robot_base.RobotArm` instance.
97 | hand: `robot_base.RobotHand` instance.
98 | prop: `composer.Entity` instance specifying the prop to reach to, or None
99 | in which case the target is a fixed site whose position is specified by
100 | the workspace.
101 | obs_settings: `observations.ObservationSettings` instance.
102 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
103 | control_timestep: Float specifying the control timestep in seconds.
104 | """
105 | self._arena = arena
106 | self._arm = arm
107 | self._hand = hand
108 | self._arm.attach(self._hand)
109 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
110 | self.control_timestep = control_timestep
111 | self._tcp_initializer = initializers.ToolCenterPointInitializer(
112 | self._hand, self._arm,
113 | position=distributions.Uniform(*workspace.tcp_bbox),
114 | quaternion=workspaces.DOWN_QUATERNION)
115 |
116 | # Add custom camera observable.
117 | self._task_observables = cameras.add_camera_observables(
118 | arena, obs_settings, cameras.FRONT_CLOSE)
119 |
120 | target_pos_distribution = distributions.Uniform(*TASKS[task_id])
121 | self._prop = prop
122 | if prop:
123 | # The prop itself is used to visualize the target location.
124 | self._make_target_site(parent_entity=prop, visible=False)
125 | self._target = self._arena.add_free_entity(prop)
126 | self._prop_placer = initializers.PropPlacer(
127 | props=[prop],
128 | position=target_pos_distribution,
129 | quaternion=workspaces.uniform_z_rotation,
130 | settle_physics=True)
131 | else:
132 | self._target = self._make_target_site(parent_entity=arena, visible=True)
133 | self._target_placer = target_pos_distribution
134 |
135 | obs = observable.MJCFFeature('pos', self._target)
136 | obs.configure(**obs_settings.prop_pose._asdict())
137 | self._task_observables['target_position'] = obs
138 |
139 | # Add sites for visualizing the prop and target bounding boxes.
140 | workspaces.add_bbox_site(
141 | body=self.root_entity.mjcf_model.worldbody,
142 | lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper,
143 | rgba=constants.GREEN, name='tcp_spawn_area')
144 | workspaces.add_bbox_site(
145 | body=self.root_entity.mjcf_model.worldbody,
146 | lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper,
147 | rgba=constants.BLUE, name='target_spawn_area')
148 |
149 | def _make_target_site(self, parent_entity, visible):
150 | return workspaces.add_target_site(
151 | body=parent_entity.mjcf_model.worldbody,
152 | radius=_TARGET_RADIUS, visible=visible,
153 | rgba=constants.RED, name='target_site')
154 |
155 | @property
156 | def root_entity(self):
157 | return self._arena
158 |
159 | @property
160 | def arm(self):
161 | return self._arm
162 |
163 | @property
164 | def hand(self):
165 | return self._hand
166 |
167 | @property
168 | def task_observables(self):
169 | return self._task_observables
170 |
171 | def get_reward(self, physics):
172 | hand_pos = physics.bind(self._hand.tool_center_point).xpos
173 | target_pos = physics.bind(self._target).xpos
174 | distance = np.linalg.norm(hand_pos - target_pos)
175 | return rewards.tolerance(
176 | distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS)
177 |
178 | def initialize_episode(self, physics, random_state):
179 | self._hand.set_grasp(physics, close_factors=random_state.uniform())
180 | self._tcp_initializer(physics, random_state)
181 | if self._prop:
182 | self._prop_placer(physics, random_state)
183 | else:
184 | physics.bind(self._target).pos = (
185 | self._target_placer(random_state=random_state))
186 |
187 |
188 | def _reach(task_id, obs_settings, use_site):
189 | """Configure and instantiate a `Reach` task.
190 |
191 | Args:
192 | obs_settings: An `observations.ObservationSettings` instance.
193 | use_site: Boolean, if True then the target will be a fixed site, otherwise
194 | it will be a moveable Duplo brick.
195 |
196 | Returns:
197 | An instance of `reach.Reach`.
198 | """
199 | arena = arenas.Standard()
200 | arm = robots.make_arm(obs_settings=obs_settings)
201 | hand = robots.make_hand(obs_settings=obs_settings)
202 | if use_site:
203 | workspace = _SITE_WORKSPACE
204 | prop = None
205 | else:
206 | workspace = _DUPLO_WORKSPACE
207 | prop = props.Duplo(observable_options=observations.make_options(
208 | obs_settings, observations.FREEPROP_OBSERVABLES))
209 | task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop,
210 | obs_settings=obs_settings,
211 | workspace=workspace,
212 | control_timestep=constants.CONTROL_TIMESTEP)
213 | return task
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/quadruped.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/walker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Planar Walker Domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from dm_control.suite import base
27 | from dm_control.suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | from dm_control.utils import io as resources
32 | from dm_control import suite
33 |
34 | _DEFAULT_TIME_LIMIT = 25
35 | _CONTROL_TIMESTEP = .025
36 |
37 | # Minimal height of torso over foot above which stand reward is 1.
38 | _STAND_HEIGHT = 1.2
39 |
40 | # Horizontal speeds (meters/second) above which move reward is 1.
41 | _WALK_SPEED = 1
42 | _RUN_SPEED = 8
43 | _SPIN_SPEED = 5
44 |
45 | SUITE = containers.TaggedTasks()
46 |
47 | def make(task,
48 | task_kwargs=None,
49 | environment_kwargs=None,
50 | visualize_reward=False):
51 | task_kwargs = task_kwargs or {}
52 | if environment_kwargs is not None:
53 | task_kwargs = task_kwargs.copy()
54 | task_kwargs['environment_kwargs'] = environment_kwargs
55 | env = SUITE[task](**task_kwargs)
56 | env.task.visualize_reward = visualize_reward
57 | return env
58 |
59 | def get_model_and_assets():
60 | """Returns a tuple containing the model XML string and a dict of assets."""
61 | root_dir = os.path.dirname(os.path.dirname(__file__))
62 | xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks',
63 | 'walker.xml'))
64 | return xml, common.ASSETS
65 |
66 |
67 |
68 |
69 |
70 |
71 | @SUITE.add('benchmarking')
72 | def flip(time_limit=_DEFAULT_TIME_LIMIT,
73 | random=None,
74 | environment_kwargs=None):
75 | """Returns the Run task."""
76 | physics = Physics.from_xml_string(*get_model_and_assets())
77 | task = PlanarWalker(move_speed=_RUN_SPEED,
78 | forward=True,
79 | flip=True,
80 | random=random)
81 | environment_kwargs = environment_kwargs or {}
82 | return control.Environment(physics,
83 | task,
84 | time_limit=time_limit,
85 | control_timestep=_CONTROL_TIMESTEP,
86 | **environment_kwargs)
87 |
88 |
89 | class Physics(mujoco.Physics):
90 | """Physics simulation with additional features for the Walker domain."""
91 | def torso_upright(self):
92 | """Returns projection from z-axes of torso to the z-axes of world."""
93 | return self.named.data.xmat['torso', 'zz']
94 |
95 | def torso_height(self):
96 | """Returns the height of the torso."""
97 | return self.named.data.xpos['torso', 'z']
98 |
99 | def horizontal_velocity(self):
100 | """Returns the horizontal velocity of the center-of-mass."""
101 | return self.named.data.sensordata['torso_subtreelinvel'][0]
102 |
103 | def orientations(self):
104 | """Returns planar orientations of all bodies."""
105 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
106 |
107 | def angmomentum(self):
108 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
109 | return self.named.data.subtree_angmom['torso'][1]
110 |
111 |
112 | class PlanarWalker(base.Task):
113 | """A planar walker task."""
114 | def __init__(self, move_speed, forward=True, flip=False, random=None):
115 | """Initializes an instance of `PlanarWalker`.
116 |
117 | Args:
118 | move_speed: A float. If this value is zero, reward is given simply for
119 | standing up. Otherwise this specifies a target horizontal velocity for
120 | the walking task.
121 | random: Optional, either a `numpy.random.RandomState` instance, an
122 | integer seed for creating a new `RandomState`, or None to select a seed
123 | automatically (default).
124 | """
125 | self._move_speed = move_speed
126 | self._forward = 1 if forward else -1
127 | self._flip = flip
128 | super(PlanarWalker, self).__init__(random=random)
129 |
130 | def initialize_episode(self, physics):
131 | """Sets the state of the environment at the start of each episode.
132 |
133 | In 'standing' mode, use initial orientation and small velocities.
134 | In 'random' mode, randomize joint angles and let fall to the floor.
135 |
136 | Args:
137 | physics: An instance of `Physics`.
138 |
139 | """
140 | randomizers.randomize_limited_and_rotational_joints(
141 | physics, self.random)
142 | super(PlanarWalker, self).initialize_episode(physics)
143 |
144 | def get_observation(self, physics):
145 | """Returns an observation of body orientations, height and velocites."""
146 | obs = collections.OrderedDict()
147 | obs['orientations'] = physics.orientations()
148 | obs['height'] = physics.torso_height()
149 | obs['velocity'] = physics.velocity()
150 | return obs
151 |
152 | def get_reward(self, physics):
153 | """Returns a reward to the agent."""
154 | standing = rewards.tolerance(physics.torso_height(),
155 | bounds=(_STAND_HEIGHT, float('inf')),
156 | margin=_STAND_HEIGHT / 2)
157 | upright = (1 + physics.torso_upright()) / 2
158 | stand_reward = (3 * standing + upright) / 4
159 |
160 | if self._flip:
161 | move_reward = rewards.tolerance(self._forward *
162 | physics.angmomentum(),
163 | bounds=(_SPIN_SPEED, float('inf')),
164 | margin=_SPIN_SPEED,
165 | value_at_margin=0,
166 | sigmoid='linear')
167 | else:
168 | move_reward = rewards.tolerance(
169 | self._forward * physics.horizontal_velocity(),
170 | bounds=(self._move_speed, float('inf')),
171 | margin=self._move_speed / 2,
172 | value_at_margin=0.5,
173 | sigmoid='linear')
174 |
175 | return stand_reward * (5 * move_reward + 1) / 6
176 |
--------------------------------------------------------------------------------
/core/custom_dmc_tasks/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/core/data/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import jax.numpy as jnp
4 | from dm_env import specs
5 |
6 | from .replay_buffer import get_reverb_replay_components, Batch, ReverbReplay, IntraEpisodicBuffer
7 | from .replay_buffer_torch import make_replay_loader, ReplayBufferStorage, ReplayBuffer
--------------------------------------------------------------------------------
/core/data/replay_buffer.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, List, Optional, NamedTuple, Dict, Any
2 | from collections import deque
3 | import dataclasses
4 |
5 | import reverb
6 | import numpy as np
7 | import dm_env
8 | from dm_env import specs
9 | from acme import adders, specs, types
10 | from acme.adders import reverb as adders_reverb
11 | from acme.datasets import reverb as datasets
12 |
13 |
14 | #################
15 | ### From Acme ###
16 | #################
17 |
18 | class Batch(NamedTuple):
19 | observation: np.ndarray
20 | action: np.ndarray
21 | reward: np.ndarray
22 | discount: np.ndarray
23 | next_observation: np.ndarray
24 | extras: Dict
25 |
26 | @dataclasses.dataclass
27 | class ReverbReplay:
28 | server: reverb.Server
29 | adder: adders.Adder
30 | data_iterator: Iterator[reverb.ReplaySample]
31 | client: Optional[reverb.Client] = None
32 |
33 | def make_replay_tables(
34 | environment_spec: specs.EnvironmentSpec,
35 | replay_table_name: str = 'replay buffer',
36 | max_replay_size: int = 2_000_000,
37 | min_replay_size: int = 100,
38 | extras_spec: types.NestedSpec = ()
39 | ) -> List[reverb.Table]:
40 | """Creates reverb tables for the algorithm."""
41 | return [reverb.Table(
42 | name=replay_table_name,
43 | sampler=reverb.selectors.Uniform(),
44 | remover=reverb.selectors.Fifo(),
45 | max_size=max_replay_size,
46 | rate_limiter=reverb.rate_limiters.MinSize(min_replay_size),
47 | signature=adders_reverb.NStepTransitionAdder.signature(
48 | environment_spec, extras_spec))]
49 |
50 | def make_dataset_iterator(
51 | replay_client: reverb.Client, batch_size: int,
52 | prefetch_size: int = 4, replay_table_name: str = 'replay buffer',
53 | ) -> Iterator[reverb.ReplaySample]:
54 | """Creates a dataset iterator to use for learning."""
55 | dataset = datasets.make_reverb_dataset(
56 | table=replay_table_name,
57 | server_address=replay_client.server_address,
58 | batch_size=batch_size,
59 | prefetch_size=prefetch_size)
60 | return dataset.as_numpy_iterator()
61 |
62 | def make_adder(
63 | replay_client: reverb.Client,
64 | n_step: int, discount: float,
65 | replay_table_name: str = 'replay buffer',) -> adders.Adder:
66 | """Creates an adder which handles observations."""
67 | return adders_reverb.NStepTransitionAdder(
68 | priority_fns={replay_table_name: None},
69 | client=replay_client,
70 | n_step=n_step,
71 | discount=discount
72 | )
73 |
74 | def get_reverb_replay_components(
75 | environment_spec: specs.EnvironmentSpec,
76 | n_step: int, discount: float, batch_size: int,
77 | max_replay_size: int = 2_000_000,
78 | min_replay_size: int = 100,
79 | replay_table_name: str = 'replay buffer',
80 | extras_spec: Optional[types.NestedSpec] = ()
81 | ) -> ReverbReplay:
82 | replay_table = make_replay_tables(environment_spec,
83 | replay_table_name, max_replay_size,
84 | min_replay_size=min_replay_size, extras_spec=extras_spec)
85 | server = reverb.Server(replay_table, port=None)
86 | address = f'localhost:{server.port}'
87 | client = reverb.Client(address)
88 | adder = make_adder(client, n_step, discount, replay_table_name)
89 | data_iterator = make_dataset_iterator(
90 | client, batch_size, replay_table_name=replay_table_name)
91 | return ReverbReplay(
92 | server, adder, data_iterator, client
93 | )
94 |
95 |
96 | class IntraEpisodicBuffer:
97 | def __init__(self, maxlen: int = 1001, full_method: str = 'episodic') -> None:
98 | self.timesteps = deque(maxlen=maxlen)
99 | self.extras = deque(maxlen=maxlen)
100 | self._maxlen = maxlen
101 | self.full_method = full_method
102 | self._last_timestep = None
103 |
104 | def add(self, timestep: dm_env.TimeStep, extra: Dict[str, Any]):
105 | self.timesteps.append(timestep)
106 | self.extras.append(extra)
107 | self._last_timestep = timestep
108 |
109 | def reset(self):
110 | self.timesteps = deque(maxlen=self._maxlen)
111 | self.extras = deque(maxlen=self._maxlen)
112 | self._last_timestep = None
113 |
114 | def __len__(self) -> int:
115 | return len(self.timesteps)
116 |
117 | def is_full(self):
118 | if self.full_method == 'episodic':
119 | # buffer is not full when just initialized/resetted
120 | if self._last_timestep is None:
121 | return False
122 | return self._last_timestep.last()
123 | if self.full_method == 'step':
124 | return len(self.timesteps) == self._maxlen
125 | raise NotImplementedError
126 |
--------------------------------------------------------------------------------
/core/data/replay_buffer_torch.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, OrderedDict, List, Tuple, Dict, Union
2 | import datetime
3 | import io
4 | import random
5 | import traceback
6 | from collections import defaultdict
7 | import pathlib
8 | import functools
9 |
10 | import numpy as np
11 | import dm_env
12 | import torch
13 | from torch.utils.data import IterableDataset
14 |
15 |
16 | class Batch(NamedTuple):
17 | observation: Union[np.ndarray, List]
18 | action: Union[np.ndarray, List]
19 | reward: Union[np.ndarray, List]
20 | discount: Union[np.ndarray, List]
21 | next_observation: Union[np.ndarray, List]
22 | extras: Dict # List
23 |
24 |
25 | def compute_episode_len(episode):
26 | # subtract -1 because the dummy first transition
27 | return next(iter(episode.values())).shape[0] - 1
28 |
29 |
30 | def save_episode(episode, fn):
31 | with io.BytesIO() as bs:
32 | np.savez_compressed(bs, **episode)
33 | bs.seek(0)
34 | with fn.open('wb') as f:
35 | f.write(bs.read())
36 |
37 |
38 | def _preload(replay_dir: pathlib.Path) -> Tuple[int, int]:
39 | """
40 | returns the number of episode and transitions in the replay_dir,
41 | it assumes that each episode's name has the format {}_{}_{episode_len}.npz
42 | """
43 | n_episodes, n_transitions = 0, 0
44 | for file in replay_dir.glob('*.npz'):
45 | _, _, episode_len = file.stem.split('_')
46 | n_episodes += 1
47 | n_transitions += int(episode_len)
48 |
49 | return n_episodes, n_transitions
50 |
51 |
52 | class ReplayBufferStorage:
53 | def __init__(self,
54 | data_specs, #: Tuple[specs, ...],
55 | meta_specs, #: Tuple[specs, ...],
56 | replay_dir: pathlib.Path = pathlib.Path.cwd() / 'buffer'
57 | ):
58 | """
59 | data_specs: (obs, action , reward, discount)
60 | meta_specs: any extra e.g. (skill, mode)
61 | """
62 |
63 | self._data_specs = data_specs
64 | self._meta_specs = meta_specs
65 | self._replay_dir = replay_dir
66 | replay_dir.mkdir(exist_ok=True)
67 | self._current_episode = defaultdict(list)
68 | self._n_episodes, self._n_transitions = _preload(replay_dir)
69 |
70 | def __len__(self):
71 | return self._n_transitions
72 |
73 | def add(self,
74 | time_step: dm_env.TimeStep,
75 | meta: OrderedDict
76 | ):
77 | self._add_meta(meta=meta)
78 | self._add_time_step(time_step=time_step)
79 | if time_step.last():
80 | self._store_episode()
81 |
82 | def _add_meta(self, meta: OrderedDict):
83 | for spec in self._meta_specs:
84 | value = meta[spec.name]
85 | if np.isscalar(value):
86 | value = np.full(spec.shape, value, spec.dtype)
87 | self._current_episode[spec.name].append(value)
88 | # for key, value in meta.items():
89 | # self._current_episode[key].append(value)
90 |
91 | def _add_time_step(self, time_step: dm_env.TimeStep):
92 | for spec in self._data_specs:
93 | value = time_step[spec.name]
94 | if np.isscalar(value):
95 | # convert it to a numpy array as shape given by the data specs (reward & discount)
96 | value = np.full(spec.shape, value, spec.dtype)
97 | assert spec.shape == value.shape and spec.dtype == value.dtype
98 | self._current_episode[spec.name].append(value)
99 |
100 | def _store_episode(self):
101 | episode = dict()
102 |
103 | # datas to save as numpy array
104 | for spec in self._data_specs:
105 | value = self._current_episode[spec.name]
106 | episode[spec.name] = np.array(value, spec.dtype)
107 |
108 | # metas to save as numpy array
109 | for spec in self._meta_specs:
110 | value = self._current_episode[spec.name]
111 | episode[spec.name] = np.array(value, spec.dtype)
112 |
113 | # reset current episode content
114 | self._current_episode = defaultdict(list)
115 |
116 | # save episode
117 | eps_idx = self._n_episodes
118 | eps_len = compute_episode_len(episode)
119 | self._n_episodes += 1
120 | self._n_transitions += eps_len
121 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
122 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
123 | save_episode(episode, self._replay_dir / eps_fn)
124 |
125 |
126 | def load_episode(fn):
127 | with fn.open('rb') as f:
128 | episode = np.load(f)
129 | episode = {k: episode[k] for k in episode.keys()}
130 | return episode
131 |
132 |
133 | class ReplayBuffer(IterableDataset):
134 |
135 | def __init__(self,
136 | storage: ReplayBufferStorage,
137 | max_size: int,
138 | num_workers: int,
139 | nstep: int,
140 | discount: int,
141 | fetch_every: int,
142 | save_snapshot: bool
143 | ):
144 | self._storage = storage
145 | self._size = 0
146 | self._max_size = max_size
147 | self._num_workers = max(1, num_workers)
148 | self._episode_fns = []
149 | self._episodes = dict()
150 | self._nstep = nstep
151 | self._discount = discount
152 | self._fetch_every = fetch_every
153 | self._samples_since_last_fetch = fetch_every
154 | self._save_snapshot = save_snapshot
155 |
156 | def __len__(self):
157 | return len(self._storage)
158 |
159 | def _sample_episode(self):
160 | """ Sample a single episode """
161 | eps_fn = random.choice(self._episode_fns)
162 | return self._episodes[eps_fn]
163 |
164 | def _store_episode(self, eps_fn):
165 | """
166 | load an episode in memory with dict self._episodes
167 | and self._episode_fns contains the sorted keys
168 | and deletes the file
169 | """
170 | try:
171 | episode = load_episode(eps_fn)
172 | except:
173 | return False
174 | eps_len = compute_episode_len(episode)
175 | # remove old episodes if max size is reached
176 | while eps_len + self._size > self._max_size:
177 | early_eps_fn = self._episode_fns.pop(0)
178 | early_eps = self._episodes.pop(early_eps_fn)
179 | self._size -= compute_episode_len(early_eps)
180 | early_eps_fn.unlink(missing_ok=True)
181 | # store the episode
182 | self._episode_fns.append(eps_fn)
183 | self._episode_fns.sort()
184 | self._episodes[eps_fn] = episode
185 | self._size += eps_len
186 |
187 | # delete episode if save_snapshot false
188 | if not self._save_snapshot:
189 | eps_fn.unlink(missing_ok=True)
190 | return True
191 |
192 | def _try_fetch(self):
193 | """
194 | Fetch all episodes, divided between workers
195 | """
196 | if self._samples_since_last_fetch < self._fetch_every:
197 | return
198 | self._samples_since_last_fetch = 0
199 | try:
200 | worker_id = torch.utils.data.get_worker_info().id
201 | except:
202 | worker_id = 0
203 |
204 | # last created to first created
205 | eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True)
206 | fetched_size = 0
207 | # load all episodes
208 | for eps_fn in eps_fns:
209 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
210 | # each worker load an episode
211 | if eps_idx % self._num_workers != worker_id:
212 | continue
213 | if eps_fn in self._episodes.keys():
214 | break
215 | if fetched_size + eps_len > self._max_size:
216 | break
217 | fetched_size += eps_len
218 | # stop if fail to load episode
219 | if not self._store_episode(eps_fn):
220 | break
221 |
222 | def _sample(self
223 | ): # -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, ...]:
224 | try:
225 | self._try_fetch()
226 | except:
227 | traceback.print_exc()
228 | self._samples_since_last_fetch += 1
229 | episode = self._sample_episode()
230 | # add +1 for the first dummy transition so only starts above 1 and below max-nstep
231 | # idx to take inside the episode
232 | idx = np.random.randint(0, compute_episode_len(episode) - self._nstep + 1) + 1
233 | meta = dict()
234 | # meta = []
235 | for spec in self._storage._meta_specs:
236 | meta[spec.name] = episode[spec.name][idx - 1]
237 | # meta.append(episode[spec.name][idx - 1])
238 | obs = episode['observation'][idx - 1] # account for first dummy transition
239 | action = episode['action'][idx] # on first dummy transition action is set to 0
240 | next_obs = episode['observation'][idx + self._nstep - 1]# account for first dummy transition
241 | reward = np.zeros_like(episode['reward'][idx])
242 | discount = np.ones_like(episode['discount'][idx])
243 | for i in range(self._nstep):
244 | step_reward = episode['reward'][idx + i]
245 | reward += discount * step_reward
246 | discount *= episode['discount'][idx + i] * self._discount
247 | # noinspection PyRedundantParentheses
248 | data = dict(
249 | observation=obs,
250 | action=action,
251 | reward=reward,
252 | discount=discount,
253 | next_observation=next_obs,
254 | )
255 | data.update(meta)
256 | return data
257 | # return (obs, action, reward, discount, next_obs, *meta)
258 |
259 | def __iter__(self):
260 | while True:
261 | yield self._sample()
262 |
263 |
264 | class RepeatSampler(object):
265 | """ Sampler that repeats forever.
266 | Args:
267 | sampler (Sampler)
268 | """
269 |
270 | def __init__(self, sampler):
271 | self.sampler = sampler
272 |
273 | def __iter__(self):
274 | while True:
275 | yield from iter(self.sampler)
276 |
277 |
278 | def numpy_collate(batch: List[Dict], meta_specs):
279 | res = defaultdict(list)
280 | for b in batch:
281 | for k, v in b.items():
282 | res[k].append(v)
283 | extras = dict()
284 | for spec in meta_specs:
285 | extras[spec.name] = np.stack(res[spec.name])
286 | return Batch(
287 | observation=np.stack(res['observation']),
288 | action=np.stack(res['action']),
289 | reward=np.stack(res['reward']),
290 | discount=np.stack(res['discount']),
291 | next_observation=np.stack(res['next_observation']),
292 | extras=extras
293 | )
294 |
295 | def numpy_collate_mode(batch: List[Dict], meta_specs):
296 | res_mode0 = defaultdict(list)
297 | res_mode1 = defaultdict(list)
298 | for b in batch:
299 |
300 | if b['mode'] == 0:
301 | res_mode0['skill'].append(b['skill'])
302 | res_mode0['observation'].append(b['observation'])
303 | res_mode0['next_observation'].append(b['next_observation'])
304 | res_mode0['action'].append(b['action'])
305 | res_mode0['reward'].append(b['reward'])
306 | res_mode0['discount'].append(b['discount'])
307 | elif b['mode'] == 1:
308 | res_mode1['skill'].append(b['skill'])
309 | res_mode1['observation'].append(b['observation'])
310 | res_mode1['next_observation'].append(b['next_observation'])
311 | res_mode1['action'].append(b['action'])
312 | res_mode1['reward'].append(b['reward'])
313 | # res_mode1['discount'].append(b['discount'] * 0.25)
314 | res_mode1['discount'].append(b['discount'])
315 |
316 | extras = dict()
317 | # for spec in meta_specs:
318 | extras['skill'] = [np.stack(res_mode0['skill']), np.stack(res_mode1['skill'])]
319 | # extras['skill'] = [] #[np.stack(res_mode0[spec.name]), np.stack(res_mode1[spec.name])]
320 | # if len(res_mode0['skill']):
321 | # extras['skill'].append(np.stack(res_mode0['skill']))
322 | # if len(res_mode1['skill']):
323 | # extras['skill'].append(np.stack(res_mode1['skill']))
324 |
325 | return Batch(
326 | observation=[np.stack(res_mode0['observation']), np.stack(res_mode1['observation'])],
327 | action=[np.stack(res_mode0['action']), np.stack(res_mode1['action'])],
328 | reward=[np.stack(res_mode0['reward']), np.stack(res_mode1['reward'])],
329 | discount=[np.stack(res_mode0['discount']), np.stack(res_mode1['discount'])],
330 | next_observation=[np.stack(res_mode0['next_observation']), np.stack(res_mode1['next_observation'])],
331 | extras=extras
332 | )
333 |
334 |
335 | def _worker_init_fn(worker_id):
336 | seed = np.random.get_state()[1][0] + worker_id
337 | np.random.seed(seed)
338 | random.seed(seed)
339 |
340 |
341 | def make_replay_loader(storage,
342 | max_size,
343 | batch_size,
344 | num_workers,
345 | nstep,
346 | discount,
347 | meta_specs,
348 | save_snapshot: bool = False):
349 |
350 |
351 | if 'mode' in [spec.name for spec in meta_specs]:
352 | # collate_fct = functools.partial(numpy_collate, meta_specs=meta_specs)
353 | collate_fct = functools.partial(numpy_collate_mode, meta_specs=meta_specs)
354 | else:
355 | collate_fct = functools.partial(numpy_collate, meta_specs=meta_specs)
356 |
357 | max_size_per_worker = max_size // max(1, num_workers)
358 | iterable = ReplayBuffer(storage,
359 | max_size_per_worker,
360 | num_workers,
361 | nstep,
362 | discount,
363 | fetch_every=1000,
364 | save_snapshot=save_snapshot)
365 |
366 | loader = torch.utils.data.DataLoader(iterable,
367 | batch_size=batch_size,
368 | num_workers=num_workers,
369 | pin_memory=True,
370 | worker_init_fn=_worker_init_fn,
371 | collate_fn=collate_fct
372 | )
373 | return loader
374 |
--------------------------------------------------------------------------------
/core/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | import dm_env
4 |
5 | from .dmc import make as make_dmc_env
6 |
7 |
8 | def make_env(action_type: str, cfg: NamedTuple, seed: int) -> dm_env.Environment:
9 | if action_type == 'continuous':
10 | return make_dmc_env(cfg.task, cfg.obs_type, cfg.frame_stack, cfg.action_repeat, seed)
11 | raise NotImplementedError
12 |
--------------------------------------------------------------------------------
/core/envs/dmc.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict, deque
2 | from typing import Any, NamedTuple
3 |
4 | import dm_env
5 | import numpy as np
6 | from dm_control import suite
7 | from dm_control.suite.wrappers import action_scale, pixels
8 | from dm_env import StepType, specs
9 |
10 | from core import custom_dmc_tasks as cdmc
11 | from core.envs.wrappers import InformativeTimestepWrapper, DMCTimeWrapper
12 |
13 |
14 | class ExtendedTimeStep(NamedTuple):
15 | step_type: Any
16 | reward: Any
17 | discount: Any
18 | observation: Any
19 | action: Any
20 |
21 | def first(self):
22 | return self.step_type == StepType.FIRST
23 |
24 | def mid(self):
25 | return self.step_type == StepType.MID
26 |
27 | def last(self):
28 | return self.step_type == StepType.LAST
29 |
30 | def __getitem__(self, attr):
31 | return getattr(self, attr)
32 |
33 |
34 | class FlattenJacoObservationWrapper(dm_env.Environment):
35 | def __init__(self, env):
36 | self._env = env
37 | self._obs_spec = OrderedDict()
38 | wrapped_obs_spec = env.observation_spec().copy()
39 | if 'front_close' in wrapped_obs_spec:
40 | spec = wrapped_obs_spec['front_close']
41 | # drop batch dim
42 | self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:],
43 | dtype=spec.dtype,
44 | minimum=spec.minimum,
45 | maximum=spec.maximum,
46 | name='pixels')
47 | wrapped_obs_spec.pop('front_close')
48 |
49 | for key, spec in wrapped_obs_spec.items():
50 | assert spec.dtype == np.float64
51 | assert type(spec) == specs.Array
52 | dim = np.sum(
53 | np.fromiter((np.int(np.prod(spec.shape))
54 | for spec in wrapped_obs_spec.values()), np.int32))
55 |
56 | self._obs_spec['observations'] = specs.Array(shape=(dim,),
57 | dtype=np.float32,
58 | name='observations')
59 |
60 | def _transform_observation(self, time_step):
61 | obs = OrderedDict()
62 |
63 | if 'front_close' in time_step.observation:
64 | pixels = time_step.observation['front_close']
65 | time_step.observation.pop('front_close')
66 | pixels = np.squeeze(pixels)
67 | obs['pixels'] = pixels
68 |
69 | features = []
70 | for feature in time_step.observation.values():
71 | features.append(feature.ravel())
72 | obs['observations'] = np.concatenate(features, axis=0)
73 | return time_step._replace(observation=obs)
74 |
75 | def reset(self):
76 | time_step = self._env.reset()
77 | return self._transform_observation(time_step)
78 |
79 | def step(self, action):
80 | time_step = self._env.step(action)
81 | return self._transform_observation(time_step)
82 |
83 | def observation_spec(self):
84 | return self._obs_spec
85 |
86 | def action_spec(self):
87 | return self._env.action_spec()
88 |
89 | def __getattr__(self, name):
90 | return getattr(self._env, name)
91 |
92 |
93 | class ActionRepeatWrapper(dm_env.Environment):
94 | def __init__(self, env, num_repeats):
95 | self._env = env
96 | self._num_repeats = num_repeats
97 |
98 | def step(self, action):
99 | reward = 0.0
100 | discount = 1.0
101 | for i in range(self._num_repeats):
102 | time_step = self._env.step(action)
103 | reward += (time_step.reward or 0.0) * discount
104 | discount *= time_step.discount
105 | if time_step.last():
106 | break
107 |
108 | return time_step._replace(reward=reward, discount=discount)
109 |
110 | def observation_spec(self):
111 | return self._env.observation_spec()
112 |
113 | def action_spec(self):
114 | return self._env.action_spec()
115 |
116 | def reset(self):
117 | return self._env.reset()
118 |
119 | def __getattr__(self, name):
120 | return getattr(self._env, name)
121 |
122 |
123 | class FrameStackWrapper(dm_env.Environment):
124 | def __init__(self, env, num_frames, pixels_key='pixels'):
125 | self._env = env
126 | self._num_frames = num_frames
127 | self._frames = deque([], maxlen=num_frames)
128 | self._pixels_key = pixels_key
129 |
130 | wrapped_obs_spec = env.observation_spec()
131 | assert pixels_key in wrapped_obs_spec
132 |
133 | pixels_shape = wrapped_obs_spec[pixels_key].shape
134 | # remove batch dim
135 | if len(pixels_shape) == 4:
136 | pixels_shape = pixels_shape[1:]
137 | self._obs_spec = specs.BoundedArray(shape=np.concatenate(
138 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
139 | dtype=np.uint8,
140 | minimum=0,
141 | maximum=255,
142 | name='observation')
143 |
144 | def _transform_observation(self, time_step):
145 | assert len(self._frames) == self._num_frames
146 | obs = np.concatenate(list(self._frames), axis=0)
147 | return time_step._replace(observation=obs)
148 |
149 | def _extract_pixels(self, time_step):
150 | pixels = time_step.observation[self._pixels_key]
151 | # remove batch dim
152 | if len(pixels.shape) == 4:
153 | pixels = pixels[0]
154 | return pixels.transpose(2, 0, 1).copy()
155 |
156 | def reset(self):
157 | time_step = self._env.reset()
158 | pixels = self._extract_pixels(time_step)
159 | for _ in range(self._num_frames):
160 | self._frames.append(pixels)
161 | return self._transform_observation(time_step)
162 |
163 | def step(self, action):
164 | time_step = self._env.step(action)
165 | pixels = self._extract_pixels(time_step)
166 | self._frames.append(pixels)
167 | return self._transform_observation(time_step)
168 |
169 | def observation_spec(self):
170 | return self._obs_spec
171 |
172 | def action_spec(self):
173 | return self._env.action_spec()
174 |
175 | def __getattr__(self, name):
176 | return getattr(self._env, name)
177 |
178 |
179 | class ActionDTypeWrapper(dm_env.Environment):
180 | def __init__(self, env, dtype):
181 | self._env = env
182 | wrapped_action_spec = env.action_spec()
183 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
184 | dtype,
185 | wrapped_action_spec.minimum,
186 | wrapped_action_spec.maximum,
187 | 'action')
188 |
189 | def step(self, action):
190 | action = action.astype(self._env.action_spec().dtype)
191 | return self._env.step(action)
192 |
193 | def observation_spec(self):
194 | return self._env.observation_spec()
195 |
196 | def action_spec(self):
197 | return self._action_spec
198 |
199 | def reset(self):
200 | return self._env.reset()
201 |
202 | def __getattr__(self, name):
203 | return getattr(self._env, name)
204 |
205 |
206 | class ObservationDTypeWrapper(dm_env.Environment):
207 | def __init__(self, env, dtype):
208 | self._env = env
209 | self._dtype = dtype
210 | wrapped_obs_spec = env.observation_spec()['observations']
211 | self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype,
212 | 'observation')
213 |
214 | def _transform_observation(self, time_step):
215 | obs = time_step.observation['observations'].astype(self._dtype)
216 | return time_step._replace(observation=obs)
217 |
218 | def reset(self):
219 | time_step = self._env.reset()
220 | return self._transform_observation(time_step)
221 |
222 | def step(self, action):
223 | time_step = self._env.step(action)
224 | return self._transform_observation(time_step)
225 |
226 | def observation_spec(self):
227 | return self._obs_spec
228 |
229 | def action_spec(self):
230 | return self._env.action_spec()
231 |
232 | def __getattr__(self, name):
233 | return getattr(self._env, name)
234 |
235 |
236 | class ExtendedTimeStepWrapper(dm_env.Environment):
237 | def __init__(self, env):
238 | self._env = env
239 |
240 | def reset(self):
241 | time_step = self._env.reset()
242 | return self._augment_time_step(time_step)
243 |
244 | def step(self, action):
245 | time_step = self._env.step(action)
246 | return self._augment_time_step(time_step, action)
247 |
248 | def _augment_time_step(self, time_step, action=None):
249 | if action is None:
250 | action_spec = self.action_spec()
251 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
252 | return ExtendedTimeStep(observation=time_step.observation,
253 | step_type=time_step.step_type,
254 | action=action,
255 | reward=time_step.reward or 0.0,
256 | discount=time_step.discount or 1.0)
257 |
258 | def observation_spec(self):
259 | return self._env.observation_spec()
260 |
261 | def action_spec(self):
262 | return self._env.action_spec()
263 |
264 | def __getattr__(self, name):
265 | return getattr(self._env, name)
266 |
267 |
268 | def _make_jaco(obs_type, domain, task, frame_stack, action_repeat, seed):
269 | env = cdmc.make_jaco(task, obs_type, seed)
270 | env = ActionDTypeWrapper(env, np.float32)
271 | env = ActionRepeatWrapper(env, action_repeat)
272 | env = FlattenJacoObservationWrapper(env)
273 | return env
274 |
275 |
276 | def _make_dmc(obs_type, domain, task, frame_stack, action_repeat, seed):
277 | visualize_reward = False
278 | if (domain, task) in suite.ALL_TASKS:
279 | env = suite.load(domain,
280 | task,
281 | task_kwargs=dict(random=seed),
282 | environment_kwargs=dict(flat_observation=True),
283 | visualize_reward=visualize_reward)
284 | else:
285 | env = cdmc.make(domain,
286 | task,
287 | task_kwargs=dict(random=seed),
288 | environment_kwargs=dict(flat_observation=True),
289 | visualize_reward=visualize_reward)
290 |
291 | env = ActionDTypeWrapper(env, np.float32)
292 | env = ActionRepeatWrapper(env, action_repeat)
293 | if obs_type == 'pixels':
294 | # zoom in camera for quadruped
295 | camera_id = dict(quadruped=2).get(domain, 0)
296 | render_kwargs = dict(height=84, width=84, camera_id=camera_id)
297 | env = pixels.Wrapper(env,
298 | pixels_only=True,
299 | render_kwargs=render_kwargs)
300 | return env
301 |
302 |
303 | def make(name, obs_type, frame_stack, action_repeat, seed):
304 | assert obs_type in ['states', 'pixels']
305 | domain, task = name.split('_', 1)
306 | domain = dict(cup='ball_in_cup').get(domain, domain)
307 |
308 | make_fn = _make_jaco if domain == 'jaco' else _make_dmc
309 | env = make_fn(obs_type, domain, task, frame_stack, action_repeat, seed)
310 |
311 | if obs_type == 'pixels':
312 | env = FrameStackWrapper(env, frame_stack)
313 | else:
314 | env = ObservationDTypeWrapper(env, np.float32)
315 |
316 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
317 | env = ExtendedTimeStepWrapper(env)
318 | return InformativeTimestepWrapper(
319 | DMCTimeWrapper(
320 | env,
321 | )
322 | )
323 |
--------------------------------------------------------------------------------
/core/envs/dmc_benchmark.py:
--------------------------------------------------------------------------------
1 | from core.custom_dmc_tasks import quadruped, jaco, cheetah, walker, hopper
2 |
3 |
4 | def make(domain, task,
5 | task_kwargs=None,
6 | environment_kwargs=None,
7 | visualize_reward=False):
8 | if domain == 'cheetah':
9 | return cheetah.make(task,
10 | task_kwargs=task_kwargs,
11 | environment_kwargs=environment_kwargs,
12 | visualize_reward=visualize_reward)
13 | elif domain == 'walker':
14 | return walker.make(task,
15 | task_kwargs=task_kwargs,
16 | environment_kwargs=environment_kwargs,
17 | visualize_reward=visualize_reward)
18 | elif domain == 'hopper':
19 | return hopper.make(task,
20 | task_kwargs=task_kwargs,
21 | environment_kwargs=environment_kwargs,
22 | visualize_reward=visualize_reward)
23 | elif domain == 'quadruped':
24 | return quadruped.make(task,
25 | task_kwargs=task_kwargs,
26 | environment_kwargs=environment_kwargs,
27 | visualize_reward=visualize_reward)
28 | else:
29 | raise f'{task} not found'
30 |
31 | assert None
32 |
33 |
34 | def make_jaco(task, obs_type, seed):
35 | return jaco.make(task, obs_type, seed)
--------------------------------------------------------------------------------
/core/envs/wrappers.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, NamedTuple
2 | from collections import deque
3 |
4 | import dm_env
5 | from dm_env import StepType, TimeStep
6 | from jax import numpy as jnp
7 | import numpy as np
8 |
9 |
10 | class InformativeTimeStep(NamedTuple):
11 | step_type: StepType
12 | reward: float
13 | discount: float
14 | observation: jnp.ndarray
15 | action: jnp.ndarray
16 | mode: int
17 | current_timestep: int
18 | max_timestep: int
19 |
20 | def first(self) -> bool:
21 | return self.step_type == StepType.FIRST
22 |
23 | def mid(self) -> bool:
24 | return self.step_type == StepType.MID
25 |
26 | def last(self) -> bool:
27 | return self.step_type == StepType.LAST
28 |
29 | def __getitem__(self, attr):
30 | return getattr(self, attr)
31 |
32 |
33 | def timestep2informative_timestep(
34 | timestep: TimeStep,
35 | action: Optional[jnp.ndarray] = None,
36 | mode: Optional[int] = None,
37 | current_timestep: Optional[int] = None,
38 | max_timestep: Optional[int] = None,) -> InformativeTimeStep:
39 | return InformativeTimeStep(
40 | step_type=timestep.step_type,
41 | reward=timestep.reward,
42 | discount=timestep.discount,
43 | observation=timestep.observation,
44 | action=action,
45 | mode=mode,
46 | current_timestep=current_timestep,
47 | max_timestep=max_timestep,
48 | )
49 |
50 |
51 | class Wrapper(dm_env.Environment):
52 | def __init__(self, env: dm_env.Environment):
53 | self._env = env
54 | # inherent some attributes from env, like time counter, etc
55 | for attr, val in vars(self._env).items():
56 | if attr not in vars(self):
57 | setattr(self, attr, val)
58 |
59 | def action_spec(self):
60 | return self._env.action_spec()
61 |
62 | @property
63 | def timestep(self):
64 | return self._env._timestep
65 |
66 | @property
67 | def max_timestep(self):
68 | return self._env.max_timestep
69 |
70 | def reset(self):
71 | return self._env.reset()
72 |
73 | def observation_spec(self):
74 | return self._env.observation_spec()
75 |
76 | def step(self, action: np.ndarray) -> dm_env.TimeStep:
77 | return self._env.step(action)
78 |
79 | def __getattr__(self, name):
80 | return getattr(self._env, name)
81 |
82 |
83 | class FrameStacker(Wrapper):
84 | def __init__(self, env: dm_env.Environment, frame_stack: int = 3):
85 | super().__init__(env)
86 | self._observation = deque(maxlen=frame_stack)
87 | self.n_stacks = frame_stack
88 |
89 | def observation_spec(self):
90 | single_observation_spec = self._env.observation_spec()
91 | new_shape = list(single_observation_spec.shape)
92 | new_shape[self._env._channel_axis] = new_shape[self._env._channel_axis] * self.n_stacks
93 | return dm_env.specs.Array(
94 | shape=tuple(new_shape),
95 | dtype=single_observation_spec.dtype,
96 | name=single_observation_spec.name
97 | )
98 |
99 | def reset(self,) -> dm_env.TimeStep:
100 | timestep = self._env.reset()
101 | # stack n_stacks init frames for first observation
102 | for _ in range(self.n_stacks):
103 | self._observation.append(timestep.observation)
104 | return timestep._replace(
105 | observation=np.concatenate(self._observation, axis=self._env._channel_axis))
106 |
107 | def step(self, action: np.ndarray) -> dm_env.TimeStep:
108 | timestep = self._env.step(action)
109 | self._observation.append(timestep.observation)
110 | return timestep._replace(
111 | observation=np.concatenate(self._observation, axis=self._env._channel_axis))
112 |
113 |
114 | class ActionRepeater(Wrapper):
115 | def __init__(self, env: dm_env.Environment, nrepeats: int = 3):
116 | super().__init__(env)
117 | self._nrepeats = nrepeats
118 |
119 | def reset(self,) -> dm_env.TimeStep:
120 | return self._env.reset()
121 |
122 | def step(self, action: np.ndarray) -> dm_env.TimeStep:
123 | for _ in range(self._nrepeats):
124 | timestep = self._env.step(action)
125 | return timestep
126 |
127 |
128 | class InformativeTimestepWrapper(Wrapper):
129 | def __init__(self, env: dm_env.Environment):
130 | super().__init__(env)
131 |
132 | def reset(self,) -> InformativeTimeStep:
133 | timestep = self._env.reset()
134 | action_spec = self.action_spec()
135 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
136 | return timestep2informative_timestep(
137 | # this namedtuple contains obs, reward, etc.
138 | timestep,
139 | action=action,
140 | # this is the time spent in this episode
141 | current_timestep=self._env.timestep,
142 | max_timestep=self._env.max_timestep,
143 | )
144 |
145 | def step(self, action: np.ndarray) -> InformativeTimeStep:
146 | timestep = self._env.step(action)
147 | return timestep2informative_timestep(
148 | timestep,
149 | action=action,
150 | current_timestep=self._env.timestep,
151 | max_timestep=self._env.max_timestep,
152 | )
153 |
154 |
155 | class RewardScaler(Wrapper):
156 | def __init__(self, env: dm_env.Environment, reward_scale: float):
157 | super().__init__(env)
158 | self._reward_scale = reward_scale
159 |
160 | def step(self, action: np.ndarray) -> dm_env.TimeStep:
161 | timestep = self._env.step(action)
162 | return dm_env.TimeStep(
163 | step_type=timestep.step_type, reward=timestep.reward * self._reward_scale,
164 | discount=timestep.discount, observation=timestep.observation
165 | )
166 |
167 |
168 | class DMCTimeWrapper(Wrapper):
169 | def __init__(self, env: dm_env.Environment,):
170 | super().__init__(env)
171 | self._env = env
172 | self._timestep = 0
173 | self.action_shape = self._env.action_spec().shape
174 |
175 | @property
176 | def max_timestep(self,) -> int:
177 | # last step
178 | if hasattr(self._env, '_time_limit'):
179 | return self._env._time_limit / self._env._task.control_timestep
180 | if hasattr(self._env, '_step_limit'):
181 | return self._env._step_limit
182 |
183 | @property
184 | def timestep(self,) -> int:
185 | # current in the episode
186 | return self._timestep
187 |
188 | def step(self, action: np.ndarray) -> dm_env.TimeStep:
189 | self._timestep += 1
190 | return self._env.step(action)
191 |
192 | def reset(self,) -> dm_env.TimeStep:
193 | self._timestep = 0
194 | return self._env.reset()
195 |
--------------------------------------------------------------------------------
/core/exp_utils/__init__.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 |
5 | from .checkpointing import Checkpointer
6 | from .video import VideoRecorder, TrainVideoRecorder
7 | from .loggers import MetricLogger, log_params_to_wandb, LogParamsEvery, Timer, Until, Every, dict_to_header
8 |
9 |
10 | def set_seed(seed):
11 | torch.manual_seed(seed)
12 | # if torch.cuda.is_available():
13 | # torch.cuda.manual_seed_all(seed)
14 | np.random.seed(seed)
15 | random.seed(seed)
16 |
--------------------------------------------------------------------------------
/core/exp_utils/checkpointing.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import Any, Mapping, Text
4 | from functools import partial
5 |
6 | import dill
7 | import jax
8 | import jax.numpy as jnp
9 | from iopath.common.file_io import PathManager
10 |
11 | # from core.misc.utils import broadcast
12 |
13 | logger = logging.getLogger(__name__)
14 | path_manager = PathManager()
15 |
16 | def tag_last_checkpoint(save_dir: str,
17 | last_filename_basename: str) -> None:
18 | """ save name of the last checkpoint in the file `last_checkpoint` """
19 | save_file = os.path.join(save_dir, "last_checkpoint")
20 | with path_manager.open(save_file, "w") as f:
21 | f.write(last_filename_basename)
22 |
23 | def save_state(save_dir: str,
24 | name: str,
25 | state: Mapping[Text, jnp.ndarray],
26 | step: int,
27 | rng,
28 | **kwargs: Any) -> None:
29 | n_devices = jax.local_device_count()
30 | if jax.process_index() != 0: # only checkpoint the first worker
31 | return
32 | checkpoint_data = dict(
33 | # state=state,
34 | state= jax.tree_map(
35 | lambda x: jax.device_get(x[0]) if n_devices > 1 else jax.device_get(x), state),
36 | step=step,
37 | rng=rng
38 | )
39 | checkpoint_data.update(kwargs)
40 | basename = "{}.pth".format(name)
41 | save_file = os.path.join(save_dir, basename)
42 | assert os.path.basename(save_file) == basename, basename
43 | logger.info("Saving checkpoint to {}".format(save_file))
44 | with path_manager.open(save_file, "wb") as f:
45 | dill.dump(checkpoint_data, f)
46 | # tag it for auto resuming
47 | tag_last_checkpoint(
48 | save_dir=save_dir,
49 | last_filename_basename=basename,
50 | )
51 |
52 | def has_last_checkpoint(save_dir:str) -> bool:
53 | save_dir = os.path.join(save_dir, "last_checkpoint")
54 | return path_manager.exists(save_dir)
55 |
56 | def get_last_checkpoint(save_dir: str) -> str:
57 | save_file = os.path.join(save_dir, "last_checkpoint")
58 | try:
59 | with path_manager.open(save_file, "r") as f:
60 | last_saved = f.read().strip()
61 | except IOError:
62 | # if file doesn't exist, maybe because it has just been
63 | # deleted by a separate process
64 | return ""
65 | return os.path.join(save_dir, last_saved)
66 |
67 | def resume_or_load(path: str, save_dir, *, resume: bool = False):
68 | if resume and has_last_checkpoint(save_dir):
69 | path = get_last_checkpoint(save_dir)
70 | return load_checkpoint(path)
71 | else:
72 | return load_checkpoint(path)
73 |
74 | def load_checkpoint(path: str) -> Mapping[str, Any]:
75 | """
76 | :param path:
77 | :return: empty dict if checkpoint doesn't exist
78 | """
79 | if not path:
80 | logger.info("No checkpoint given.")
81 | return dict()
82 |
83 | if not os.path.isfile(path):
84 | path = path_manager.get_local_path(path)
85 | assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
86 |
87 | with path_manager.open(path, 'rb') as checkpoint_file:
88 | checkpoint = dill.load(checkpoint_file)
89 | logger.info('Loading checkpoint from %s', checkpoint_file)
90 |
91 | return checkpoint
92 |
93 | class Checkpointer:
94 | def __init__(self,
95 | save_dir: str = "checkpoints",
96 | ):
97 | self.save_dir = save_dir
98 | os.makedirs(save_dir, exist_ok=True)
99 | self.save_state = partial(save_state, save_dir=save_dir)
100 | self.load_checkpoint = load_checkpoint
101 | self.resume_or_load = partial(resume_or_load, save_dir=save_dir)
102 |
--------------------------------------------------------------------------------
/core/exp_utils/loggers.py:
--------------------------------------------------------------------------------
1 | import random
2 | import timeit
3 | import time
4 | import contextlib
5 | import logging
6 | from collections import defaultdict
7 |
8 | import haiku as hk
9 | import csv
10 | import torch
11 | import numpy as np
12 | import jax.numpy as jnp
13 | import wandb
14 | from pathlib import Path
15 |
16 | #TODO remove those
17 | def log_params_to_wandb(params: hk.Params, step: int):
18 | if params:
19 | for module in sorted(params):
20 | if 'w' in params[module]:
21 | wandb.log({
22 | f'{module}/w': wandb.Histogram(params[module]['w'])
23 | }, step=step)
24 | if 'b' in params[module]:
25 | wandb.log({
26 | f'{module}/b': wandb.Histogram(params[module]['b'])
27 | }, step=step)
28 |
29 | class LogParamsEvery:
30 | def __init__(self, every, action_repeat=1):
31 | self._every = every
32 | self._action_repeat = action_repeat
33 |
34 | def __call__(self, params: hk.Params, step):
35 | if self._every is None:
36 | pass
37 | every = self._every // self._action_repeat
38 | if step % every == 0:
39 | log_params_to_wandb(params, step)
40 | pass
41 |
42 | class Until:
43 | def __init__(self, until, action_repeat=1):
44 | self._until = until
45 | self._action_repeat = action_repeat
46 |
47 | def __call__(self, step):
48 | if self._until is None:
49 | return True
50 | until = self._until // self._action_repeat
51 | return step < until
52 |
53 |
54 | class Every:
55 | def __init__(self, every, action_repeat=1):
56 | self._every = every
57 | self._action_repeat = action_repeat
58 |
59 | def __call__(self, step):
60 | if self._every is None:
61 | return False
62 | every = self._every // self._action_repeat
63 | if step % every == 0:
64 | return True
65 | return False
66 |
67 |
68 | class Timer:
69 | def __init__(self):
70 | self._start_time = time.time()
71 | self._last_time = time.time()
72 |
73 | def reset(self):
74 | elapsed_time = time.time() - self._last_time
75 | self._last_time = time.time()
76 | total_time = time.time() - self._start_time
77 | return elapsed_time, total_time
78 |
79 | def total_time(self):
80 | return time.time() - self._start_time
81 |
82 | @contextlib.contextmanager
83 | def time_activity(activity_name: str):
84 | logger = logging.getLogger(__name__)
85 | start = timeit.default_timer()
86 | yield
87 | duration = timeit.default_timer() - start
88 | logger.info('[Timing] %s finished (Took %.2fs).', activity_name, duration)
89 |
90 | class AverageMeter:
91 | def __init__(self):
92 | self._sum = 0.
93 | self._count = 0
94 | self.fmt = "{value:.4f}"
95 |
96 | def update(self, value, n=1):
97 | self._sum += value
98 | self._count += n
99 |
100 | @property
101 | def value(self):
102 | return self._sum / max(1, self._count)
103 |
104 | def __str__(self):
105 | return self.fmt.format(
106 | value=self.value
107 | )
108 |
109 | def dict_to_header(data: dict, header=None):
110 | if header is not None:
111 | header = [header]
112 | else:
113 | header = []
114 | delimiter = '\t'
115 | for name, value in data.items():
116 | if type(value) == float:
117 | header.append(
118 | '{}: {:.4f}'.format(name, value)
119 | )
120 | elif type(value) == np.ndarray: # reward is a np.ndarray of shape ()
121 | header.append(
122 | '{}: {:.4f}'.format(name, value)
123 | )
124 | else:
125 | header.append(
126 | '{}: {}'.format(name, value)
127 | )
128 | return delimiter.join(header)
129 |
130 | class MetricLogger:
131 | def __init__(self,
132 | csv_file_name: Path,
133 | use_wandb: bool,
134 | delimiter= "\t"
135 | ):
136 | self.logger = logging.getLogger(__name__)
137 | self._meters = defaultdict(AverageMeter) # factory
138 | self._csv_writer = None
139 | self._csv_file = None
140 | self._csv_file_name = csv_file_name
141 | self.delimiter = delimiter
142 | self.use_wandb = use_wandb
143 |
144 | def update_metrics(self,**kwargs):
145 | """Log the average of variables that are logged per episode"""
146 | for k, v in kwargs.items():
147 | if isinstance(v, jnp.DeviceArray):
148 | v = v.item()
149 | assert isinstance(v, (float, int))
150 | self._meters[k].update(v)
151 |
152 | def log_and_dump_metrics_to_wandb(self, step: int, header=''):
153 | """log and dump to wandb metrics"""
154 | if type(header) == dict:
155 | header = dict_to_header(data=header)
156 | self.logger.info(self._log_meters(header=header))
157 | if self.use_wandb:
158 | for name, meter in self._meters.items():
159 | wandb.log({name: np.mean(meter.value).item()}, step=step)
160 | self._clean_meters()
161 |
162 | def _clean_meters(self):
163 | self._meters.clear()
164 |
165 | def _remove_old_entries(self, data):
166 | rows = []
167 | with self._csv_file_name.open('r') as f:
168 | reader = csv.DictReader(f)
169 | for row in reader:
170 | if float(row['episode']) >= data['episode']: # assume episode exist in header of existing file
171 | break
172 | rows.append(row)
173 | with self._csv_file_name.open('w') as f:
174 | writer = csv.DictWriter(f,
175 | fieldnames=sorted(data.keys()),
176 | restval=0.0)
177 | writer.writeheader()
178 | for row in rows:
179 | writer.writerow(row)
180 |
181 | def dump_dict_to_csv(self, data: dict):
182 | """dump to wandb and csv the dict"""
183 | if self._csv_writer is None:
184 | should_write_header = True
185 | if self._csv_file_name.exists(): # if file already exists remove entries
186 | self._remove_old_entries(data)
187 | should_write_header = False
188 |
189 | self._csv_file = self._csv_file_name.open('a')
190 | self._csv_writer = csv.DictWriter(
191 | self._csv_file,
192 | fieldnames=sorted(data.keys()),
193 | restval=0.0
194 | )
195 | if should_write_header:
196 | self._csv_writer.writeheader()
197 | self._csv_writer.writerow(data)
198 | self._csv_file.flush()
199 |
200 | def dump_dict_to_wandb(self, step: int, data: dict):
201 | for name, value in data.items():
202 | if self.use_wandb:
203 | wandb.log({name: np.mean(value).item()}, step=step)
204 |
205 | def log_dict(self, header, data):
206 | self.logger.info(dict_to_header(data=data, header=header))
207 |
208 | def _log_meters(self, header: str):
209 | loss_str = [header]
210 | for name, meter in self._meters.items():
211 | loss_str.append(
212 | "{}: {}".format(name, str(meter))
213 | )
214 | return self.delimiter.join(loss_str)
215 |
--------------------------------------------------------------------------------
/core/exp_utils/video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imageio
3 | import numpy as np
4 | import wandb
5 |
6 |
7 | class VideoRecorder:
8 | def __init__(self,
9 | root_dir,
10 | render_size=256,
11 | fps=20,
12 | camera_id=0,
13 | use_wandb=False):
14 | if root_dir is not None:
15 | self.save_dir = root_dir / 'eval_video'
16 | self.save_dir.mkdir(exist_ok=True)
17 | else:
18 | self.save_dir = None
19 |
20 | self.render_size = render_size
21 | self.fps = fps
22 | self.frames = []
23 | self.camera_id = camera_id
24 | self.use_wandb = use_wandb
25 |
26 | def init(self, env, enabled=True):
27 | self.frames = []
28 | self.enabled = self.save_dir is not None and enabled
29 | self.record(env)
30 |
31 | def record(self, env):
32 | if self.enabled:
33 | if hasattr(env, 'physics'):
34 | frame = env.physics.render(height=self.render_size,
35 | width=self.render_size,
36 | camera_id=self.camera_id)
37 | else:
38 | frame = env.render()
39 | self.frames.append(frame)
40 |
41 | def log_to_wandb(self, step):
42 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2))
43 | fps, skip = 6, 8
44 | wandb.log({
45 | 'eval/video':
46 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif")
47 | }, step=step)
48 |
49 | def save(self, file_name, step):
50 | if self.enabled:
51 | if self.use_wandb:
52 | self.log_to_wandb(step)
53 | path = self.save_dir / file_name
54 | imageio.mimsave(str(path), self.frames, fps=self.fps)
55 |
56 |
57 | class TrainVideoRecorder:
58 | def __init__(self,
59 | root_dir,
60 | render_size=256,
61 | fps=20,
62 | camera_id=0,
63 | use_wandb=False):
64 | if root_dir is not None:
65 | self.save_dir = root_dir / 'train_video'
66 | self.save_dir.mkdir(exist_ok=True)
67 | else:
68 | self.save_dir = None
69 |
70 | self.render_size = render_size
71 | self.fps = fps
72 | self.frames = []
73 | self.camera_id = camera_id
74 | self.use_wandb = use_wandb
75 |
76 | def init(self, obs, enabled=True):
77 | self.frames = []
78 | self.enabled = self.save_dir is not None and enabled
79 | self.record(obs)
80 |
81 | def record(self, obs):
82 | if self.enabled:
83 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0),
84 | dsize=(self.render_size, self.render_size),
85 | interpolation=cv2.INTER_CUBIC)
86 | self.frames.append(frame)
87 |
88 | def log_to_wandb(self, step):
89 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2))
90 | fps, skip = 6, 8
91 | wandb.log({
92 | 'train/video':
93 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif")
94 | }, step=step)
95 |
96 | def save(self, file_name, step):
97 | if self.enabled:
98 | if self.use_wandb:
99 | self.log_to_wandb(step)
100 | path = self.save_dir / file_name
101 | imageio.mimsave(str(path), self.frames, fps=self.fps)
102 |
--------------------------------------------------------------------------------
/core/intrinsic/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import hydra
4 |
5 | from .intrinsic_reward_base import IntrinsicReward
6 | from .cic import CICReward
7 | from .multimodal_cic import MultimodalCICReward
8 |
9 | def make_intrinsic_reward(cfg):
10 | return hydra.utils.instantiate(cfg)
--------------------------------------------------------------------------------
/core/intrinsic/cic.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Mapping, Any, NamedTuple, Tuple, Dict, Union
2 | from functools import partial
3 | import logging
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import haiku as hk
8 | import optax
9 |
10 | from core import intrinsic
11 | from core import calculations
12 | from core import data
13 |
14 | class CICnetwork(hk.Module):
15 | def __init__(self,
16 | hidden_dim: int,
17 | skill_dim: int,
18 | project_skill: bool
19 | ):
20 | super(CICnetwork, self).__init__()
21 |
22 | self.state_net = calculations.mlp(hidden_dim, skill_dim, name='state_net')
23 | # self.state_net = calculations.mlp(hidden_dim, skill_dim//2, name='state_net')
24 | # self.next_state_net = calculations.mlp(hidden_dim, skill_dim, name='next_state_net')
25 | self.pred_net = calculations.mlp(hidden_dim, skill_dim, name='pred_net')
26 |
27 | if project_skill:
28 | self.skill_net = calculations.mlp(hidden_dim, skill_dim, name='skill_net')
29 | else:
30 | self.skill_net = calculations.Identity()
31 |
32 | def __call__(self, state, next_state, skill, is_training=True): # input is obs_dim - skill_dim
33 | state = self.state_net(state)
34 | next_state = self.state_net(next_state)
35 | # next_state = self.next_state_net(next_state)
36 | if is_training:
37 | query = self.skill_net(skill)
38 | key = self.pred_net(jnp.concatenate([state, next_state], axis=-1))
39 | return query, key
40 | else:
41 | return state, next_state
42 |
43 | def cic_foward(state: jnp.ndarray,
44 | next_state: jnp.ndarray,
45 | skill: jnp.ndarray,
46 | is_training: bool,
47 | network_cfg: Mapping[str, Any]
48 | ):
49 | model: Callable = CICnetwork(
50 | hidden_dim=network_cfg['hidden_dim'],
51 | skill_dim=network_cfg['skill_dim'],
52 | project_skill=network_cfg['project_skill']
53 | )
54 | return model(state, next_state, skill, is_training)
55 |
56 |
57 | class CICState(NamedTuple):
58 | cic_params: hk.Params
59 | cic_opt_params: optax.OptState
60 | running_mean: Union[float, None]
61 | running_std: Union[float, None]
62 | running_num: Union[float, None]
63 |
64 |
65 | class CICReward(intrinsic.IntrinsicReward):
66 | def __init__(self,
67 | to_jit: bool,
68 | network_cfg: Mapping[str, Any],
69 | lr: float,
70 | knn_entropy_config,
71 | temperature: float,
72 | name: str = 'cic',
73 | ):
74 | self.cic = hk.without_apply_rng(
75 | hk.transform(
76 | partial(
77 | cic_foward,
78 | network_cfg=network_cfg,
79 | )
80 | )
81 | )
82 | self._cpc_loss = partial(self._cpc_loss, temperature=temperature)
83 | self.init_params = partial(self.init_params, skill_dim=network_cfg['skill_dim'])
84 | self.cic_optimizer = optax.adam(learning_rate=lr)
85 | self.entropy_estimator = partial(calculations.particle_based_entropy,
86 | **knn_entropy_config)
87 | if to_jit:
88 | self.update_batch = jax.jit(self.update_batch)
89 |
90 | def init_params(self,
91 | init_key: jax.random.PRNGKey,
92 | dummy_obs: jnp.ndarray,
93 | skill_dim: int,
94 | summarize: bool = True
95 | ):
96 | # batch_size = dummy_obs.shape[0]
97 | dummy_skill = jax.random.uniform(key=init_key, shape=(skill_dim, ), minval=0, maxval=1)
98 | cic_init = self.cic.init(rng=init_key, state=dummy_obs, next_state=dummy_obs, skill=dummy_skill, is_training=True)
99 | cic_opt_init = self.cic_optimizer.init(cic_init)
100 | if summarize:
101 | logger = logging.getLogger(__name__)
102 | summarize_cic_forward = partial(self.cic.apply, is_training=True) # somehow only works this way
103 | logger.info(hk.experimental.tabulate(summarize_cic_forward)(cic_init, dummy_obs, dummy_obs, dummy_skill))
104 | return CICState(
105 | cic_params=cic_init,
106 | cic_opt_params=cic_opt_init,
107 | running_mean=jnp.zeros((1,)),
108 | running_std=jnp.ones((1,)),
109 | running_num=1e-4
110 | )
111 |
112 | def _cpc_loss(self,
113 | cic_params: hk.Params,
114 | obs: jnp.ndarray,
115 | next_obs: jnp.ndarray,
116 | skill: jnp.ndarray,
117 | temperature: float
118 | ):
119 | query, key = self.cic.apply(cic_params, obs, next_obs, skill, is_training=True) #(b, c)
120 | # loss = calculations.noise_contrastive_loss(query, key, temperature=temperature)
121 | loss = calculations.cpc_loss(query=query, key=key)
122 | logs = dict(
123 | cpc_loss=loss
124 | )
125 | return loss, logs
126 | # return noise_contrastive_loss(query, key)
127 |
128 | def _update_cic(self,
129 | cic_params: hk.Params,
130 | cic_opt_params: optax.OptState,
131 | obs: jnp.ndarray,
132 | next_obs: jnp.ndarray,
133 | skill: jnp.ndarray
134 | ):
135 | grad_fn = jax.grad(self._cpc_loss, has_aux=True)
136 | grads, logs = grad_fn(cic_params, obs, next_obs, skill)
137 | deltas, cic_opt_params = self.cic_optimizer.update(grads, cic_opt_params)
138 | cic_params = optax.apply_updates(cic_params, deltas)
139 | return (cic_params, cic_opt_params), logs
140 |
141 | def compute_reward(self, cic_params, obs, next_obs, skill, running_mean, running_std, running_num):
142 | source, target = self.cic.apply(cic_params, obs, next_obs, skill, is_training=False)
143 | reward, running_mean, running_std, running_num = self.entropy_estimator(
144 | source=source,
145 | target=target,
146 | num=running_num,
147 | mean=running_mean,
148 | std=running_std)
149 | return reward, running_mean, running_std, running_num
150 |
151 | def update_batch(self,
152 | state: CICState,
153 | batch: data.Batch,
154 | step: int,
155 | ) -> Tuple[CICState, NamedTuple, Dict]:
156 | obs = batch.observation
157 | extrinsic_reward = batch.reward
158 | next_obs = batch.next_observation
159 | meta = batch.extras
160 | skill = meta['skill']
161 | """ Updates CIC and batch"""
162 | logs = dict()
163 | # TODO add aug for pixel based
164 | (cic_params, cic_opt_params), cic_logs = self._update_cic(
165 | cic_params=state.cic_params,
166 | cic_opt_params=state.cic_opt_params,
167 | obs=obs,
168 | next_obs=next_obs,
169 | skill=skill)
170 | logs.update(cic_logs)
171 |
172 | intrinsic_reward, running_mean, running_std, running_num = self.compute_reward(
173 | cic_params=state.cic_params,
174 | obs=obs,
175 | next_obs=next_obs,
176 | running_num=state.running_num,
177 | skill=skill,
178 | running_mean=state.running_mean,
179 | running_std=state.running_std)
180 |
181 | logs['intrinsic_reward'] = jnp.mean(intrinsic_reward)
182 | logs['extrinsic_reward'] = jnp.mean(extrinsic_reward)
183 |
184 | return CICState(
185 | cic_params=cic_params,
186 | cic_opt_params=cic_opt_params,
187 | running_mean=running_mean,
188 | running_std=running_std,
189 | running_num=running_num
190 | ), batch._replace(reward=intrinsic_reward), logs
191 |
--------------------------------------------------------------------------------
/core/intrinsic/intrinsic_reward_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 |
5 | class IntrinsicReward(ABC):
6 |
7 | @abstractmethod
8 | def init_params(self, *args, **kwargs):
9 | raise NotImplementedError
10 |
11 | @abstractmethod
12 | def compute_reward(self, *args, **kwargs):
13 | raise NotImplementedError
14 |
15 | @abstractmethod
16 | def update_batch(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
--------------------------------------------------------------------------------
/core/intrinsic/multimodal_cic.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, Tuple, Dict
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import haiku as hk
6 | import optax
7 | import numpy as np
8 |
9 | from core import intrinsic
10 | from core import data
11 |
12 |
13 | class MultimodalCICState(NamedTuple):
14 | cic_params: hk.Params
15 | cic_opt_params: optax.OptState
16 | mode_0_running_mean: jnp.ndarray
17 | mode_0_running_std: jnp.ndarray
18 | mode_0_running_num: float
19 |
20 |
21 | class RunningStatistics(NamedTuple):
22 | mode_0_running_mean: jnp.ndarray
23 | mode_0_running_std: jnp.ndarray
24 | mode_0_running_num: float
25 |
26 |
27 | class MultimodalCICReward(intrinsic.CICReward):
28 | def __init__(self, to_jit, *args, **kwargs):
29 | # only jit the update fn
30 | super().__init__(False, *args, **kwargs)
31 | # # rewrite or will inherent
32 | if to_jit:
33 | self._update_cic = jax.jit(self._update_cic)
34 | self.entropy_estimator = jax.jit(self.entropy_estimator)
35 |
36 | def init_params(self,
37 | init_key: jax.random.PRNGKey,
38 | dummy_obs: jnp.ndarray,
39 | skill_dim: int,
40 | summarize: bool = True,
41 | ):
42 | cic_state = super().init_params(init_key, dummy_obs, skill_dim, summarize)
43 | return MultimodalCICState(
44 | cic_params=cic_state.cic_params,
45 | cic_opt_params=cic_state.cic_opt_params,
46 | mode_0_running_mean=jnp.zeros((1,)),
47 | mode_0_running_std=jnp.ones((1,)),
48 | mode_0_running_num=1e-4,
49 | )
50 |
51 | def compute_reward(self,
52 | cic_params,
53 | obs,
54 | next_obs,
55 | skill,
56 | statistics,
57 | **kwargs
58 | ):
59 | source_0, target_0 = self.cic.apply(cic_params,
60 | obs,
61 | next_obs,
62 | skill,
63 | is_training=False)
64 | reward, running_mean_0, running_std_0, running_num_0 = self.entropy_estimator(
65 | source=source_0,
66 | target=target_0,
67 | mean=statistics.mode_0_running_mean,
68 | std=statistics.mode_0_running_std,
69 | num=statistics.mode_0_running_num,
70 | )
71 |
72 | return reward, RunningStatistics(
73 | running_mean_0,
74 | running_std_0,
75 | running_num_0,
76 | )
77 |
78 | def update_batch(self,
79 | state: MultimodalCICState,
80 | batch: data.Batch,
81 | step: int,
82 | ) -> Tuple[MultimodalCICState, data.Batch, Dict]:
83 | """ Updates CIC and batch"""
84 | obs = batch.observation
85 | extrinsic_reward = batch.reward
86 | next_obs = batch.next_observation
87 | meta = batch.extras
88 | skill = meta['skill']
89 | logs = dict()
90 | # TODO add aug for pixel baseds
91 | (cic_params, cic_opt_params), cic_logs = self._update_cic(
92 | cic_params=state.cic_params,
93 | cic_opt_params=state.cic_opt_params,
94 | obs=jnp.concatenate(obs),
95 | next_obs=jnp.concatenate(next_obs),
96 | skill=jnp.concatenate(skill))
97 | logs.update(cic_logs)
98 |
99 | intrinsic_reward, statistics = self.compute_reward(cic_params=state.cic_params,
100 | obs=jnp.concatenate(obs),
101 | next_obs=jnp.concatenate(next_obs),
102 | skill=jnp.concatenate(meta['skill']),
103 | statistics=state)
104 | # todo do we care about logging? put before to prevent moving out of gpu and putting back
105 | logs['intrinsic_reward'] = jnp.mean(intrinsic_reward)
106 | logs['extrinsic_reward'] = jnp.mean(jnp.concatenate(extrinsic_reward)) # don't mean on a list
107 |
108 | intrinsic_reward = np.array(intrinsic_reward)
109 | intrinsic_reward[len(obs[0]):, :] *= -1
110 |
111 |
112 | return MultimodalCICState(
113 | cic_params=cic_params,
114 | cic_opt_params=cic_opt_params,
115 | mode_0_running_mean=statistics.mode_0_running_mean,
116 | mode_0_running_std=statistics.mode_0_running_std,
117 | mode_0_running_num=statistics.mode_0_running_num,
118 | ), data.Batch(
119 | observation=jnp.concatenate(obs),
120 | action=jnp.concatenate(batch.action),
121 | reward=intrinsic_reward,
122 | discount=jnp.concatenate(batch.discount),
123 | next_observation=jnp.concatenate(next_obs),
124 | extras=dict(skill=jnp.concatenate(skill))
125 | ), logs
126 |
--------------------------------------------------------------------------------
/figures/MOSS_robot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/MOSS/534820b5cb5713389f416f60c867de505e791166/figures/MOSS_robot.png
--------------------------------------------------------------------------------
/figures/fraction_rliable.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/MOSS/534820b5cb5713389f416f60c867de505e791166/figures/fraction_rliable.png
--------------------------------------------------------------------------------
/figures/rliable.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/MOSS/534820b5cb5713389f416f60c867de505e791166/figures/rliable.png
--------------------------------------------------------------------------------
/helpers.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, Any, Dict
2 | import time
3 |
4 | import wandb
5 | import haiku as hk
6 |
7 | from core.calculations import skill_utils
8 |
9 |
10 | def log_params_to_wandb(params: hk.Params, step: int):
11 | if params:
12 | for module in sorted(params):
13 | if 'w' in params[module]:
14 | wandb.log({
15 | f'{module}/w': wandb.Histogram(params[module]['w'])
16 | }, step=step)
17 | if 'b' in params[module]:
18 | wandb.log({
19 | f'{module}/b': wandb.Histogram(params[module]['b'])
20 | }, step=step)
21 |
22 |
23 | class LogParamsEvery:
24 | def __init__(self, every, action_repeat=1):
25 | self._every = every
26 | self._action_repeat = action_repeat
27 |
28 | def __call__(self, params: hk.Params, step):
29 | if self._every is None:
30 | pass
31 | every = self._every // self._action_repeat
32 | if step % every == 0:
33 | log_params_to_wandb(params, step)
34 | pass
35 |
36 |
37 | class Until:
38 | def __init__(self, until, action_repeat=1):
39 | self._until = until
40 | self._action_repeat = action_repeat
41 |
42 | def __call__(self, step):
43 | if self._until is None:
44 | return True
45 | until = self._until // self._action_repeat
46 | return step < until
47 |
48 |
49 | class Every:
50 | def __init__(self, every, action_repeat=1):
51 | self._every = every
52 | self._action_repeat = action_repeat
53 |
54 | def __call__(self, step):
55 | if self._every is None:
56 | return False
57 | every = self._every // self._action_repeat
58 | if step % every == 0:
59 | return True
60 | return False
61 |
62 |
63 | class Timer:
64 | def __init__(self):
65 | self._start_time = time.time()
66 | self._last_time = time.time()
67 |
68 | def reset(self):
69 | elapsed_time = time.time() - self._last_time
70 | self._last_time = time.time()
71 | total_time = time.time() - self._start_time
72 | return elapsed_time, total_time
73 |
74 | def total_time(self):
75 | return time.time() - self._start_time
76 |
77 |
78 | class CsvData(NamedTuple):
79 | episode_reward: float
80 | episode_length: int
81 | episode: int
82 | step: int
83 | total_time: float
84 | fps: float
85 |
86 |
87 | class LoopVar(NamedTuple):
88 | global_step: int
89 | global_episode: int
90 | episode_step: int
91 | episode_reward: float
92 | total_reward: float
93 | pointer: int
94 |
95 |
96 | class LoopsLength(NamedTuple):
97 | eval_until_episode: Until
98 | train_until_step: Until
99 | seed_until_step: Until
100 | eval_every_step: Every
101 |
102 |
103 | def increment_step(x: LoopVar,
104 | reward: float,
105 | n: int = 1
106 | ) -> LoopVar:
107 | return LoopVar(
108 | global_step=x.global_step + n,
109 | global_episode=x.global_episode,
110 | episode_step=x.episode_step + n,
111 | episode_reward=x.episode_reward + reward,
112 | total_reward=x.episode_reward + reward,
113 | pointer=x.pointer,
114 | )
115 |
116 |
117 | def increment_episode(x: LoopVar,
118 | n: int = 1
119 | ) -> LoopVar:
120 | return LoopVar(
121 | global_step=x.global_step,
122 | global_episode=x.global_episode + n,
123 | episode_step=x.episode_step,
124 | episode_reward=x.episode_reward,
125 | total_reward=x.episode_reward,
126 | pointer=x.pointer,
127 | )
128 |
129 |
130 | def reset_episode(x: LoopVar,
131 | ) -> LoopVar:
132 | return LoopVar(
133 | global_step=x.global_step,
134 | global_episode=x.global_episode,
135 | episode_step=0,
136 | episode_reward=0.,
137 | total_reward=x.episode_reward,
138 | pointer=x.pointer,
139 | )
140 |
141 |
142 | def update_skilltracker(
143 | x: skill_utils.SkillRewardTracker,
144 | reward: float
145 | ) -> skill_utils.SkillRewardTracker:
146 | # for pretrain, we dont need skill tracker
147 | if x is None:
148 | return
149 | return x._replace(
150 | score_sum=x.score_sum + reward,
151 | score_step=x.score_step + 1,
152 | )
153 |
154 |
155 | def parse_skilltracker(
156 | x: skill_utils.SkillRewardTracker,
157 | meta: Dict[str, Any],
158 | ) -> skill_utils.SkillRewardTracker:
159 | if not meta or 'tracker' not in meta:
160 | return x
161 | return meta['tracker']
162 |
163 |
164 | def init_skilltracker(
165 | search_steps: int,
166 | change_interval: int,
167 | low: float,
168 | ) -> skill_utils.SkillRewardTracker:
169 | return skill_utils.SkillRewardTracker(
170 | best_skill=None,
171 | best_score=-float('inf'),
172 | score_sum=0.,
173 | score_step=0,
174 | current_skill=None,
175 | search_steps=search_steps,
176 | change_interval=change_interval,
177 | low=low,
178 | update=True,
179 | )
180 |
181 |
182 | def skilltracker_update_on(
183 | x: skill_utils.SkillRewardTracker,
184 | ) -> skill_utils.SkillRewardTracker:
185 | if x is None:
186 | return
187 | return x._replace(update=True)
188 |
189 |
190 | def skilltracker_update_off(
191 | x: skill_utils.SkillRewardTracker,
192 | ) -> skill_utils.SkillRewardTracker:
193 | if x is None:
194 | return
195 | return x._replace(update=False)
196 |
--------------------------------------------------------------------------------
/pretrain_multimodal.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
3 | os.environ['MUJOCO_GL'] = 'egl'
4 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
5 | import random
6 | import logging
7 |
8 | import wandb
9 | import jax
10 | import torch
11 | import hydra
12 | import numpy as np
13 | from omegaconf import OmegaConf, DictConfig
14 | import pathlib
15 | from dm_env import specs
16 |
17 | from core import envs
18 | from core import agents
19 | from core import exp_utils
20 | import helpers
21 |
22 |
23 | class PretrainLoop:
24 | def __init__(self, cfg):
25 | self.logger = logging.getLogger(__name__)
26 | self.logger.info(OmegaConf.to_yaml(cfg))
27 | self.init_rng = jax.random.PRNGKey(cfg.seed)
28 | torch.manual_seed(cfg.seed)
29 | np.random.seed(cfg.seed)
30 | random.seed(cfg.seed)
31 |
32 | self.work_dir = pathlib.Path.cwd()
33 | self.use_wandb = cfg.use_wandb
34 | if cfg.use_wandb:
35 | WANDB_NOTES = cfg.wandb_note
36 | os.environ["WANDB_MODE"] = "offline"
37 | wandb.init(project=cfg.wandb_project_name,
38 | name='pretrain_multimodal_'+cfg.benchmark.task+WANDB_NOTES+str(cfg.seed),
39 | config=OmegaConf.to_container(cfg, resolve=True))
40 |
41 | # init env
42 | self.train_env = envs.make_env(cfg.agent.action_type, cfg.benchmark, seed=cfg.seed)
43 | self.eval_env = envs.make_env(cfg.agent.action_type, cfg.benchmark, seed=cfg.seed)
44 | self.action_repeat = cfg.benchmark.action_repeat
45 |
46 | # init agent
47 | self.agent = agents.make_agent(
48 | cfg.benchmark.obs_type,
49 | self.train_env.action_spec().shape,
50 | cfg.agent
51 | )
52 | self.intrinsic_reward = hydra.utils.instantiate(cfg.intrinsic)
53 | self.state = None
54 | self.intrinsic_state = None
55 | data_specs = (
56 | self.train_env.observation_spec(),
57 | self.train_env.action_spec(),
58 | specs.Array((1,), np.float32, 'reward'),
59 | specs.Array((1,), np.float32, 'discount'),
60 | )
61 | self.agent.init_replay_buffer(
62 | replay_buffer_cfg=cfg.agent.replay_buffer_cfg,
63 | replay_dir=self.work_dir / 'buffer',
64 | environment_spec=data_specs,
65 | )
66 | self.update_agent_every = cfg.update_every_steps
67 | self.checkpointer = exp_utils.Checkpointer(save_dir=cfg.save_dir)
68 | self.snapshot_steps = cfg.snapshots
69 |
70 | # init exp_utils
71 | self.train_video_recorder = exp_utils.TrainVideoRecorder(
72 | self.work_dir if cfg.save_train_video else None # if state based no pixels to save
73 | )
74 | self.video_recorder = exp_utils.VideoRecorder(
75 | self.work_dir if cfg.save_video else None
76 | )
77 |
78 | # init loop
79 | eval_until_episode = helpers.Until(cfg.num_eval_episodes, cfg.benchmark.action_repeat)
80 | seed_until_step = helpers.Until(cfg.num_seed_frames, cfg.benchmark.action_repeat)
81 | eval_every_step = helpers.Every(cfg.eval_every_frames, cfg.benchmark.action_repeat)
82 | train_until_step = helpers.Until(cfg.num_pretrain_frames, cfg.benchmark.action_repeat)
83 | self.loops_length = helpers.LoopsLength(
84 | eval_until_episode=eval_until_episode,
85 | train_until_step=train_until_step,
86 | seed_until_step=seed_until_step,
87 | eval_every_step=eval_every_step,
88 | )
89 | self.global_loop_var = helpers.LoopVar(
90 | global_step=0,
91 | global_episode=0,
92 | episode_step=0,
93 | episode_reward=0.,
94 | total_reward=0.,
95 | pointer=0)
96 |
97 | @property
98 | def global_frame(self):
99 | return self.global_loop_var.global_step * self.action_repeat
100 |
101 | def _exploration_loop(self, rng):
102 | time_step = self.train_env.reset()
103 | meta_rng, rng = jax.random.split(key=rng, num=2)
104 | meta = self.agent.init_meta(key=meta_rng, time_step=time_step)[0]
105 |
106 | self.agent.store_timestep(time_step=time_step, meta=meta)
107 | while self.loops_length.seed_until_step(self.global_loop_var.global_step):
108 | if time_step.last():
109 | self.global_loop_var = helpers.increment_episode(self.global_loop_var)
110 | time_step = self.train_env.reset()
111 | meta_rng, rng = jax.random.split(key=rng, num=2)
112 | meta = self.agent.init_meta(key=meta_rng, time_step=time_step)[0]
113 | self.agent.store_timestep(time_step=time_step, meta=meta)
114 | self.global_loop_var = helpers.reset_episode(self.global_loop_var)
115 |
116 | meta_rng, action_rng, rng = tuple(jax.random.split(key=rng, num=3))
117 | meta = self.agent.update_meta(key=meta_rng,
118 | meta=meta,
119 | step=self.global_loop_var.global_step, time_step=time_step)[0]
120 | action = jax.random.uniform(key=action_rng, shape=self.train_env.action_spec().shape, minval=-1.0, maxval=1.0)
121 | action = np.array(action)
122 |
123 | # take env step
124 | time_step = self.train_env.step(action)
125 | self.agent.store_timestep(time_step=time_step, meta=meta)
126 | # increment loop_vars and skill tracker
127 | self.global_loop_var = helpers.increment_step(self.global_loop_var, reward=time_step.reward)
128 |
129 | return time_step, meta
130 |
131 | def train_loop(self):
132 |
133 | metric_logger = exp_utils.MetricLogger(csv_file_name=self.work_dir / 'train.csv', use_wandb=self.use_wandb)
134 | timer = exp_utils.Timer()
135 | time_step = self.train_env.reset()
136 |
137 | self.logger.info("Pretraining from scratch")
138 | self.state = self.agent.init_params(
139 | init_key=self.init_rng,
140 | dummy_obs=time_step.observation
141 | )
142 | self.intrinsic_state = self.intrinsic_reward.init_params(
143 | init_key=self.init_rng,
144 | dummy_obs=time_step.observation,
145 | )
146 |
147 | step_rng, eval_rng, rng = jax.random.split(self.init_rng, num=3)
148 | # self.evaluate(eval_rng)
149 | self.logger.info("Exploration loop")
150 | time_step, meta = self._exploration_loop(step_rng)
151 | self.logger.info("Starting training at episode: {}, step: {}".format(self.global_loop_var.global_episode,
152 | self.global_loop_var.global_step))
153 |
154 | metrics = None
155 | while self.loops_length.train_until_step(self.global_loop_var.global_step):
156 | if time_step.last():
157 | self.global_loop_var = helpers.increment_episode(self.global_loop_var)
158 |
159 | # log metrics
160 | if metrics is not None:
161 | elapsed_time, total_time = timer.reset()
162 | episode_frame = self.global_loop_var.episode_step * self.action_repeat
163 | data = helpers.CsvData(
164 | step=self.global_loop_var.global_step,
165 | episode=self.global_loop_var.global_episode,
166 | episode_length=episode_frame,
167 | episode_reward=self.global_loop_var.episode_reward, # not a float type
168 | total_time=total_time,
169 | fps=episode_frame / elapsed_time
170 | )
171 | data = data._asdict()
172 | metric_logger.dump_dict_to_csv(data=data)
173 | metric_logger.dump_dict_to_wandb(step=self.global_loop_var.global_step, data=data)
174 | data.update(buffer_size=len(self.agent))
175 | metric_logger.log_and_dump_metrics_to_wandb(step=self.global_loop_var.global_step, header=data)
176 |
177 | # reset env
178 | time_step = self.train_env.reset()
179 | step_rng, rng = tuple(jax.random.split(rng, num=2))
180 | meta = self.agent.init_meta(step_rng, time_step=time_step)[0]
181 | # no need to parse because not updating it duing finetune loop
182 | self.agent.store_timestep(time_step=time_step, meta=meta)
183 | # train_video_recorder.init(time_step.observation)
184 | self.global_loop_var = helpers.reset_episode(self.global_loop_var)
185 |
186 | chkpt_pointer = min(self.global_loop_var.pointer, len(self.snapshot_steps) - 1)
187 | if (self.global_loop_var.global_step + 1) >= self.snapshot_steps[chkpt_pointer]:
188 | self.checkpointer.save_state(
189 | name=str(self.global_loop_var.global_step + 1),
190 | state=self.state,
191 | step=self.snapshot_steps[chkpt_pointer],
192 | rng=rng,
193 | )
194 | self.checkpointer.save_state(
195 | name=str(self.global_loop_var.global_step + 1) + '_cic',
196 | state=self.intrinsic_state,
197 | step=self.snapshot_steps[chkpt_pointer],
198 | rng=rng,
199 | )
200 | self.global_loop_var = self.global_loop_var._replace(pointer=self.global_loop_var.pointer + 1)
201 |
202 | if self.loops_length.eval_every_step(self.global_loop_var.global_step):
203 | eval_rng, rng = tuple(jax.random.split(rng, num=2))
204 |
205 | # agent step
206 | meta_rng, step_rng, update_rng, rng = tuple(jax.random.split(rng, num=4))
207 | meta = self.agent.update_meta(
208 | key=meta_rng, meta=meta, step=self.global_loop_var.global_step, time_step=time_step)[0]
209 | action = self.agent.select_action(
210 | state=self.state,
211 | obs=time_step.observation,
212 | meta=meta,
213 | step=self.global_loop_var.global_step,
214 | key=step_rng,
215 | greedy=False
216 | )
217 | if self.global_loop_var.global_step % self.update_agent_every == 0:
218 | batch = self.agent.sample_timesteps()
219 | self.intrinsic_state, batch, intrinsic_metrics = self.intrinsic_reward.update_batch(
220 | state=self.intrinsic_state,
221 | batch=batch,
222 | step=self.global_loop_var.global_step,
223 | )
224 | metric_logger.update_metrics(**intrinsic_metrics)
225 | self.state, metrics = self.agent.update(
226 | state=self.state,
227 | key=update_rng,
228 | step=self.global_loop_var.global_step,
229 | batch=batch
230 | )
231 | metric_logger.update_metrics(**metrics)
232 |
233 | # step on env
234 | time_step = self.train_env.step(action)
235 | self.agent.store_timestep(time_step=time_step, meta=meta)
236 | self.global_loop_var = helpers.increment_step(self.global_loop_var, reward=time_step.reward)
237 |
238 | eval_rng, rng = jax.random.split(rng, num=2)
239 | self.evaluate(eval_rng=eval_rng)
240 |
241 | def evaluate(self, eval_rng):
242 | metric_logger = exp_utils.MetricLogger(csv_file_name=pathlib.Path.cwd() / 'eval.csv', use_wandb=self.use_wandb)
243 | timer = exp_utils.Timer()
244 | local_loop_var = helpers.LoopVar(
245 | global_step=0,
246 | global_episode=0,
247 | episode_step=0,
248 | episode_reward=0.,
249 | total_reward=0.,
250 | pointer=0,
251 | )
252 | while self.loops_length.eval_until_episode(local_loop_var.global_episode):
253 | step_rng, rng = jax.random.split(key=eval_rng, num=2)
254 | time_step = self.eval_env.reset()
255 | meta = self.agent.init_meta(key=step_rng, time_step=time_step)[0]
256 | self.video_recorder.init(self.eval_env, enabled=(local_loop_var.global_episode == 0))
257 | while not time_step.last():
258 | action = self.agent.select_action(
259 | state=self.state,
260 | obs=time_step.observation,
261 | meta=meta,
262 | step=self.global_loop_var.global_step,
263 | key=step_rng,
264 | greedy=True
265 | )
266 | time_step = self.eval_env.step(action)
267 | self.video_recorder.record(self.eval_env)
268 | local_loop_var = helpers.increment_step(local_loop_var, reward=time_step.reward)
269 |
270 | # episode += 1
271 | local_loop_var = helpers.increment_episode(local_loop_var)
272 | self.video_recorder.save(f'{self.global_loop_var.global_step * self.action_repeat}.mp4',
273 | step=self.global_loop_var.global_step)
274 |
275 | n_frame = local_loop_var.global_step * self.action_repeat
276 | total_time = timer.total_time()
277 | data = helpers.CsvData(
278 | # episode_reward=total_reward / episode,
279 | episode_reward=local_loop_var.total_reward / local_loop_var.global_episode,
280 | episode_length=int(local_loop_var.global_step * self.action_repeat / local_loop_var.global_episode),
281 | episode=self.global_loop_var.global_episode, # must name it episode otherwise the csv cannot clean it
282 | step=self.global_loop_var.global_step,
283 | total_time=total_time,
284 | fps=n_frame / total_time
285 | )
286 | data = data._asdict()
287 | metric_logger.dump_dict_to_csv(data=data)
288 | metric_logger.dump_dict_to_wandb(step=self.global_loop_var.global_step, data=data)
289 | metric_logger.log_dict(data=data, header="Evaluation results: ")
290 | return data
291 |
292 |
293 | @hydra.main(config_path='conf/', config_name='config')
294 | def main(cfg: DictConfig):
295 | trainer = PretrainLoop(cfg)
296 | trainer.train_loop()
297 |
298 |
299 | if __name__ == '__main__':
300 | import warnings
301 | warnings.filterwarnings('ignore', category=DeprecationWarning) # dmc version
302 | import tensorflow as tf
303 |
304 | tf.config.set_visible_devices([], "GPU") # resolves tf/jax concurrent use conflict
305 | main()
306 |
--------------------------------------------------------------------------------