├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── CASCADE_motivation_with_description.jpg └── cascade_overview.jpg ├── dreamerv2 ├── agent.py ├── api.py ├── common │ ├── __init__.py │ ├── cdmc │ │ ├── __init__.py │ │ ├── cheetah.py │ │ ├── cheetah.xml │ │ ├── walker.py │ │ └── walker.xml │ ├── config.py │ ├── counter.py │ ├── dists.py │ ├── driver.py │ ├── envs.py │ ├── eval.py │ ├── flags.py │ ├── logger.py │ ├── nets.py │ ├── other.py │ ├── plot.py │ ├── ram_annotations.py │ ├── recorder.py │ ├── replay.py │ ├── tfutils.py │ └── when.py ├── configs.yaml ├── expl.py └── train.py └── main.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to cascade 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to cascade, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning General World Models in a Handful of Reward-Free Deployments 2 | 3 | [[Project Website](https://ycxuyingchen.github.io/cascade/)] 4 | 5 | Implementation of CASCADE in [Learning General World Models in a Handful of Reward-Free Deployments](https://arxiv.org/abs/2210.12719). 6 | 7 | CASCADE is a novel approach for self-supervised exploration in the reward-free deployment efficient setting. It seeks to learn a world model by collecting data with a population of agents, using an information theoretic objective inspired by Bayesian Active Learning. CASCADE achieves this by specifically maximizing the diversity of trajectories sampled by the population through a novel cascading objective. 8 | 9 | ![image](assets/cascade_overview.jpg) 10 | 11 | 12 | ## Install Dependencies 13 | ``` 14 | pip3 install tensorflow==2.6.0 keras=2.6 tensorflow_probability ruamel.yaml 'gym[atari]' dm_control pycparser scikit-learn scipy gym_minigrid 15 | ``` 16 | 17 | ## Run 18 | Example: train a population of 10 CASCADE agents on Crafter, collecting 50k steps per deployment. 19 | ``` 20 | python main.py --task=crafter_noreward --xpid=test_cascade_walker --num_agents=10 --cascade_alpha=0.1 --train_every=50000 --envs=10 --offline_model_train_steps=5001 21 | ``` 22 | 23 | ## Reference 24 | If you find this work useful, please cite: 25 | ``` 26 | @article{xu2022cascade, 27 | title = {Learning General World Models in a Handful of Reward-Free Deployments}, 28 | doi = {10.48550/ARXIV.2210.12719}, 29 | author = {Xu, Yingchen and Parker-Holder, Jack and Pacchiano, Aldo and Ball, Philip J. and Rybkin, Oleh and Roberts, Stephen J. and Rocktäschel, Tim and Grefenstette, Edward}, 30 | publisher = {arXiv}, 31 | url = {https://arxiv.org/abs/2210.12719}, 32 | year = {2022}, 33 | } 34 | ``` 35 | 36 | ## License 37 | The majority of CASCADE is licensed under CC-BY-NC, however portions of the project are available under separate license terms: https://github.com/danijar/dreamerv2 is licensed under the MIT license. 38 | -------------------------------------------------------------------------------- /assets/CASCADE_motivation_with_description.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/cascade/96d870817590e3ec8b8df3ed6d82359509b64a13/assets/CASCADE_motivation_with_description.jpg -------------------------------------------------------------------------------- /assets/cascade_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/cascade/96d870817590e3ec8b8df3ed6d82359509b64a13/assets/cascade_overview.jpg -------------------------------------------------------------------------------- /dreamerv2/api.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pathlib 4 | import sys 5 | import warnings 6 | 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 8 | logging.getLogger().setLevel('ERROR') 9 | warnings.filterwarnings('ignore', '.*box bound precision lowered.*') 10 | 11 | sys.path.append(str(pathlib.Path(__file__).parent)) 12 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 13 | 14 | import ruamel.yaml as yaml 15 | 16 | import common 17 | 18 | configs = yaml.safe_load( 19 | (pathlib.Path(__file__).parent / 'configs.yaml').read_text()) 20 | defaults = common.Config(configs.pop('defaults')) 21 | -------------------------------------------------------------------------------- /dreamerv2/common/__init__.py: -------------------------------------------------------------------------------- 1 | # General tools. 2 | from .config import * 3 | from .counter import * 4 | from .flags import * 5 | from .logger import * 6 | from .when import * 7 | from .eval import * 8 | from .cdmc import * 9 | 10 | # RL tools. 11 | from .other import * 12 | from .driver import * 13 | from .envs import * 14 | from .replay import * 15 | 16 | # TensorFlow tools. 17 | from .tfutils import * 18 | from .dists import * 19 | from .nets import * 20 | -------------------------------------------------------------------------------- /dreamerv2/common/cdmc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .walker import make_walker 8 | from .cheetah import make_cheetah 9 | 10 | def make_dmc_all(domain, task, 11 | task_kwargs=None, 12 | environment_kwargs=None, 13 | visualize_reward=False): 14 | 15 | if domain == 'walker': 16 | return make_walker(task, 17 | task_kwargs=task_kwargs, 18 | environment_kwargs=environment_kwargs, 19 | visualize_reward=visualize_reward) 20 | elif domain == 'cheetah': 21 | return make_cheetah(task, 22 | task_kwargs=task_kwargs, 23 | environment_kwargs=environment_kwargs, 24 | visualize_reward=visualize_reward) 25 | 26 | 27 | DMC_TASK_IDS = { 28 | 'dmc_walker_all': ['stand', 'walk', 'run', 'flip'], 29 | 'dmc_cheetah_all': ['run-fwd', 'run-bwd', 'flip-fwd', 'flip-bwd'], 30 | } -------------------------------------------------------------------------------- /dreamerv2/common/cdmc/cheetah.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | 9 | from dm_control import mujoco 10 | from dm_control.rl import control 11 | from dm_control.suite import base 12 | from dm_control.suite import common 13 | from dm_control.utils import containers 14 | from dm_control.utils import rewards 15 | 16 | 17 | # How long the simulation will run, in seconds. 18 | _DEFAULT_TIME_LIMIT = 10 19 | 20 | # Running speed above which reward is 1. 21 | _RUN_SPEED = 10 22 | _SPIN_SPEED = 5 23 | 24 | SUITE = containers.TaggedTasks() 25 | 26 | def make_cheetah(task, 27 | task_kwargs=None, 28 | environment_kwargs=None, 29 | visualize_reward=False): 30 | task_kwargs = task_kwargs or {} 31 | if environment_kwargs is not None: 32 | task_kwargs = task_kwargs.copy() 33 | task_kwargs['environment_kwargs'] = environment_kwargs 34 | env = SUITE[task](**task_kwargs) 35 | env.task.visualize_reward = visualize_reward 36 | return env 37 | 38 | def get_model_and_assets(): 39 | """Returns a tuple containing the model XML string and a dict of assets.""" 40 | return common.read_model('cheetah.xml'), common.ASSETS 41 | 42 | 43 | @SUITE.add('benchmarking') 44 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 45 | """Returns the run task.""" 46 | physics = Physics.from_xml_string(*get_model_and_assets()) 47 | task = Cheetah(forward=True,random=random) 48 | environment_kwargs = environment_kwargs or {} 49 | return control.Environment(physics, task, time_limit=time_limit, 50 | **environment_kwargs) 51 | 52 | @SUITE.add('benchmarking') 53 | def run_back(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 54 | """Returns the run task.""" 55 | physics = Physics.from_xml_string(*get_model_and_assets()) 56 | task = Cheetah(forward=False,random=random) 57 | environment_kwargs = environment_kwargs or {} 58 | return control.Environment(physics, task, time_limit=time_limit, 59 | **environment_kwargs) 60 | 61 | @SUITE.add('benchmarking') 62 | def flip_forward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 63 | """Returns the run task.""" 64 | physics = Physics.from_xml_string(*get_model_and_assets()) 65 | task = Cheetah(forward=False,flip=True,random=random) 66 | environment_kwargs = environment_kwargs or {} 67 | return control.Environment(physics, task, time_limit=time_limit, 68 | **environment_kwargs) 69 | 70 | @SUITE.add('benchmarking') 71 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 72 | """Returns the run task.""" 73 | physics = Physics.from_xml_string(*get_model_and_assets()) 74 | task = Cheetah(forward=True,flip=True,random=random) 75 | environment_kwargs = environment_kwargs or {} 76 | return control.Environment(physics, task, time_limit=time_limit, 77 | **environment_kwargs) 78 | 79 | @SUITE.add('benchmarking') 80 | def all(time_limit=_DEFAULT_TIME_LIMIT, 81 | random=None, 82 | environment_kwargs=None): 83 | """Returns the Run task.""" 84 | physics = Physics.from_xml_string(*get_model_and_assets()) 85 | task = Cheetah(forward=True,flip=True,random=random,all=True) 86 | environment_kwargs = environment_kwargs or {} 87 | return control.Environment(physics, 88 | task, 89 | time_limit=time_limit, 90 | **environment_kwargs) 91 | 92 | class Physics(mujoco.Physics): 93 | """Physics simulation with additional features for the Cheetah domain.""" 94 | 95 | def speed(self): 96 | """Returns the horizontal speed of the Cheetah.""" 97 | return self.named.data.sensordata['torso_subtreelinvel'][0] 98 | 99 | def angmomentum(self): 100 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 101 | return self.named.data.subtree_angmom['torso'][1] 102 | 103 | 104 | class Cheetah(base.Task): 105 | """A `Task` to train a running Cheetah.""" 106 | 107 | def __init__(self, forward=True, flip=False, random=None, all=False): 108 | 109 | self._forward = 1 if forward else -1 110 | self._flip = flip 111 | self._all = all 112 | super(Cheetah, self).__init__(random=random) 113 | 114 | 115 | def initialize_episode(self, physics): 116 | """Sets the state of the environment at the start of each episode.""" 117 | # The indexing below assumes that all joints have a single DOF. 118 | assert physics.model.nq == physics.model.njnt 119 | is_limited = physics.model.jnt_limited == 1 120 | lower, upper = physics.model.jnt_range[is_limited].T 121 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper) 122 | 123 | # Stabilize the model before the actual simulation. 124 | for _ in range(200): 125 | physics.step() 126 | 127 | physics.data.time = 0 128 | self._timeout_progress = 0 129 | super(Cheetah, self).initialize_episode(physics) 130 | 131 | def get_observation(self, physics): 132 | """Returns an observation of the state, ignoring horizontal position.""" 133 | obs = collections.OrderedDict() 134 | # Ignores horizontal position to maintain translational invariance. 135 | obs['position'] = physics.data.qpos[1:].copy() 136 | obs['velocity'] = physics.velocity() 137 | return obs 138 | 139 | def get_reward(self, physics): 140 | """Returns a reward to the agent.""" 141 | if self._flip: 142 | reward = rewards.tolerance(self._forward*physics.angmomentum(), 143 | bounds=(_SPIN_SPEED, float('inf')), 144 | margin=_SPIN_SPEED, 145 | value_at_margin=0, 146 | sigmoid='linear') 147 | 148 | else: 149 | reward = rewards.tolerance(self._forward*physics.speed(), 150 | bounds=(_RUN_SPEED, float('inf')), 151 | margin=_RUN_SPEED, 152 | value_at_margin=0, 153 | sigmoid='linear') 154 | 155 | if self._all: 156 | flip_fwd = rewards.tolerance(1*physics.angmomentum(), 157 | bounds=(_SPIN_SPEED, float('inf')), 158 | margin=_SPIN_SPEED, 159 | value_at_margin=0, 160 | sigmoid='linear') 161 | 162 | flip_bwd = rewards.tolerance(-1*physics.angmomentum(), 163 | bounds=(_SPIN_SPEED, float('inf')), 164 | margin=_SPIN_SPEED, 165 | value_at_margin=0, 166 | sigmoid='linear') 167 | 168 | run_fwd = rewards.tolerance(1*physics.speed(), 169 | bounds=(_RUN_SPEED, float('inf')), 170 | margin=_RUN_SPEED, 171 | value_at_margin=0, 172 | sigmoid='linear') 173 | 174 | run_bwd = rewards.tolerance(-1*physics.speed(), 175 | bounds=(_RUN_SPEED, float('inf')), 176 | margin=_RUN_SPEED, 177 | value_at_margin=0, 178 | sigmoid='linear') 179 | 180 | reward = { 181 | 'run-fwd': run_fwd, 182 | 'run-bwd': run_bwd, 183 | 'flip-fwd': flip_fwd, 184 | 'flip-bwd': flip_bwd 185 | } 186 | 187 | return reward -------------------------------------------------------------------------------- /dreamerv2/common/cdmc/cheetah.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /dreamerv2/common/cdmc/walker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | import os 9 | 10 | from dm_control import mujoco 11 | from dm_control.rl import control 12 | from dm_control.suite import base 13 | from dm_control.suite import common 14 | from dm_control.suite.utils import randomizers 15 | from dm_control.utils import containers 16 | from dm_control.utils import rewards 17 | from dm_control.utils import io as resources 18 | from dm_control import suite 19 | 20 | _DEFAULT_TIME_LIMIT = 25 21 | _CONTROL_TIMESTEP = .025 22 | 23 | # Minimal height of torso over foot above which stand reward is 1. 24 | _STAND_HEIGHT = 1.2 25 | 26 | # Horizontal speeds (meters/second) above which move reward is 1. 27 | _WALK_SPEED = 1 28 | _RUN_SPEED = 8 29 | _SPIN_SPEED = 5 30 | 31 | SUITE = containers.TaggedTasks() 32 | 33 | def make_walker(task, 34 | task_kwargs=None, 35 | environment_kwargs=None, 36 | visualize_reward=False): 37 | task_kwargs = task_kwargs or {} 38 | if environment_kwargs is not None: 39 | task_kwargs = task_kwargs.copy() 40 | task_kwargs['environment_kwargs'] = environment_kwargs 41 | env = SUITE[task](**task_kwargs) 42 | env.task.visualize_reward = visualize_reward 43 | return env 44 | 45 | def get_model_and_assets(): 46 | """Returns a tuple containing the model XML string and a dict of assets.""" 47 | root_dir = os.path.dirname(os.path.dirname(__file__)) 48 | xml = resources.GetResource(os.path.join(root_dir, 'cdmc', 49 | 'walker.xml')) 50 | return xml, common.ASSETS 51 | 52 | @SUITE.add('benchmarking') 53 | def flip(time_limit=_DEFAULT_TIME_LIMIT, 54 | random=None, 55 | environment_kwargs=None): 56 | """Returns the Run task.""" 57 | physics = Physics.from_xml_string(*get_model_and_assets()) 58 | task = PlanarWalker(move_speed=_RUN_SPEED, 59 | forward=True, 60 | flip=True, 61 | random=random) 62 | environment_kwargs = environment_kwargs or {} 63 | return control.Environment(physics, 64 | task, 65 | time_limit=time_limit, 66 | control_timestep=_CONTROL_TIMESTEP, 67 | **environment_kwargs) 68 | 69 | @SUITE.add('benchmarking') 70 | def all(time_limit=_DEFAULT_TIME_LIMIT, 71 | random=None, 72 | environment_kwargs=None): 73 | """Returns the Run task.""" 74 | physics = Physics.from_xml_string(*get_model_and_assets()) 75 | task = PlanarWalker(move_speed=_RUN_SPEED, 76 | forward=True, 77 | flip=True, 78 | all=True, 79 | random=random) 80 | environment_kwargs = environment_kwargs or {} 81 | return control.Environment(physics, 82 | task, 83 | time_limit=time_limit, 84 | control_timestep=_CONTROL_TIMESTEP, 85 | **environment_kwargs) 86 | 87 | class Physics(mujoco.Physics): 88 | """Physics simulation with additional features for the Walker domain.""" 89 | def torso_upright(self): 90 | """Returns projection from z-axes of torso to the z-axes of world.""" 91 | return self.named.data.xmat['torso', 'zz'] 92 | 93 | def torso_height(self): 94 | """Returns the height of the torso.""" 95 | return self.named.data.xpos['torso', 'z'] 96 | 97 | def horizontal_velocity(self): 98 | """Returns the horizontal velocity of the center-of-mass.""" 99 | return self.named.data.sensordata['torso_subtreelinvel'][0] 100 | 101 | def orientations(self): 102 | """Returns planar orientations of all bodies.""" 103 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel() 104 | 105 | def angmomentum(self): 106 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 107 | return self.named.data.subtree_angmom['torso'][1] 108 | 109 | 110 | class PlanarWalker(base.Task): 111 | """A planar walker task.""" 112 | def __init__(self, move_speed, forward=True, flip=False, random=None, all=False): 113 | """Initializes an instance of `PlanarWalker`. 114 | Args: 115 | move_speed: A float. If this value is zero, reward is given simply for 116 | standing up. Otherwise this specifies a target horizontal velocity for 117 | the walking task. 118 | random: Optional, either a `numpy.random.RandomState` instance, an 119 | integer seed for creating a new `RandomState`, or None to select a seed 120 | automatically (default). 121 | """ 122 | self._move_speed = move_speed 123 | self._forward = 1 if forward else -1 124 | self._flip = flip 125 | self._all = all 126 | super(PlanarWalker, self).__init__(random=random) 127 | 128 | def initialize_episode(self, physics): 129 | """Sets the state of the environment at the start of each episode. 130 | In 'standing' mode, use initial orientation and small velocities. 131 | In 'random' mode, randomize joint angles and let fall to the floor. 132 | Args: 133 | physics: An instance of `Physics`. 134 | """ 135 | randomizers.randomize_limited_and_rotational_joints( 136 | physics, self.random) 137 | super(PlanarWalker, self).initialize_episode(physics) 138 | 139 | def get_observation(self, physics): 140 | """Returns an observation of body orientations, height and velocites.""" 141 | obs = collections.OrderedDict() 142 | obs['orientations'] = physics.orientations() 143 | obs['height'] = physics.torso_height() 144 | obs['velocity'] = physics.velocity() 145 | return obs 146 | 147 | def get_reward(self, physics): 148 | """Returns a reward to the agent.""" 149 | standing = rewards.tolerance(physics.torso_height(), 150 | bounds=(_STAND_HEIGHT, float('inf')), 151 | margin=_STAND_HEIGHT / 2) 152 | upright = (1 + physics.torso_upright()) / 2 153 | stand_reward = (3 * standing + upright) / 4 154 | 155 | if self._flip: 156 | move_reward = rewards.tolerance(self._forward * 157 | physics.angmomentum(), 158 | bounds=(_SPIN_SPEED, float('inf')), 159 | margin=_SPIN_SPEED, 160 | value_at_margin=0, 161 | sigmoid='linear') 162 | else: 163 | move_reward = rewards.tolerance( 164 | self._forward * physics.horizontal_velocity(), 165 | bounds=(self._move_speed, float('inf')), 166 | margin=self._move_speed / 2, 167 | value_at_margin=0.5, 168 | sigmoid='linear') 169 | 170 | if self._all: 171 | 172 | walk_reward = rewards.tolerance( 173 | self._forward * physics.horizontal_velocity(), 174 | bounds=(_WALK_SPEED, float('inf')), 175 | margin=_WALK_SPEED / 2, 176 | value_at_margin=0.5, 177 | sigmoid='linear') 178 | 179 | run_reward = rewards.tolerance( 180 | self._forward * physics.horizontal_velocity(), 181 | bounds=(_RUN_SPEED, float('inf')), 182 | margin=_RUN_SPEED / 2, 183 | value_at_margin=0.5, 184 | sigmoid='linear') 185 | 186 | flip_reward = rewards.tolerance(self._forward * 187 | physics.angmomentum(), 188 | bounds=(_SPIN_SPEED, float('inf')), 189 | margin=_SPIN_SPEED, 190 | value_at_margin=0, 191 | sigmoid='linear') 192 | 193 | reward_dict = { 194 | 'stand': stand_reward, 195 | 'walk': stand_reward * (5*walk_reward + 1) / 6, 196 | 'run': stand_reward * (5*run_reward + 1) / 6, 197 | 'flip': flip_reward 198 | } 199 | return reward_dict 200 | else: 201 | return stand_reward * (5 * move_reward + 1) / 6 -------------------------------------------------------------------------------- /dreamerv2/common/cdmc/walker.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /dreamerv2/common/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | import re 4 | 5 | 6 | class Config(dict): 7 | 8 | SEP = '.' 9 | IS_PATTERN = re.compile(r'.*[^A-Za-z0-9_.-].*') 10 | 11 | def __init__(self, *args, **kwargs): 12 | mapping = dict(*args, **kwargs) 13 | mapping = self._flatten(mapping) 14 | mapping = self._ensure_keys(mapping) 15 | mapping = self._ensure_values(mapping) 16 | self._flat = mapping 17 | self._nested = self._nest(mapping) 18 | # Need to assign the values to the base class dictionary so that 19 | # conversion to dict does not lose the content. 20 | super().__init__(self._nested) 21 | 22 | @property 23 | def flat(self): 24 | return self._flat.copy() 25 | 26 | def save(self, filename): 27 | filename = pathlib.Path(filename) 28 | if filename.suffix == '.json': 29 | filename.write_text(json.dumps(dict(self))) 30 | elif filename.suffix in ('.yml', '.yaml'): 31 | import ruamel.yaml as yaml 32 | with filename.open('w') as f: 33 | yaml.safe_dump(dict(self), f) 34 | else: 35 | raise NotImplementedError(filename.suffix) 36 | 37 | @classmethod 38 | def load(cls, filename): 39 | filename = pathlib.Path(filename) 40 | if filename.suffix == '.json': 41 | return cls(json.loads(filename.read_text())) 42 | elif filename.suffix in ('.yml', '.yaml'): 43 | import ruamel.yaml as yaml 44 | return cls(yaml.safe_load(filename.read_text())) 45 | else: 46 | raise NotImplementedError(filename.suffix) 47 | 48 | def parse_flags(self, argv=None, known_only=False, help_exists=None): 49 | from . import flags 50 | return flags.Flags(self).parse(argv, known_only, help_exists) 51 | 52 | def __contains__(self, name): 53 | try: 54 | self[name] 55 | return True 56 | except KeyError: 57 | return False 58 | 59 | def __getattr__(self, name): 60 | if name.startswith('_'): 61 | return super().__getattr__(name) 62 | try: 63 | return self[name] 64 | except KeyError: 65 | raise AttributeError(name) 66 | 67 | def __getitem__(self, name): 68 | result = self._nested 69 | for part in name.split(self.SEP): 70 | result = result[part] 71 | if isinstance(result, dict): 72 | result = type(self)(result) 73 | return result 74 | 75 | def __setattr__(self, key, value): 76 | if key.startswith('_'): 77 | return super().__setattr__(key, value) 78 | message = f"Tried to set key '{key}' on immutable config. Use update()." 79 | raise AttributeError(message) 80 | 81 | def __setitem__(self, key, value): 82 | if key.startswith('_'): 83 | return super().__setitem__(key, value) 84 | message = f"Tried to set key '{key}' on immutable config. Use update()." 85 | raise AttributeError(message) 86 | 87 | def __reduce__(self): 88 | return (type(self), (dict(self),)) 89 | 90 | def __str__(self): 91 | lines = ['\nConfig:'] 92 | keys, vals, typs = [], [], [] 93 | for key, val in self.flat.items(): 94 | keys.append(key + ':') 95 | vals.append(self._format_value(val)) 96 | typs.append(self._format_type(val)) 97 | max_key = max(len(k) for k in keys) if keys else 0 98 | max_val = max(len(v) for v in vals) if vals else 0 99 | for key, val, typ in zip(keys, vals, typs): 100 | key = key.ljust(max_key) 101 | val = val.ljust(max_val) 102 | lines.append(f'{key} {val} ({typ})') 103 | return '\n'.join(lines) 104 | 105 | def update(self, *args, **kwargs): 106 | result = self._flat.copy() 107 | inputs = self._flatten(dict(*args, **kwargs)) 108 | for key, new in inputs.items(): 109 | if self.IS_PATTERN.match(key): 110 | pattern = re.compile(key) 111 | keys = {k for k in result if pattern.match(k)} 112 | else: 113 | keys = [key] 114 | if not keys: 115 | raise KeyError(f'Unknown key or pattern {key}.') 116 | for key in keys: 117 | old = result[key] 118 | try: 119 | if isinstance(old, int) and isinstance(new, float): 120 | if float(int(new)) != new: 121 | message = f"Cannot convert fractional float {new} to int." 122 | raise ValueError(message) 123 | result[key] = type(old)(new) 124 | except (ValueError, TypeError): 125 | raise TypeError( 126 | f"Cannot convert '{new}' to type '{type(old).__name__}' " + 127 | f"of value '{old}' for key '{key}'.") 128 | return type(self)(result) 129 | 130 | def _flatten(self, mapping): 131 | result = {} 132 | for key, value in mapping.items(): 133 | if isinstance(value, dict): 134 | for k, v in self._flatten(value).items(): 135 | if self.IS_PATTERN.match(key) or self.IS_PATTERN.match(k): 136 | combined = f'{key}\\{self.SEP}{k}' 137 | else: 138 | combined = f'{key}{self.SEP}{k}' 139 | result[combined] = v 140 | else: 141 | result[key] = value 142 | return result 143 | 144 | def _nest(self, mapping): 145 | result = {} 146 | for key, value in mapping.items(): 147 | parts = key.split(self.SEP) 148 | node = result 149 | for part in parts[:-1]: 150 | if part not in node: 151 | node[part] = {} 152 | node = node[part] 153 | node[parts[-1]] = value 154 | return result 155 | 156 | def _ensure_keys(self, mapping): 157 | for key in mapping: 158 | assert not self.IS_PATTERN.match(key), key 159 | return mapping 160 | 161 | def _ensure_values(self, mapping): 162 | result = json.loads(json.dumps(mapping)) 163 | for key, value in result.items(): 164 | if isinstance(value, list): 165 | value = tuple(value) 166 | if isinstance(value, tuple): 167 | if len(value) == 0: 168 | message = 'Empty lists are disallowed because their type is unclear.' 169 | raise TypeError(message) 170 | if not isinstance(value[0], (str, float, int, bool)): 171 | message = 'Lists can only contain strings, floats, ints, bools' 172 | message += f' but not {type(value[0])}' 173 | raise TypeError(message) 174 | if not all(isinstance(x, type(value[0])) for x in value[1:]): 175 | message = 'Elements of a list must all be of the same type.' 176 | raise TypeError(message) 177 | result[key] = value 178 | return result 179 | 180 | def _format_value(self, value): 181 | if isinstance(value, (list, tuple)): 182 | return '[' + ', '.join(self._format_value(x) for x in value) + ']' 183 | return str(value) 184 | 185 | def _format_type(self, value): 186 | if isinstance(value, (list, tuple)): 187 | assert len(value) > 0, value 188 | return self._format_type(value[0]) + 's' 189 | return str(type(value).__name__) 190 | -------------------------------------------------------------------------------- /dreamerv2/common/counter.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | @functools.total_ordering 5 | class Counter: 6 | 7 | def __init__(self, initial=0): 8 | self.value = initial 9 | 10 | def __int__(self): 11 | return int(self.value) 12 | 13 | def __eq__(self, other): 14 | return int(self) == other 15 | 16 | def __ne__(self, other): 17 | return int(self) != other 18 | 19 | def __lt__(self, other): 20 | return int(self) < other 21 | 22 | def __add__(self, other): 23 | return int(self) + other 24 | 25 | def increment(self, amount=1): 26 | self.value += amount 27 | -------------------------------------------------------------------------------- /dreamerv2/common/dists.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | from tensorflow_probability import distributions as tfd 4 | 5 | 6 | # Patch to ignore seed to avoid synchronization across GPUs. 7 | _orig_random_categorical = tf.random.categorical 8 | def random_categorical(*args, **kwargs): 9 | kwargs['seed'] = None 10 | return _orig_random_categorical(*args, **kwargs) 11 | tf.random.categorical = random_categorical 12 | 13 | # Patch to ignore seed to avoid synchronization across GPUs. 14 | _orig_random_normal = tf.random.normal 15 | def random_normal(*args, **kwargs): 16 | kwargs['seed'] = None 17 | return _orig_random_normal(*args, **kwargs) 18 | tf.random.normal = random_normal 19 | 20 | 21 | class SampleDist: 22 | 23 | def __init__(self, dist, samples=100): 24 | self._dist = dist 25 | self._samples = samples 26 | 27 | @property 28 | def name(self): 29 | return 'SampleDist' 30 | 31 | def __getattr__(self, name): 32 | return getattr(self._dist, name) 33 | 34 | def mean(self): 35 | samples = self._dist.sample(self._samples) 36 | return samples.mean(0) 37 | 38 | def mode(self): 39 | sample = self._dist.sample(self._samples) 40 | logprob = self._dist.log_prob(sample) 41 | return tf.gather(sample, tf.argmax(logprob))[0] 42 | 43 | def entropy(self): 44 | sample = self._dist.sample(self._samples) 45 | logprob = self.log_prob(sample) 46 | return -logprob.mean(0) 47 | 48 | 49 | class OneHotDist(tfd.OneHotCategorical): 50 | 51 | def __init__(self, logits=None, probs=None, dtype=None): 52 | self._sample_dtype = dtype or tf.float32 53 | super().__init__(logits=logits, probs=probs) 54 | 55 | def mode(self): 56 | return tf.cast(super().mode(), self._sample_dtype) 57 | 58 | def sample(self, sample_shape=(), seed=None): 59 | # Straight through biased gradient estimator. 60 | sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype) 61 | probs = self._pad(super().probs_parameter(), sample.shape) 62 | sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype) 63 | return sample 64 | 65 | def _pad(self, tensor, shape): 66 | tensor = super().probs_parameter() 67 | while len(tensor.shape) < len(shape): 68 | tensor = tensor[None] 69 | return tensor 70 | 71 | 72 | class TruncNormalDist(tfd.TruncatedNormal): 73 | 74 | def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): 75 | super().__init__(loc, scale, low, high) 76 | self._clip = clip 77 | self._mult = mult 78 | 79 | def sample(self, *args, **kwargs): 80 | event = super().sample(*args, **kwargs) 81 | if self._clip: 82 | clipped = tf.clip_by_value( 83 | event, self.low + self._clip, self.high - self._clip) 84 | event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped) 85 | if self._mult: 86 | event *= self._mult 87 | return event 88 | 89 | 90 | class TanhBijector(tfp.bijectors.Bijector): 91 | 92 | def __init__(self, validate_args=False, name='tanh'): 93 | super().__init__( 94 | forward_min_event_ndims=0, 95 | validate_args=validate_args, 96 | name=name) 97 | 98 | def _forward(self, x): 99 | return tf.nn.tanh(x) 100 | 101 | def _inverse(self, y): 102 | dtype = y.dtype 103 | y = tf.cast(y, tf.float32) 104 | y = tf.where( 105 | tf.less_equal(tf.abs(y), 1.), 106 | tf.clip_by_value(y, -0.99999997, 0.99999997), y) 107 | y = tf.atanh(y) 108 | y = tf.cast(y, dtype) 109 | return y 110 | 111 | def _forward_log_det_jacobian(self, x): 112 | log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) 113 | return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) 114 | -------------------------------------------------------------------------------- /dreamerv2/common/driver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Driver: 5 | 6 | def __init__(self, envs, **kwargs): 7 | self._envs = envs 8 | self._kwargs = kwargs 9 | self._on_steps = [] 10 | self._on_resets = [] 11 | self._on_episodes = [] 12 | self._act_spaces = [env.act_space for env in envs] 13 | self.reset() 14 | 15 | def on_step(self, callback): 16 | self._on_steps.append(callback) 17 | 18 | def on_reset(self, callback): 19 | self._on_resets.append(callback) 20 | 21 | def on_episode(self, callback): 22 | self._on_episodes.append(callback) 23 | 24 | def reset(self): 25 | self._obs = [None] * len(self._envs) 26 | self._eps = [None] * len(self._envs) 27 | self._state = None 28 | 29 | def __call__(self, policy, steps=0, episodes=0, policy_idx=0, save_img=False): 30 | step, episode = 0, 0 31 | while step < steps or episode < episodes: 32 | obs = { 33 | i: self._envs[i].reset() 34 | for i, ob in enumerate(self._obs) if ob is None or ob['is_last']} 35 | for i, ob in obs.items(): 36 | self._obs[i] = ob() if callable(ob) else ob 37 | act = {k: np.zeros(v.shape) for k, v in self._act_spaces[i].items()} 38 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()} 39 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_resets] 40 | self._eps[i] = [tran] 41 | obs = {k: np.stack([o[k] for o in self._obs]) for k in self._obs[0]} 42 | actions, self._state = policy(obs, self._state, **self._kwargs) 43 | actions = [ 44 | {k: np.array(actions[k][i]) for k in actions} 45 | for i in range(len(self._envs))] 46 | assert len(actions) == len(self._envs) 47 | # if episode == 0: 48 | should_save_img = save_img 49 | # else: 50 | # should_save_img = False 51 | obs = [e.step(a) for e, a in zip(self._envs, actions)] 52 | obs = [ob() if callable(ob) else ob for ob in obs] 53 | for i, (act, ob) in enumerate(zip(actions, obs)): 54 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()} 55 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_steps] 56 | self._eps[i].append(tran) 57 | step += 1 58 | if ob['is_last']: 59 | ep = self._eps[i] 60 | ep = {k: self._convert([t[k] for t in ep]) for k in ep[0]} 61 | [fn(ep, **self._kwargs) for fn in self._on_episodes] 62 | episode += 1 63 | self._obs = obs 64 | 65 | def _convert(self, value): 66 | value = np.array(value) 67 | if np.issubdtype(value.dtype, np.floating): 68 | return value.astype(np.float32) 69 | elif np.issubdtype(value.dtype, np.signedinteger): 70 | return value.astype(np.int32) 71 | elif np.issubdtype(value.dtype, np.uint8): 72 | return value.astype(np.uint8) 73 | return value 74 | -------------------------------------------------------------------------------- /dreamerv2/common/envs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import atexit 8 | import os 9 | import sys 10 | import threading 11 | import traceback 12 | import cloudpickle 13 | import gym 14 | import numpy as np 15 | 16 | 17 | from .cdmc import make_dmc_all 18 | from .recorder import Recorder 19 | 20 | class GymWrapper: 21 | 22 | def __init__(self, env, obs_key='image', act_key='action'): 23 | self._env = env 24 | self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') 25 | self._act_is_dict = hasattr(self._env.action_space, 'spaces') 26 | self._obs_key = obs_key 27 | self._act_key = act_key 28 | 29 | def __getattr__(self, name): 30 | if name.startswith('__'): 31 | raise AttributeError(name) 32 | try: 33 | return getattr(self._env, name) 34 | except AttributeError: 35 | raise ValueError(name) 36 | 37 | @property 38 | def obs_space(self): 39 | if self._obs_is_dict: 40 | spaces = self._env.observation_space.spaces.copy() 41 | else: 42 | spaces = {self._obs_key: self._env.observation_space} 43 | return { 44 | **spaces, 45 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 46 | 'is_first': gym.spaces.Box(0, 1, (), dtype=bool), 47 | 'is_last': gym.spaces.Box(0, 1, (), dtype=bool), 48 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool), 49 | } 50 | 51 | @property 52 | def act_space(self): 53 | if self._act_is_dict: 54 | return self._env.action_space.spaces.copy() 55 | else: 56 | return {self._act_key: self._env.action_space} 57 | 58 | def step(self, action): 59 | if not self._act_is_dict: 60 | action = action[self._act_key] 61 | obs, reward, done, _, info = self._env.step(action) 62 | if not self._obs_is_dict: 63 | obs = {self._obs_key: obs} 64 | obs['reward'] = float(reward) 65 | obs['is_first'] = False 66 | obs['is_last'] = done 67 | obs['is_terminal'] = info.get('is_terminal', done) 68 | return obs 69 | 70 | def reset(self): 71 | obs = self._env.reset()[0] 72 | if not self._obs_is_dict: 73 | obs = {self._obs_key: obs} 74 | # print("obs 1:", obs) 75 | obs['reward'] = 0.0 76 | obs['is_first'] = True 77 | obs['is_last'] = False 78 | obs['is_terminal'] = False 79 | return obs 80 | 81 | def make_minigrid_env(task, fix_seed, seed): 82 | import gym_minigrid 83 | env = gym.make("MiniGrid-"+task) 84 | env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(env) 85 | 86 | if fix_seed: 87 | env = gym_minigrid.wrappers.ReseedWrapper(env, seeds=[seed]) 88 | 89 | env = GymWrapper(env) 90 | env = ResizeImage(env) 91 | if hasattr(env.act_space['action'], 'n'): 92 | env = OneHotAction(env) 93 | else: 94 | env = NormalizeAction(env) 95 | return env 96 | 97 | 98 | class DMC: 99 | 100 | def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, save_path=None): 101 | os.environ['MUJOCO_GL'] = 'egl' 102 | domain, task = name.split('_', 1) 103 | if task == 'all': 104 | self._dict_reward = True 105 | else: 106 | self._dict_reward = False 107 | if domain == 'cup': # Only domain with multiple words. 108 | domain = 'ball_in_cup' 109 | if domain == 'manip': 110 | from dm_control import manipulation 111 | self._env = manipulation.load(task + '_vision') 112 | elif domain == 'locom': 113 | from dm_control.locomotion.examples import basic_rodent_2020 114 | self._env = getattr(basic_rodent_2020, task)() 115 | elif task == 'all': 116 | import time 117 | seed = int(str(int((time.time()*10000)))[-6:]) # random seed generator 118 | self._env = make_dmc_all(domain, 119 | task, 120 | task_kwargs=dict(random=seed), 121 | environment_kwargs=dict(flat_observation=True), 122 | visualize_reward=False) 123 | else: 124 | from dm_control import suite 125 | self._env = suite.load(domain, task) 126 | self._action_repeat = action_repeat 127 | self._size = size 128 | if camera in (-1, None): 129 | camera = dict( 130 | quadruped_walk=2, quadruped_run=2, quadruped_escape=2, 131 | quadruped_fetch=2, locom_rodent_maze_forage=1, 132 | locom_rodent_two_touch=1, 133 | ).get(name, 0) 134 | self._camera = camera 135 | self._ignored_keys = [] 136 | save_path.mkdir(parents=True, exist_ok=True) 137 | self.save_path = save_path 138 | for key, value in self._env.observation_spec().items(): 139 | if value.shape == (0,): 140 | print(f"Ignoring empty observation key '{key}'.") 141 | self._ignored_keys.append(key) 142 | 143 | @property 144 | def obs_space(self): 145 | spaces = { 146 | 'image': gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), 147 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 148 | 'is_first': gym.spaces.Box(0, 1, (), dtype=bool), 149 | 'is_last': gym.spaces.Box(0, 1, (), dtype=bool), 150 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool), 151 | } 152 | for key, value in self._env.observation_spec().items(): 153 | if key in self._ignored_keys: 154 | continue 155 | if value.dtype == np.float64: 156 | spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, np.float32) 157 | elif value.dtype == np.uint8: 158 | spaces[key] = gym.spaces.Box(0, 255, value.shape, np.uint8) 159 | else: 160 | raise NotImplementedError(value.dtype) 161 | return spaces 162 | 163 | @property 164 | def act_space(self): 165 | spec = self._env.action_spec() 166 | action = gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) 167 | return {'action': action} 168 | 169 | def step(self, action): 170 | assert np.isfinite(action['action']).all(), action['action'] 171 | if self._dict_reward: 172 | reward = [] 173 | else: 174 | reward = 0.0 175 | for _ in range(self._action_repeat): 176 | time_step = self._env.step(action['action']) 177 | if self._dict_reward: 178 | curr_reward = [] 179 | for key, val in time_step.reward.items(): 180 | curr_reward.append(val) 181 | if len(reward) == 0: 182 | reward = curr_reward 183 | else: 184 | reward = [sum(x) for x in zip(reward, curr_reward)] 185 | else: 186 | reward += time_step.reward or 0.0 187 | if time_step.last(): 188 | break 189 | assert time_step.discount in (0, 1) 190 | image = self._env.physics.render(*self._size, camera_id=self._camera) 191 | obs = { 192 | 'reward': reward, 193 | 'is_first': False, 194 | 'is_last': time_step.last(), 195 | 'is_terminal': time_step.discount == 0, 196 | 'image': image, 197 | } 198 | obs.update({ 199 | k: v for k, v in dict(time_step.observation).items() 200 | if k not in self._ignored_keys}) 201 | return obs 202 | 203 | def reset(self): 204 | time_step = self._env.reset() 205 | obs = { 206 | 'reward': 0.0, 207 | 'is_first': True, 208 | 'is_last': False, 209 | 'is_terminal': False, 210 | 'image': self._env.physics.render(*self._size, camera_id=self._camera), 211 | } 212 | obs.update({ 213 | k: v for k, v in dict(time_step.observation).items() 214 | if k not in self._ignored_keys}) 215 | return obs 216 | 217 | 218 | class Atari: 219 | 220 | LOCK = threading.Lock() 221 | 222 | def __init__( 223 | self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, 224 | life_done=False, sticky=True, all_actions=False, save_path=None): 225 | assert size[0] == size[1] 226 | import gym.wrappers 227 | import gym.envs.atari 228 | if name == 'james_bond': 229 | name = 'jamesbond' 230 | with self.LOCK: 231 | env = gym.envs.atari.AtariEnv( 232 | game=name, obs_type='rgb', frameskip=1, 233 | repeat_action_probability=0.25 if sticky else 0.0, 234 | full_action_space=all_actions) 235 | # Avoid unnecessary rendering in inner env. 236 | env._get_obs = lambda: None 237 | # Tell wrapper that the inner env has no action repeat. 238 | env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') 239 | self._env = gym.wrappers.AtariPreprocessing( 240 | env, noops, action_repeat, size[0], life_done, grayscale) 241 | save_path.mkdir(parents=True, exist_ok=True) 242 | self.save_path = save_path 243 | self._size = size 244 | self._grayscale = grayscale 245 | 246 | @property 247 | def obs_space(self): 248 | shape = self._size + (1 if self._grayscale else 3,) 249 | return { 250 | 'image': gym.spaces.Box(0, 255, shape, np.uint8), 251 | 'ram': gym.spaces.Box(0, 255, (128,), np.uint8), 252 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 253 | 'is_first': gym.spaces.Box(0, 1, (), dtype=bool), 254 | 'is_last': gym.spaces.Box(0, 1, (), dtype=bool), 255 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool), 256 | } 257 | 258 | @property 259 | def act_space(self): 260 | return {'action': self._env.action_space} 261 | 262 | def step(self, action): 263 | image, reward, done, info = self._env.step(action['action']) 264 | if self._grayscale: 265 | image = image[..., None] 266 | return { 267 | 'image': image, 268 | 'ram': self._env.env.ale.getRAM(), #if not self.record_video else self._env._env.env.ale.getRAM(), 269 | 'reward': reward, 270 | 'is_first': False, 271 | 'is_last': done, 272 | 'is_terminal': done, 273 | } 274 | 275 | def reset(self): 276 | with self.LOCK: 277 | image = self._env.reset() 278 | if self._grayscale: 279 | image = image[..., None] 280 | return { 281 | 'image': image, 282 | 'ram': self._env.env.ale.getRAM(), #if not self.record_video else self._env._env.env.ale.getRAM(), 283 | # 'ram': self._env.env._get_ram() if not self.record_video else self._env._env.env._get_ram(), 284 | 'reward': 0.0, 285 | 'is_first': True, 286 | 'is_last': False, 287 | 'is_terminal': False, 288 | } 289 | 290 | def close(self): 291 | return self._env.close() 292 | 293 | 294 | class Crafter: 295 | 296 | def __init__(self, outdir=None, reward=True, seed=None, save_path=None): 297 | import crafter 298 | self._env = crafter.Env(reward=reward, seed=seed) 299 | self._env = Recorder( 300 | self._env, outdir, 301 | save_stats=True, 302 | save_video=False, 303 | save_episode=False, 304 | ) 305 | if save_path: 306 | save_path.mkdir(parents=True, exist_ok=True) 307 | self.save_path = save_path 308 | self._achievements = crafter.constants.achievements.copy() 309 | 310 | @property 311 | def obs_space(self): 312 | spaces = { 313 | 'image': self._env.observation_space, 314 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 315 | 'is_first': gym.spaces.Box(0, 1, (), dtype=bool), 316 | 'is_last': gym.spaces.Box(0, 1, (), dtype=bool), 317 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool), 318 | 'log_reward': gym.spaces.Box(-np.inf, np.inf, (), np.float32), 319 | } 320 | spaces.update({ 321 | f'log_achievement_{k}': gym.spaces.Box(0, 2 ** 31 - 1, (), np.int32) 322 | for k in self._achievements}) 323 | return spaces 324 | 325 | @property 326 | def act_space(self): 327 | return {'action': self._env.action_space} 328 | 329 | def step(self, action): 330 | image, reward, done, info = self._env.step(action['action']) 331 | obs = { 332 | 'image': image, 333 | 'reward': reward, 334 | 'is_first': False, 335 | 'is_last': done, 336 | 'is_terminal': info['discount'] == 0, 337 | 'log_reward': info['reward'], 338 | } 339 | obs.update({ 340 | f'log_achievement_{k}': v 341 | for k, v in info['achievements'].items()}) 342 | return obs 343 | 344 | def reset(self): 345 | obs = { 346 | 'image': self._env.reset(), 347 | 'reward': 0.0, 348 | 'is_first': True, 349 | 'is_last': False, 350 | 'is_terminal': False, 351 | 'log_reward': 0.0, 352 | } 353 | obs.update({ 354 | f'log_achievement_{k}': 0 355 | for k in self._achievements}) 356 | return obs 357 | 358 | 359 | class Dummy: 360 | 361 | def __init__(self): 362 | pass 363 | 364 | @property 365 | def obs_space(self): 366 | return { 367 | 'image': gym.spaces.Box(0, 255, (64, 64, 3), dtype=np.uint8), 368 | 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 369 | 'is_first': gym.spaces.Box(0, 1, (), dtype=bool), 370 | 'is_last': gym.spaces.Box(0, 1, (), dtype=bool), 371 | 'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool), 372 | } 373 | 374 | @property 375 | def act_space(self): 376 | return {'action': gym.spaces.Box(-1, 1, (6,), dtype=np.float32)} 377 | 378 | def step(self, action): 379 | return { 380 | 'image': np.zeros((64, 64, 3)), 381 | 'reward': 0.0, 382 | 'is_first': False, 383 | 'is_last': False, 384 | 'is_terminal': False, 385 | } 386 | 387 | def reset(self): 388 | return { 389 | 'image': np.zeros((64, 64, 3)), 390 | 'reward': 0.0, 391 | 'is_first': True, 392 | 'is_last': False, 393 | 'is_terminal': False, 394 | } 395 | 396 | 397 | class TimeLimit: 398 | 399 | def __init__(self, env, duration): 400 | self._env = env 401 | self._duration = duration 402 | self._step = None 403 | 404 | def __getattr__(self, name): 405 | if name.startswith('__'): 406 | raise AttributeError(name) 407 | try: 408 | return getattr(self._env, name) 409 | except AttributeError: 410 | raise ValueError(name) 411 | 412 | def step(self, action): 413 | assert self._step is not None, 'Must reset environment.' 414 | obs = self._env.step(action) 415 | self._step += 1 416 | if self._duration and self._step >= self._duration: 417 | obs['is_last'] = True 418 | self._step = None 419 | return obs 420 | 421 | def reset(self): 422 | self._step = 0 423 | return self._env.reset() 424 | 425 | 426 | class NormalizeAction: 427 | 428 | def __init__(self, env, key='action'): 429 | self._env = env 430 | self._key = key 431 | space = env.act_space[key] 432 | self._mask = np.isfinite(space.low) & np.isfinite(space.high) 433 | self._low = np.where(self._mask, space.low, -1) 434 | self._high = np.where(self._mask, space.high, 1) 435 | 436 | def __getattr__(self, name): 437 | if name.startswith('__'): 438 | raise AttributeError(name) 439 | try: 440 | return getattr(self._env, name) 441 | except AttributeError: 442 | raise ValueError(name) 443 | 444 | @property 445 | def act_space(self): 446 | low = np.where(self._mask, -np.ones_like(self._low), self._low) 447 | high = np.where(self._mask, np.ones_like(self._low), self._high) 448 | space = gym.spaces.Box(low, high, dtype=np.float32) 449 | return {**self._env.act_space, self._key: space} 450 | 451 | def step(self, action): 452 | orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low 453 | orig = np.where(self._mask, orig, action[self._key]) 454 | return self._env.step({**action, self._key: orig}) 455 | 456 | 457 | class OneHotAction: 458 | 459 | def __init__(self, env, key='action'): 460 | 461 | assert hasattr(env.act_space[key], 'n') 462 | self._env = env 463 | self._key = key 464 | self._random = np.random.RandomState() 465 | 466 | def __getattr__(self, name): 467 | if name.startswith('__'): 468 | raise AttributeError(name) 469 | try: 470 | return getattr(self._env, name) 471 | except AttributeError: 472 | raise ValueError(name) 473 | 474 | @property 475 | def act_space(self): 476 | shape = (self._env.act_space[self._key].n,) 477 | space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) 478 | space.sample = self._sample_action 479 | space.n = shape[0] 480 | return {**self._env.act_space, self._key: space} 481 | 482 | def step(self, action): 483 | index = np.argmax(action[self._key]).astype(int) 484 | reference = np.zeros_like(action[self._key]) 485 | reference[index] = 1 486 | if not np.allclose(reference, action[self._key]): 487 | raise ValueError(f'Invalid one-hot action:\n{action}') 488 | return self._env.step({**action, self._key: index}) 489 | 490 | def reset(self): 491 | return self._env.reset() 492 | 493 | def _sample_action(self): 494 | actions = self._env.act_space.n 495 | index = self._random.randint(0, actions) 496 | reference = np.zeros(actions, dtype=np.float32) 497 | reference[index] = 1.0 498 | return reference 499 | 500 | 501 | class ResizeImage: 502 | 503 | def __init__(self, env, size=(64, 64)): 504 | self._env = env 505 | self._size = size 506 | self._keys = [ 507 | k for k, v in env.obs_space.items() 508 | if v.shape and len(v.shape) > 1 and v.shape[:2] != size] 509 | print(f'Resizing keys {",".join(self._keys)} to {self._size}.') 510 | if self._keys: 511 | from PIL import Image 512 | self._Image = Image 513 | 514 | def __getattr__(self, name): 515 | if name.startswith('__'): 516 | raise AttributeError(name) 517 | try: 518 | return getattr(self._env, name) 519 | except AttributeError: 520 | raise ValueError(name) 521 | 522 | @property 523 | def obs_space(self): 524 | spaces = self._env.obs_space 525 | new_space = {} 526 | for key in self._keys: 527 | shape = self._size + spaces[key].shape[2:] 528 | new_space[key] = gym.spaces.Box(0, 255, shape, np.uint8) 529 | return new_space 530 | 531 | def step(self, action): 532 | obs = self._env.step(action) 533 | for key in self._keys: 534 | obs[key] = self._resize(obs[key]) 535 | return obs 536 | 537 | def reset(self): 538 | obs = self._env.reset() 539 | for key in self._keys: 540 | obs[key] = self._resize(obs[key]) 541 | return obs 542 | 543 | def _resize(self, image): 544 | image = self._Image.fromarray(image) 545 | image = image.resize(self._size, self._Image.NEAREST) 546 | image = np.array(image) 547 | return image 548 | 549 | 550 | class RenderImage: 551 | 552 | def __init__(self, env, key='image'): 553 | self._env = env 554 | self._key = key 555 | self._shape = self._env.render().shape 556 | 557 | def __getattr__(self, name): 558 | if name.startswith('__'): 559 | raise AttributeError(name) 560 | try: 561 | return getattr(self._env, name) 562 | except AttributeError: 563 | raise ValueError(name) 564 | 565 | @property 566 | def obs_space(self): 567 | spaces = self._env.obs_space 568 | spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8) 569 | return spaces 570 | 571 | def step(self, action): 572 | obs = self._env.step(action) 573 | obs[self._key] = self._env.render('rgb_array') 574 | return obs 575 | 576 | def reset(self): 577 | obs = self._env.reset() 578 | obs[self._key] = self._env.render('rgb_array') 579 | return obs 580 | 581 | 582 | class Async: 583 | 584 | # Message types for communication via the pipe. 585 | _ACCESS = 1 586 | _CALL = 2 587 | _RESULT = 3 588 | _CLOSE = 4 589 | _EXCEPTION = 5 590 | 591 | def __init__(self, constructor, strategy='thread'): 592 | self._pickled_ctor = cloudpickle.dumps(constructor) 593 | if strategy == 'process': 594 | import multiprocessing as mp 595 | context = mp.get_context('spawn') 596 | elif strategy == 'thread': 597 | import multiprocessing.dummy as context 598 | else: 599 | raise NotImplementedError(strategy) 600 | self._strategy = strategy 601 | self._conn, conn = context.Pipe() 602 | self._process = context.Process(target=self._worker, args=(conn,)) 603 | atexit.register(self.close) 604 | self._process.start() 605 | self._receive() # Ready. 606 | self._obs_space = None 607 | self._act_space = None 608 | 609 | def access(self, name): 610 | self._conn.send((self._ACCESS, name)) 611 | return self._receive 612 | 613 | def call(self, name, *args, **kwargs): 614 | payload = name, args, kwargs 615 | self._conn.send((self._CALL, payload)) 616 | return self._receive 617 | 618 | def close(self): 619 | try: 620 | self._conn.send((self._CLOSE, None)) 621 | self._conn.close() 622 | except IOError: 623 | pass # The connection was already closed. 624 | self._process.join(5) 625 | 626 | @property 627 | def obs_space(self): 628 | if not self._obs_space: 629 | self._obs_space = self.access('obs_space')() 630 | return self._obs_space 631 | 632 | @property 633 | def act_space(self): 634 | if not self._act_space: 635 | self._act_space = self.access('act_space')() 636 | return self._act_space 637 | 638 | def step(self, action, blocking=False): 639 | promise = self.call('step', action) 640 | if blocking: 641 | return promise() 642 | else: 643 | return promise 644 | 645 | def reset(self, blocking=False): 646 | promise = self.call('reset') 647 | if blocking: 648 | return promise() 649 | else: 650 | return promise 651 | 652 | def _receive(self): 653 | try: 654 | message, payload = self._conn.recv() 655 | except (OSError, EOFError): 656 | raise RuntimeError('Lost connection to environment worker.') 657 | # Re-raise exceptions in the main process. 658 | if message == self._EXCEPTION: 659 | stacktrace = payload 660 | raise Exception(stacktrace) 661 | if message == self._RESULT: 662 | return payload 663 | raise KeyError('Received message of unexpected type {}'.format(message)) 664 | 665 | def _worker(self, conn): 666 | try: 667 | ctor = cloudpickle.loads(self._pickled_ctor) 668 | env = ctor() 669 | conn.send((self._RESULT, None)) # Ready. 670 | while True: 671 | try: 672 | # Only block for short times to have keyboard exceptions be raised. 673 | if not conn.poll(0.1): 674 | continue 675 | message, payload = conn.recv() 676 | except (EOFError, KeyboardInterrupt): 677 | break 678 | if message == self._ACCESS: 679 | name = payload 680 | result = getattr(env, name) 681 | conn.send((self._RESULT, result)) 682 | continue 683 | if message == self._CALL: 684 | name, args, kwargs = payload 685 | result = getattr(env, name)(*args, **kwargs) 686 | conn.send((self._RESULT, result)) 687 | continue 688 | if message == self._CLOSE: 689 | break 690 | raise KeyError('Received message of unknown type {}'.format(message)) 691 | except Exception: 692 | stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 693 | print('Error in environment process: {}'.format(stacktrace)) 694 | conn.send((self._EXCEPTION, stacktrace)) 695 | finally: 696 | try: 697 | conn.close() 698 | except IOError: 699 | pass # The connection was already closed. 700 | -------------------------------------------------------------------------------- /dreamerv2/common/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | from .cdmc import DMC_TASK_IDS 9 | import numpy as np 10 | from scipy.stats import gmean 11 | 12 | def get_stats_at_idx(driver, task, idx): 13 | """ 14 | Get the train / eval stats from driver from the idx env. 15 | """ 16 | prefix = "eval_" 17 | eps = driver._eps[idx] 18 | eval_data = defaultdict(list) 19 | if task == 'crafter_noreward': 20 | for ep in eps: 21 | for key, val in ep.items(): 22 | if 'log_achievement_' in key: 23 | eval_data[prefix + 'rew_'+key.split('log_achievement_')[1]].append(val.item()) 24 | eval_data[prefix + 'sr_'+key.split('log_achievement_')[1]].append(1 if val.item() > 0 else 0) 25 | eval_data['reward'].append(ep['log_reward'].item()) 26 | eval_data = {key: np.mean(val) for key, val in eval_data.items()} 27 | eval_data[prefix + 'crafter_score'] = gmean([val for key, val in eval_data.items() if 'eval_sr' in key]) 28 | elif task in DMC_TASK_IDS: 29 | rewards = [ep['reward'] for ep in eps[1:]] 30 | for idx, goal in enumerate(DMC_TASK_IDS[task]): 31 | eval_data[prefix + 'reward_' + goal] = np.sum([r[idx] for r in rewards]) 32 | else: 33 | eval_data[prefix + 'reward'] = np.sum([ep['reward'] for ep in eps]) 34 | return eval_data 35 | 36 | def get_stats(driver, task): 37 | per_env_data = defaultdict(list) 38 | num_envs = len(driver._envs) 39 | for i in range(num_envs): 40 | stat = get_stats_at_idx(driver, task, i) 41 | for k, v in stat.items(): 42 | per_env_data[k].append(v) 43 | data = {} 44 | for k, v in per_env_data.items(): 45 | data[k] = np.mean(v) 46 | return data 47 | 48 | def eval(driver, config, expl_policies, logdir): 49 | ## reward for the exploration agents 50 | mets = {} 51 | mean_pop = {} 52 | for idx in range(config.num_agents): 53 | policy = expl_policies[idx] 54 | driver(policy, episodes=config.eval_eps, policy_idx=idx) 55 | data = get_stats(driver, task=config.task) 56 | if idx == 0: 57 | for key, val in data.items(): 58 | mean_pop[key] = np.mean(val) 59 | else: 60 | for key,val in data.items(): 61 | mean_pop[key] += np.mean(val) 62 | mets.update({key: np.mean(val) for key, val in mean_pop.items()}) 63 | return mets -------------------------------------------------------------------------------- /dreamerv2/common/flags.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | 5 | class Flags: 6 | 7 | def __init__(self, *args, **kwargs): 8 | from .config import Config 9 | self._config = Config(*args, **kwargs) 10 | 11 | def parse(self, argv=None, known_only=False, help_exists=None): 12 | if help_exists is None: 13 | help_exists = not known_only 14 | if argv is None: 15 | argv = sys.argv[1:] 16 | if '--help' in argv: 17 | print('\nHelp:') 18 | lines = str(self._config).split('\n')[2:] 19 | print('\n'.join('--' + re.sub(r'[:,\[\]]', '', x) for x in lines)) 20 | help_exists and sys.exit() 21 | parsed = {} 22 | remaining = [] 23 | key = None 24 | vals = None 25 | for arg in argv: 26 | if arg.startswith('--'): 27 | if key: 28 | self._submit_entry(key, vals, parsed, remaining) 29 | if '=' in arg: 30 | key, val = arg.split('=', 1) 31 | vals = [val] 32 | else: 33 | key, vals = arg, [] 34 | else: 35 | if key: 36 | vals.append(arg) 37 | else: 38 | remaining.append(arg) 39 | self._submit_entry(key, vals, parsed, remaining) 40 | parsed = self._config.update(parsed) 41 | if known_only: 42 | return parsed, remaining 43 | else: 44 | for flag in remaining: 45 | if flag.startswith('--'): 46 | raise ValueError(f"Flag '{flag}' did not match any config keys.") 47 | assert not remaining, remaining 48 | return parsed 49 | 50 | def _submit_entry(self, key, vals, parsed, remaining): 51 | if not key and not vals: 52 | return 53 | if not key: 54 | vals = ', '.join(f"'{x}'" for x in vals) 55 | raise ValueError(f"Values {vals} were not preceeded by any flag.") 56 | name = key[len('--'):] 57 | if '=' in name: 58 | remaining.extend([key] + vals) 59 | return 60 | if self._config.IS_PATTERN.match(name): 61 | pattern = re.compile(name) 62 | keys = {k for k in self._config.flat if pattern.match(k)} 63 | elif name in self._config: 64 | keys = [name] 65 | else: 66 | keys = [] 67 | if not keys: 68 | remaining.extend([key] + vals) 69 | return 70 | if not vals: 71 | raise ValueError(f"Flag '{key}' was not followed by any values.") 72 | for key in keys: 73 | parsed[key] = self._parse_flag_value(self._config[key], vals, key) 74 | 75 | def _parse_flag_value(self, default, value, key): 76 | value = value if isinstance(value, (tuple, list)) else (value,) 77 | if isinstance(default, (tuple, list)): 78 | if len(value) == 1 and ',' in value[0]: 79 | value = value[0].split(',') 80 | return tuple(self._parse_flag_value(default[0], [x], key) for x in value) 81 | assert len(value) == 1, value 82 | value = str(value[0]) 83 | if default is None: 84 | return value 85 | if isinstance(default, bool): 86 | try: 87 | return bool(['False', 'True'].index(value)) 88 | except ValueError: 89 | message = f"Expected bool but got '{value}' for key '{key}'." 90 | raise TypeError(message) 91 | if isinstance(default, int): 92 | value = float(value) # Allow scientific notation for integers. 93 | if float(int(value)) != value: 94 | message = f"Expected int but got float '{value}' for key '{key}'." 95 | raise TypeError(message) 96 | return int(value) 97 | return type(default)(value) 98 | -------------------------------------------------------------------------------- /dreamerv2/common/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pathlib 4 | import time 5 | 6 | import numpy as np 7 | 8 | 9 | class Logger: 10 | 11 | def __init__(self, step, outputs, multiplier=1): 12 | self._step = step 13 | self._outputs = outputs 14 | self._multiplier = multiplier 15 | self._last_step = None 16 | self._last_time = None 17 | self._metrics = [] 18 | 19 | def add(self, mapping, prefix=None): 20 | step = int(self._step) * self._multiplier 21 | for name, value in dict(mapping).items(): 22 | name = f'{prefix}_{name}' if prefix else name 23 | value = np.array(value) 24 | if len(value.shape) not in (0, 2, 3, 4): 25 | raise ValueError( 26 | f"Shape {value.shape} for name '{name}' cannot be " 27 | "interpreted as scalar, image, or video.") 28 | self._metrics.append((step, name, value)) 29 | 30 | def scalar(self, name, value): 31 | self.add({name: value}) 32 | 33 | def image(self, name, value): 34 | self.add({name: value}) 35 | 36 | def video(self, name, value): 37 | self.add({name: value}) 38 | 39 | def write(self, fps=False): 40 | fps and self.scalar('fps', self._compute_fps()) 41 | if not self._metrics: 42 | return 43 | for output in self._outputs: 44 | output(self._metrics) 45 | self._metrics.clear() 46 | 47 | def _compute_fps(self): 48 | step = int(self._step) * self._multiplier 49 | if self._last_step is None: 50 | self._last_time = time.time() 51 | self._last_step = step 52 | return 0 53 | steps = step - self._last_step 54 | duration = time.time() - self._last_time 55 | self._last_time += duration 56 | self._last_step = step 57 | return steps / duration 58 | 59 | 60 | class TerminalOutput: 61 | 62 | def __call__(self, summaries): 63 | step = max(s for s, _, _, in summaries) 64 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0} 65 | formatted = {k: self._format_value(v) for k, v in scalars.items()} 66 | print(f'[{step}]', ' / '.join(f'{k} {v}' for k, v in formatted.items())) 67 | 68 | def _format_value(self, value): 69 | if value == 0: 70 | return '0' 71 | elif 0.01 < abs(value) < 10000: 72 | value = f'{value:.2f}' 73 | value = value.rstrip('0') 74 | value = value.rstrip('0') 75 | value = value.rstrip('.') 76 | return value 77 | else: 78 | value = f'{value:.1e}' 79 | value = value.replace('.0e', 'e') 80 | value = value.replace('+0', '') 81 | value = value.replace('+', '') 82 | value = value.replace('-0', '-') 83 | return value 84 | 85 | 86 | class JSONLOutput: 87 | 88 | def __init__(self, logdir): 89 | self._logdir = pathlib.Path(logdir).expanduser() 90 | 91 | def __call__(self, summaries): 92 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0} 93 | step = max(s for s, _, _, in summaries) 94 | with (self._logdir / 'metrics.jsonl').open('a') as f: 95 | f.write(json.dumps({'step': step, **scalars}) + '\n') 96 | 97 | 98 | class TensorBoardOutput: 99 | 100 | def __init__(self, logdir, fps=20): 101 | # The TensorFlow summary writer supports file protocols like gs://. We use 102 | # os.path over pathlib here to preserve those prefixes. 103 | self._logdir = os.path.expanduser(logdir) 104 | self._writer = None 105 | self._fps = fps 106 | 107 | def __call__(self, summaries): 108 | import tensorflow as tf 109 | self._ensure_writer() 110 | self._writer.set_as_default() 111 | for step, name, value in summaries: 112 | if len(value.shape) == 0: 113 | tf.summary.scalar('scalars/' + name, value, step) 114 | elif len(value.shape) == 2: 115 | tf.summary.image(name, value, step) 116 | elif len(value.shape) == 3: 117 | tf.summary.image(name, value, step) 118 | elif len(value.shape) == 4: 119 | self._video_summary(name, value, step) 120 | self._writer.flush() 121 | 122 | def _ensure_writer(self): 123 | if not self._writer: 124 | import tensorflow as tf 125 | self._writer = tf.summary.create_file_writer( 126 | self._logdir, max_queue=1000) 127 | 128 | def _video_summary(self, name, video, step): 129 | import tensorflow as tf 130 | import tensorflow.compat.v1 as tf1 131 | name = name if isinstance(name, str) else name.decode('utf-8') 132 | if np.issubdtype(video.dtype, np.floating): 133 | video = np.clip(255 * video, 0, 255).astype(np.uint8) 134 | try: 135 | T, H, W, C = video.shape 136 | summary = tf1.Summary() 137 | image = tf1.Summary.Image(height=H, width=W, colorspace=C) 138 | image.encoded_image_string = encode_gif(video, self._fps) 139 | summary.value.add(tag=name, image=image) 140 | tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) 141 | except (IOError, OSError) as e: 142 | print('GIF summaries require ffmpeg in $PATH.', e) 143 | tf.summary.image(name, video, step) 144 | 145 | 146 | def encode_gif(frames, fps): 147 | from subprocess import Popen, PIPE 148 | h, w, c = frames[0].shape 149 | pxfmt = {1: 'gray', 3: 'rgb24'}[c] 150 | cmd = ' '.join([ 151 | 'ffmpeg -y -f rawvideo -vcodec rawvideo', 152 | f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', 153 | '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', 154 | f'-r {fps:.02f} -f gif -']) 155 | proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) 156 | for image in frames: 157 | proc.stdin.write(image.tobytes()) 158 | out, err = proc.communicate() 159 | if proc.returncode: 160 | raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) 161 | del proc 162 | return out 163 | -------------------------------------------------------------------------------- /dreamerv2/common/nets.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.keras import layers as tfkl 6 | from tensorflow_probability import distributions as tfd 7 | from tensorflow.keras.mixed_precision import experimental as prec 8 | 9 | import common 10 | 11 | 12 | class EnsembleRSSM(common.Module): 13 | 14 | def __init__( 15 | self, ensemble=5, stoch=30, deter=200, hidden=200, discrete=False, 16 | act='elu', norm='none', std_act='softplus', min_std=0.1): 17 | super().__init__() 18 | self._ensemble = ensemble 19 | self._stoch = stoch 20 | self._deter = deter 21 | self._hidden = hidden 22 | self._discrete = discrete 23 | self._act = get_act(act) 24 | self._norm = norm 25 | self._std_act = std_act 26 | self._min_std = min_std 27 | self._cell = GRUCell(self._deter, norm=True) 28 | self._cast = lambda x: tf.cast(x, prec.global_policy().compute_dtype) 29 | 30 | def initial(self, batch_size): 31 | dtype = prec.global_policy().compute_dtype 32 | if self._discrete: 33 | state = dict( 34 | logit=tf.zeros([batch_size, self._stoch, self._discrete], dtype), 35 | stoch=tf.zeros([batch_size, self._stoch, self._discrete], dtype), 36 | deter=self._cell.get_initial_state(None, batch_size, dtype)) 37 | else: 38 | state = dict( 39 | mean=tf.zeros([batch_size, self._stoch], dtype), 40 | std=tf.zeros([batch_size, self._stoch], dtype), 41 | stoch=tf.zeros([batch_size, self._stoch], dtype), 42 | deter=self._cell.get_initial_state(None, batch_size, dtype)) 43 | return state 44 | 45 | @tf.function 46 | def observe(self, embed, action, is_first, state=None): 47 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape)))) 48 | if state is None: 49 | state = self.initial(tf.shape(action)[0]) 50 | post, prior = common.static_scan( 51 | lambda prev, inputs: self.obs_step(prev[0], *inputs), 52 | (swap(action), swap(embed), swap(is_first)), (state, state)) 53 | post = {k: swap(v) for k, v in post.items()} 54 | prior = {k: swap(v) for k, v in prior.items()} 55 | return post, prior 56 | 57 | @tf.function 58 | def imagine(self, action, state=None): 59 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape)))) 60 | if state is None: 61 | state = self.initial(tf.shape(action)[0]) 62 | assert isinstance(state, dict), state 63 | action = swap(action) 64 | prior = common.static_scan(self.img_step, action, state) 65 | prior = {k: swap(v) for k, v in prior.items()} 66 | return prior 67 | 68 | def get_feat(self, state): 69 | stoch = self._cast(state['stoch']) 70 | if self._discrete: 71 | shape = stoch.shape[:-2] + [self._stoch * self._discrete] 72 | stoch = tf.reshape(stoch, shape) 73 | return tf.concat([stoch, state['deter']], -1) 74 | 75 | def get_dist(self, state, ensemble=False): 76 | if ensemble: 77 | state = self._suff_stats_ensemble(state['deter']) 78 | if self._discrete: 79 | logit = state['logit'] 80 | logit = tf.cast(logit, tf.float32) 81 | dist = tfd.Independent(common.OneHotDist(logit), 1) 82 | else: 83 | mean, std = state['mean'], state['std'] 84 | mean = tf.cast(mean, tf.float32) 85 | std = tf.cast(std, tf.float32) 86 | dist = tfd.MultivariateNormalDiag(mean, std) 87 | return dist 88 | 89 | @tf.function 90 | def obs_step(self, prev_state, prev_action, embed, is_first, sample=True): 91 | # if is_first.any(): 92 | prev_state, prev_action = tf.nest.map_structure( 93 | lambda x: tf.einsum( 94 | 'b,b...->b...', 1.0 - is_first.astype(x.dtype), x), 95 | (prev_state, prev_action)) 96 | prior = self.img_step(prev_state, prev_action, sample) 97 | x = tf.concat([prior['deter'], embed], -1) 98 | x = self.get('obs_out', tfkl.Dense, self._hidden)(x) 99 | x = self.get('obs_out_norm', NormLayer, self._norm)(x) 100 | x = self._act(x) 101 | stats = self._suff_stats_layer('obs_dist', x) 102 | dist = self.get_dist(stats) 103 | stoch = dist.sample() if sample else dist.mode() 104 | post = {'stoch': stoch, 'deter': prior['deter'], **stats} 105 | return post, prior 106 | 107 | @tf.function 108 | def img_step(self, prev_state, prev_action, sample=True): 109 | prev_stoch = self._cast(prev_state['stoch']) 110 | prev_action = self._cast(prev_action) 111 | if self._discrete: 112 | shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete] 113 | prev_stoch = tf.reshape(prev_stoch, shape) 114 | x = tf.concat([prev_stoch, prev_action], -1) 115 | x = self.get('img_in', tfkl.Dense, self._hidden)(x) 116 | x = self.get('img_in_norm', NormLayer, self._norm)(x) 117 | x = self._act(x) 118 | deter = prev_state['deter'] 119 | x, deter = self._cell(x, [deter]) 120 | deter = deter[0] # Keras wraps the state in a list. 121 | stats = self._suff_stats_ensemble(x) 122 | index = tf.random.uniform((), 0, self._ensemble, tf.int32) 123 | stats = {k: v[index] for k, v in stats.items()} 124 | dist = self.get_dist(stats) 125 | stoch = dist.sample() if sample else dist.mode() 126 | prior = {'stoch': stoch, 'deter': deter, **stats} 127 | return prior 128 | 129 | def _suff_stats_ensemble(self, inp): 130 | bs = list(inp.shape[:-1]) 131 | inp = inp.reshape([-1, inp.shape[-1]]) 132 | stats = [] 133 | for k in range(self._ensemble): 134 | x = self.get(f'img_out_{k}', tfkl.Dense, self._hidden)(inp) 135 | x = self.get(f'img_out_norm_{k}', NormLayer, self._norm)(x) 136 | x = self._act(x) 137 | stats.append(self._suff_stats_layer(f'img_dist_{k}', x)) 138 | stats = { 139 | k: tf.stack([x[k] for x in stats], 0) 140 | for k, v in stats[0].items()} 141 | stats = { 142 | k: v.reshape([v.shape[0]] + bs + list(v.shape[2:])) 143 | for k, v in stats.items()} 144 | return stats 145 | 146 | def _suff_stats_layer(self, name, x): 147 | if self._discrete: 148 | x = self.get(name, tfkl.Dense, self._stoch * self._discrete, None)(x) 149 | logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete]) 150 | return {'logit': logit} 151 | else: 152 | x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x) 153 | mean, std = tf.split(x, 2, -1) 154 | std = { 155 | 'softplus': lambda: tf.nn.softplus(std), 156 | 'sigmoid': lambda: tf.nn.sigmoid(std), 157 | 'sigmoid2': lambda: 2 * tf.nn.sigmoid(std / 2), 158 | }[self._std_act]() 159 | std = std + self._min_std 160 | return {'mean': mean, 'std': std} 161 | 162 | def kl_loss(self, post, prior, forward, balance, free, free_avg): 163 | kld = tfd.kl_divergence 164 | sg = lambda x: tf.nest.map_structure(tf.stop_gradient, x) 165 | lhs, rhs = (prior, post) if forward else (post, prior) 166 | mix = balance if forward else (1 - balance) 167 | if balance == 0.5: 168 | value = kld(self.get_dist(lhs), self.get_dist(rhs)) 169 | loss = tf.maximum(value, free).mean() 170 | else: 171 | value_lhs = value = kld(self.get_dist(lhs), self.get_dist(sg(rhs))) 172 | value_rhs = kld(self.get_dist(sg(lhs)), self.get_dist(rhs)) 173 | if free_avg: 174 | loss_lhs = tf.maximum(value_lhs.mean(), free) 175 | loss_rhs = tf.maximum(value_rhs.mean(), free) 176 | else: 177 | loss_lhs = tf.maximum(value_lhs, free).mean() 178 | loss_rhs = tf.maximum(value_rhs, free).mean() 179 | loss = mix * loss_lhs + (1 - mix) * loss_rhs 180 | return loss, value 181 | 182 | 183 | class Encoder(common.Module): 184 | 185 | def __init__( 186 | self, shapes, cnn_keys=r'.*', mlp_keys=r'.*', act='elu', norm='none', 187 | cnn_depth=48, cnn_kernels=(4, 4, 4, 4), mlp_layers=[400, 400, 400, 400]): 188 | self.shapes = shapes 189 | self.cnn_keys = [ 190 | k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3] 191 | self.mlp_keys = [ 192 | k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1] 193 | print('Encoder CNN inputs:', list(self.cnn_keys)) 194 | print('Encoder MLP inputs:', list(self.mlp_keys)) 195 | self._act = get_act(act) 196 | self._norm = norm 197 | self._cnn_depth = cnn_depth 198 | self._cnn_kernels = cnn_kernels 199 | self._mlp_layers = mlp_layers 200 | 201 | @tf.function 202 | def __call__(self, data): 203 | key, shape = list(self.shapes.items())[0] 204 | batch_dims = data[key].shape[:-len(shape)] 205 | data = { 206 | k: tf.reshape(v, (-1,) + tuple(v.shape)[len(batch_dims):]) 207 | for k, v in data.items()} 208 | outputs = [] 209 | if self.cnn_keys: 210 | outputs.append(self._cnn({k: data[k] for k in self.cnn_keys})) 211 | if self.mlp_keys: 212 | outputs.append(self._mlp({k: data[k] for k in self.mlp_keys})) 213 | output = tf.concat(outputs, -1) 214 | return output.reshape(batch_dims + output.shape[1:]) 215 | 216 | def _cnn(self, data): 217 | x = tf.concat(list(data.values()), -1) 218 | x = x.astype(prec.global_policy().compute_dtype) 219 | for i, kernel in enumerate(self._cnn_kernels): 220 | depth = 2 ** i * self._cnn_depth 221 | x = self.get(f'conv{i}', tfkl.Conv2D, depth, kernel, 2)(x) 222 | x = self.get(f'convnorm{i}', NormLayer, self._norm)(x) 223 | x = self._act(x) 224 | return x.reshape(tuple(x.shape[:-3]) + (-1,)) 225 | 226 | def _mlp(self, data): 227 | x = tf.concat(list(data.values()), -1) 228 | x = x.astype(prec.global_policy().compute_dtype) 229 | for i, width in enumerate(self._mlp_layers): 230 | x = self.get(f'dense{i}', tfkl.Dense, width)(x) 231 | x = self.get(f'densenorm{i}', NormLayer, self._norm)(x) 232 | x = self._act(x) 233 | return x 234 | 235 | 236 | class Decoder(common.Module): 237 | 238 | def __init__( 239 | self, shapes, cnn_keys=r'.*', mlp_keys=r'.*', act='elu', norm='none', 240 | cnn_depth=48, cnn_kernels=(4, 4, 4, 4), mlp_layers=[400, 400, 400, 400]): 241 | self._shapes = shapes 242 | self.cnn_keys = [ 243 | k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3] 244 | self.mlp_keys = [ 245 | k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1] 246 | print('Decoder CNN outputs:', list(self.cnn_keys)) 247 | print('Decoder MLP outputs:', list(self.mlp_keys)) 248 | self._act = get_act(act) 249 | self._norm = norm 250 | self._cnn_depth = cnn_depth 251 | self._cnn_kernels = cnn_kernels 252 | self._mlp_layers = mlp_layers 253 | 254 | def __call__(self, features): 255 | features = tf.cast(features, prec.global_policy().compute_dtype) 256 | outputs = {} 257 | if self.cnn_keys: 258 | outputs.update(self._cnn(features)) 259 | if self.mlp_keys: 260 | outputs.update(self._mlp(features)) 261 | return outputs 262 | 263 | def _cnn(self, features): 264 | channels = {k: self._shapes[k][-1] for k in self.cnn_keys} 265 | ConvT = tfkl.Conv2DTranspose 266 | x = self.get('convin', tfkl.Dense, 32 * self._cnn_depth)(features) 267 | x = tf.reshape(x, [-1, 1, 1, 32 * self._cnn_depth]) 268 | for i, kernel in enumerate(self._cnn_kernels): 269 | depth = 2 ** (len(self._cnn_kernels) - i - 2) * self._cnn_depth 270 | act, norm = self._act, self._norm 271 | if i == len(self._cnn_kernels) - 1: 272 | depth, act, norm = sum(channels.values()), tf.identity, 'none' 273 | x = self.get(f'conv{i}', ConvT, depth, kernel, 2)(x) 274 | x = self.get(f'convnorm{i}', NormLayer, norm)(x) 275 | x = act(x) 276 | x = x.reshape(features.shape[:-1] + x.shape[1:]) 277 | means = tf.split(x, list(channels.values()), -1) 278 | dists = { 279 | key: tfd.Independent(tfd.Normal(mean, 1), 3) 280 | for (key, shape), mean in zip(channels.items(), means)} 281 | return dists 282 | 283 | def _mlp(self, features): 284 | shapes = {k: self._shapes[k] for k in self.mlp_keys} 285 | x = features 286 | for i, width in enumerate(self._mlp_layers): 287 | x = self.get(f'dense{i}', tfkl.Dense, width)(x) 288 | x = self.get(f'densenorm{i}', NormLayer, self._norm)(x) 289 | x = self._act(x) 290 | dists = {} 291 | for key, shape in shapes.items(): 292 | dists[key] = self.get(f'dense_{key}', DistLayer, shape)(x) 293 | return dists 294 | 295 | 296 | class MLP(common.Module): 297 | 298 | def __init__(self, shape, layers, units, act='elu', norm='none', **out): 299 | self._shape = (shape,) if isinstance(shape, int) else shape 300 | self._layers = layers 301 | self._units = units 302 | self._norm = norm 303 | self._act = get_act(act) 304 | self._out = out 305 | 306 | def __call__(self, features): 307 | x = tf.cast(features, prec.global_policy().compute_dtype) 308 | x = x.reshape([-1, x.shape[-1]]) 309 | for index in range(self._layers): 310 | x = self.get(f'dense{index}', tfkl.Dense, self._units)(x) 311 | x = self.get(f'norm{index}', NormLayer, self._norm)(x) 312 | x = self._act(x) 313 | x = x.reshape(features.shape[:-1] + [x.shape[-1]]) 314 | return self.get('out', DistLayer, self._shape, **self._out)(x) 315 | 316 | class MultiMLP(common.Module): 317 | # initial feature extraction layers 318 | def __init__(self, shape, layers, units, act='elu', norm='none', **out): 319 | self._shape = (shape,) if isinstance(shape, int) else shape 320 | self._layers = layers 321 | self._units = units 322 | self._norm = norm 323 | self._act = get_act(act) 324 | self._out = out 325 | 326 | def __call__(self, features, idx=0): 327 | x = tf.cast(features, prec.global_policy().compute_dtype) 328 | x = x.reshape([-1, x.shape[-1]]) 329 | for index in range(self._layers): 330 | x = self.get(f'dense{index}', tfkl.Dense, self._units)(x) 331 | x = self.get(f'norm{index}', NormLayer, self._norm)(x) 332 | x = self._act(x) 333 | x = x.reshape(features.shape[:-1] + [x.shape[-1]]) 334 | ## pass in idx for the MultiDistLayer! 335 | return self.get('out', MultiDistLayer, self._shape, **self._out)(x, idx) 336 | 337 | class GRUCell(tf.keras.layers.AbstractRNNCell): 338 | 339 | def __init__(self, size, norm=False, act='tanh', update_bias=-1, **kwargs): 340 | super().__init__() 341 | self._size = size 342 | self._act = get_act(act) 343 | self._norm = norm 344 | self._update_bias = update_bias 345 | self._layer = tfkl.Dense(3 * size, use_bias=norm is not None, **kwargs) 346 | if norm: 347 | self._norm = tfkl.LayerNormalization(dtype=tf.float32) 348 | 349 | @property 350 | def state_size(self): 351 | return self._size 352 | 353 | @tf.function 354 | def call(self, inputs, state): 355 | state = state[0] # Keras wraps the state in a list. 356 | parts = self._layer(tf.concat([inputs, state], -1)) 357 | if self._norm: 358 | dtype = parts.dtype 359 | parts = tf.cast(parts, tf.float32) 360 | parts = self._norm(parts) 361 | parts = tf.cast(parts, dtype) 362 | reset, cand, update = tf.split(parts, 3, -1) 363 | reset = tf.nn.sigmoid(reset) 364 | cand = self._act(reset * cand) 365 | update = tf.nn.sigmoid(update + self._update_bias) 366 | output = update * cand + (1 - update) * state 367 | return output, [output] 368 | 369 | 370 | class DistLayer(common.Module): 371 | 372 | def __init__( 373 | self, shape, dist='mse', min_std=0.1, init_std=0.0): 374 | self._shape = shape 375 | self._dist = dist 376 | self._min_std = min_std 377 | self._init_std = init_std 378 | 379 | def __call__(self, inputs): 380 | out = self.get('out', tfkl.Dense, np.prod(self._shape))(inputs) 381 | out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], self._shape], 0)) 382 | out = tf.cast(out, tf.float32) 383 | if self._dist in ('normal', 'tanh_normal', 'trunc_normal'): 384 | std = self.get('std', tfkl.Dense, np.prod(self._shape))(inputs) 385 | std = tf.reshape(std, tf.concat([tf.shape(inputs)[:-1], self._shape], 0)) 386 | std = tf.cast(std, tf.float32) 387 | if self._dist == 'mse': 388 | dist = tfd.Normal(out, 1.0) 389 | return tfd.Independent(dist, len(self._shape)) 390 | if self._dist == 'normal': 391 | dist = tfd.Normal(out, std) 392 | return tfd.Independent(dist, len(self._shape)) 393 | if self._dist == 'binary': 394 | dist = tfd.Bernoulli(out) 395 | return tfd.Independent(dist, len(self._shape)) 396 | if self._dist == 'tanh_normal': 397 | mean = 5 * tf.tanh(out / 5) 398 | std = tf.nn.softplus(std + self._init_std) + self._min_std 399 | dist = tfd.Normal(mean, std) 400 | dist = tfd.TransformedDistribution(dist, common.TanhBijector()) 401 | dist = tfd.Independent(dist, len(self._shape)) 402 | return common.SampleDist(dist) 403 | if self._dist == 'trunc_normal': 404 | std = 2 * tf.nn.sigmoid((std + self._init_std) / 2) + self._min_std 405 | dist = common.TruncNormalDist(tf.tanh(out), std, -1, 1) 406 | return tfd.Independent(dist, 1) 407 | if self._dist == 'onehot': 408 | return common.OneHotDist(out) 409 | raise NotImplementedError(self._dist) 410 | 411 | class MultiDistLayer(common.Module): 412 | 413 | def __init__( 414 | self, shape, dist='mse', min_std=0.1, init_std=0.0): 415 | self._shape = shape 416 | self._dist = dist 417 | self._min_std = min_std 418 | self._init_std = init_std 419 | 420 | def __call__(self, inputs, idx=0): 421 | out = self.get(f'out{idx}', tfkl.Dense, np.prod(self._shape))(inputs) 422 | out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], self._shape], 0)) 423 | out = tf.cast(out, tf.float32) 424 | if self._dist in ('normal', 'tanh_normal', 'trunc_normal'): 425 | std = self.get(f'std{idx}', tfkl.Dense, np.prod(self._shape))(inputs) 426 | std = tf.reshape(std, tf.concat([tf.shape(inputs)[:-1], self._shape], 0)) 427 | std = tf.cast(std, tf.float32) 428 | if self._dist == 'mse': 429 | dist = tfd.Normal(out, 1.0) 430 | return tfd.Independent(dist, len(self._shape)) 431 | if self._dist == 'normal': 432 | dist = tfd.Normal(out, std) 433 | return tfd.Independent(dist, len(self._shape)) 434 | if self._dist == 'binary': 435 | dist = tfd.Bernoulli(out) 436 | return tfd.Independent(dist, len(self._shape)) 437 | if self._dist == 'tanh_normal': 438 | mean = 5 * tf.tanh(out / 5) 439 | std = tf.nn.softplus(std + self._init_std) + self._min_std 440 | dist = tfd.Normal(mean, std) 441 | dist = tfd.TransformedDistribution(dist, common.TanhBijector()) 442 | dist = tfd.Independent(dist, len(self._shape)) 443 | return common.SampleDist(dist) 444 | if self._dist == 'trunc_normal': 445 | std = 2 * tf.nn.sigmoid((std + self._init_std) / 2) + self._min_std 446 | dist = common.TruncNormalDist(tf.tanh(out), std, -1, 1) 447 | return tfd.Independent(dist, 1) 448 | if self._dist == 'onehot': 449 | return common.OneHotDist(out) 450 | raise NotImplementedError(self._dist) 451 | 452 | class NormLayer(common.Module): 453 | 454 | def __init__(self, name): 455 | if name == 'none': 456 | self._layer = None 457 | elif name == 'layer': 458 | self._layer = tfkl.LayerNormalization() 459 | else: 460 | raise NotImplementedError(name) 461 | 462 | def __call__(self, features): 463 | if not self._layer: 464 | return features 465 | return self._layer(features) 466 | 467 | 468 | def get_act(name): 469 | if name == 'none': 470 | return tf.identity 471 | if name == 'mish': 472 | return lambda x: x * tf.math.tanh(tf.nn.softplus(x)) 473 | elif hasattr(tf.nn, name): 474 | return getattr(tf.nn, name) 475 | elif hasattr(tf, name): 476 | return getattr(tf, name) 477 | else: 478 | raise NotImplementedError(name) 479 | -------------------------------------------------------------------------------- /dreamerv2/common/other.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import contextlib 3 | import re 4 | import time 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow_probability import distributions as tfd 9 | 10 | from . import dists 11 | from . import tfutils 12 | 13 | 14 | class RandomAgent: 15 | 16 | def __init__(self, act_space, logprob=False): 17 | self.act_space = act_space['action'] 18 | self.logprob = logprob 19 | if hasattr(self.act_space, 'n'): 20 | self._dist = dists.OneHotDist(tf.zeros(self.act_space.n)) 21 | else: 22 | dist = tfd.Uniform(self.act_space.low, self.act_space.high) 23 | self._dist = tfd.Independent(dist, 1) 24 | 25 | def __call__(self, obs, state=None, mode=None): 26 | action = self._dist.sample(len(obs['is_first'])) 27 | output = {'action': action} 28 | if self.logprob: 29 | output['logprob'] = self._dist.log_prob(action) 30 | return output, None 31 | 32 | 33 | def static_scan(fn, inputs, start, reverse=False): 34 | last = start 35 | outputs = [[] for _ in tf.nest.flatten(start)] 36 | indices = range(tf.nest.flatten(inputs)[0].shape[0]) 37 | if reverse: 38 | indices = reversed(indices) 39 | for index in indices: 40 | inp = tf.nest.map_structure(lambda x: x[index], inputs) 41 | last = fn(last, inp) 42 | [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] 43 | if reverse: 44 | outputs = [list(reversed(x)) for x in outputs] 45 | outputs = [tf.stack(x, 0) for x in outputs] 46 | return tf.nest.pack_sequence_as(start, outputs) 47 | 48 | 49 | def schedule(string, step): 50 | try: 51 | return float(string) 52 | except ValueError: 53 | step = tf.cast(step, tf.float32) 54 | match = re.match(r'linear\((.+),(.+),(.+)\)', string) 55 | if match: 56 | initial, final, duration = [float(group) for group in match.groups()] 57 | mix = tf.clip_by_value(step / duration, 0, 1) 58 | return (1 - mix) * initial + mix * final 59 | match = re.match(r'warmup\((.+),(.+)\)', string) 60 | if match: 61 | warmup, value = [float(group) for group in match.groups()] 62 | scale = tf.clip_by_value(step / warmup, 0, 1) 63 | return scale * value 64 | match = re.match(r'exp\((.+),(.+),(.+)\)', string) 65 | if match: 66 | initial, final, halflife = [float(group) for group in match.groups()] 67 | return (initial - final) * 0.5 ** (step / halflife) + final 68 | match = re.match(r'horizon\((.+),(.+),(.+)\)', string) 69 | if match: 70 | initial, final, duration = [float(group) for group in match.groups()] 71 | mix = tf.clip_by_value(step / duration, 0, 1) 72 | horizon = (1 - mix) * initial + mix * final 73 | return 1 - 1 / horizon 74 | raise NotImplementedError(string) 75 | 76 | 77 | def lambda_return( 78 | reward, value, pcont, bootstrap, lambda_, axis): 79 | # Setting lambda=1 gives a discounted Monte Carlo return. 80 | # Setting lambda=0 gives a fixed 1-step return. 81 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) 82 | if isinstance(pcont, (int, float)): 83 | pcont = pcont * tf.ones_like(reward) 84 | dims = list(range(reward.shape.ndims)) 85 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] 86 | if axis != 0: 87 | reward = tf.transpose(reward, dims) 88 | value = tf.transpose(value, dims) 89 | pcont = tf.transpose(pcont, dims) 90 | if bootstrap is None: 91 | bootstrap = tf.zeros_like(value[-1]) 92 | next_values = tf.concat([value[1:], bootstrap[None]], 0) 93 | inputs = reward + pcont * next_values * (1 - lambda_) 94 | returns = static_scan( 95 | lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, 96 | (inputs, pcont), bootstrap, reverse=True) 97 | if axis != 0: 98 | returns = tf.transpose(returns, dims) 99 | return returns 100 | 101 | 102 | def action_noise(action, amount, act_space): 103 | if amount == 0: 104 | return action 105 | amount = tf.cast(amount, action.dtype) 106 | if hasattr(act_space, 'n'): 107 | probs = amount / action.shape[-1] + (1 - amount) * action 108 | return dists.OneHotDist(probs=probs).sample() 109 | else: 110 | return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) 111 | 112 | 113 | class StreamNorm(tfutils.Module): 114 | 115 | def __init__(self, shape=(), momentum=0.99, scale=1.0, eps=1e-8): 116 | # Momentum of 0 normalizes only based on the current batch. 117 | # Momentum of 1 disables normalization. 118 | self._shape = tuple(shape) 119 | self._momentum = momentum 120 | self._scale = scale 121 | self._eps = eps 122 | self.mag = tf.Variable(tf.ones(shape, tf.float64), False) 123 | 124 | def __call__(self, inputs): 125 | metrics = {} 126 | self.update(inputs) 127 | metrics['mean'] = inputs.mean() 128 | metrics['std'] = inputs.std() 129 | outputs = self.transform(inputs) 130 | metrics['normed_mean'] = outputs.mean() 131 | metrics['normed_std'] = outputs.std() 132 | return outputs, metrics 133 | 134 | def reset(self): 135 | self.mag.assign(tf.ones_like(self.mag)) 136 | 137 | def update(self, inputs): 138 | batch = inputs.reshape((-1,) + self._shape) 139 | mag = tf.abs(batch).mean(0).astype(tf.float64) 140 | self.mag.assign(self._momentum * self.mag + (1 - self._momentum) * mag) 141 | 142 | def transform(self, inputs): 143 | values = inputs.reshape((-1,) + self._shape) 144 | values /= self.mag.astype(inputs.dtype)[None] + self._eps 145 | values *= self._scale 146 | return values.reshape(inputs.shape) 147 | 148 | 149 | class Timer: 150 | 151 | def __init__(self): 152 | self._indurs = collections.defaultdict(list) 153 | self._outdurs = collections.defaultdict(list) 154 | self._start_times = {} 155 | self._end_times = {} 156 | 157 | @contextlib.contextmanager 158 | def section(self, name): 159 | self.start(name) 160 | yield 161 | self.end(name) 162 | 163 | def wrap(self, function, name): 164 | def wrapped(*args, **kwargs): 165 | with self.section(name): 166 | return function(*args, **kwargs) 167 | return wrapped 168 | 169 | def start(self, name): 170 | now = time.time() 171 | self._start_times[name] = now 172 | if name in self._end_times: 173 | last = self._end_times[name] 174 | self._outdurs[name].append(now - last) 175 | 176 | def end(self, name): 177 | now = time.time() 178 | self._end_times[name] = now 179 | self._indurs[name].append(now - self._start_times[name]) 180 | 181 | def result(self): 182 | metrics = {} 183 | for key in self._indurs: 184 | indurs = self._indurs[key] 185 | outdurs = self._outdurs[key] 186 | metrics[f'timer_count_{key}'] = len(indurs) 187 | metrics[f'timer_inside_{key}'] = np.sum(indurs) 188 | metrics[f'timer_outside_{key}'] = np.sum(outdurs) 189 | indurs.clear() 190 | outdurs.clear() 191 | return metrics 192 | 193 | 194 | class CarryOverState: 195 | 196 | def __init__(self, fn): 197 | self._fn = fn 198 | self._state = None 199 | 200 | def __call__(self, *args): 201 | self._state, out = self._fn(*args, self._state) 202 | return out 203 | -------------------------------------------------------------------------------- /dreamerv2/common/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import functools 4 | import itertools 5 | import json 6 | import multiprocessing as mp 7 | import os 8 | import pathlib 9 | import re 10 | import subprocess 11 | import warnings 12 | 13 | os.environ['NO_AT_BRIDGE'] = '1' # Hide X org false warning. 14 | 15 | import matplotlib 16 | matplotlib.use('Agg') 17 | import matplotlib.pyplot as plt 18 | import matplotlib.ticker as ticker 19 | import numpy as np 20 | import pandas as pd 21 | 22 | np.set_string_function(lambda x: f'') 23 | 24 | Run = collections.namedtuple('Run', 'task method seed xs ys') 25 | 26 | PALETTES = dict( 27 | discrete=( 28 | '#377eb8', '#4daf4a', '#984ea3', '#e41a1c', '#ff7f00', '#a65628', 29 | '#f781bf', '#888888', '#a6cee3', '#b2df8a', '#cab2d6', '#fb9a99', 30 | ), 31 | contrast=( 32 | '#0022ff', '#33aa00', '#ff0011', '#ddaa00', '#cc44dd', '#0088aa', 33 | '#001177', '#117700', '#990022', '#885500', '#553366', '#006666', 34 | ), 35 | gradient=( 36 | '#fde725', '#a0da39', '#4ac16d', '#1fa187', '#277f8e', '#365c8d', 37 | '#46327e', '#440154', 38 | ), 39 | baselines=( 40 | '#222222', '#666666', '#aaaaaa', '#cccccc', 41 | ), 42 | ) 43 | 44 | LEGEND = dict( 45 | fontsize='medium', numpoints=1, labelspacing=0, columnspacing=1.2, 46 | handlelength=1.5, handletextpad=0.5, loc='lower center') 47 | 48 | DEFAULT_BASELINES = ['d4pg', 'rainbow_sticky', 'human_gamer', 'impala'] 49 | 50 | 51 | def find_keys(args): 52 | filenames = [] 53 | for indir in args.indir: 54 | task = next(indir.iterdir()) # First only. 55 | for method in task.iterdir(): 56 | seed = next(indir.iterdir()) # First only. 57 | filenames += list(seed.glob('**/*.jsonl')) 58 | keys = set() 59 | for filename in filenames: 60 | keys |= set(load_jsonl(filename).columns) 61 | print(f'Keys ({len(keys)}):', ', '.join(keys), flush=True) 62 | 63 | 64 | def load_runs(args): 65 | total, toload = [], [] 66 | for indir in args.indir: 67 | filenames = list(indir.glob('**/*.jsonl')) 68 | total += filenames 69 | for filename in filenames: 70 | task, method, seed = filename.relative_to(indir).parts[:-1] 71 | if not any(p.search(task) for p in args.tasks): 72 | continue 73 | if not any(p.search(method) for p in args.methods): 74 | continue 75 | toload.append((filename, indir)) 76 | print(f'Loading {len(toload)} of {len(total)} runs...') 77 | jobs = [functools.partial(load_run, f, i, args) for f, i in toload] 78 | # Disable async data loading: 79 | # runs = [j() for j in jobs] 80 | with mp.Pool(10) as pool: 81 | promises = [pool.apply_async(j) for j in jobs] 82 | runs = [p.get() for p in promises] 83 | runs = [r for r in runs if r is not None] 84 | return runs 85 | 86 | 87 | def load_run(filename, indir, args): 88 | task, method, seed = filename.relative_to(indir).parts[:-1] 89 | prefix = f'indir{args.indir.index(indir)+1}_' 90 | if task == 'atari_jamesbond': 91 | task = 'atari_james_bond' 92 | seed = prefix + seed 93 | if args.prefix: 94 | method = prefix + method 95 | df = load_jsonl(filename) 96 | if df is None: 97 | print('Skipping empty run') 98 | return 99 | try: 100 | df = df[[args.xaxis, args.yaxis]].dropna() 101 | if args.maxval: 102 | df = df.replace([+np.inf], +args.maxval) 103 | df = df.replace([-np.inf], -args.maxval) 104 | df[args.yaxis] = df[args.yaxis].clip(-args.maxval, +args.maxval) 105 | except KeyError: 106 | return 107 | xs = df[args.xaxis].to_numpy() 108 | if args.xmult != 1: 109 | xs = xs.astype(np.float32) * args.xmult 110 | ys = df[args.yaxis].to_numpy() 111 | bins = { 112 | 'atari': 1e6, 113 | 'dmc': 1e4, 114 | 'crafter': 1e4, 115 | }.get(task.split('_')[0], 1e5) if args.bins == -1 else args.bins 116 | if bins: 117 | borders = np.arange(0, xs.max() + 1e-8, bins) 118 | xs, ys = bin_scores(xs, ys, borders) 119 | if not len(xs): 120 | print('Skipping empty run', task, method, seed) 121 | return 122 | return Run(task, method, seed, xs, ys) 123 | 124 | 125 | def load_baselines(patterns, prefix=False): 126 | runs = [] 127 | directory = pathlib.Path(__file__).parent.parent / 'scores' 128 | for filename in directory.glob('**/*_baselines.json'): 129 | for task, methods in json.loads(filename.read_text()).items(): 130 | for method, score in methods.items(): 131 | if prefix: 132 | method = f'baseline_{method}' 133 | if not any(p.search(method) for p in patterns): 134 | continue 135 | runs.append(Run(task, method, None, None, score)) 136 | return runs 137 | 138 | 139 | def stats(runs, baselines): 140 | tasks = sorted(set(r.task for r in runs)) 141 | methods = sorted(set(r.method for r in runs)) 142 | seeds = sorted(set(r.seed for r in runs)) 143 | baseline = sorted(set(r.method for r in baselines)) 144 | print('Loaded', len(runs), 'runs.') 145 | print(f'Tasks ({len(tasks)}):', ', '.join(tasks)) 146 | print(f'Methods ({len(methods)}):', ', '.join(methods)) 147 | print(f'Seeds ({len(seeds)}):', ', '.join(seeds)) 148 | print(f'Baselines ({len(baseline)}):', ', '.join(baseline)) 149 | 150 | 151 | def order_methods(runs, baselines, args): 152 | methods = [] 153 | for pattern in args.methods: 154 | for method in sorted(set(r.method for r in runs)): 155 | if pattern.search(method): 156 | if method not in methods: 157 | methods.append(method) 158 | if method not in args.colors: 159 | index = len(args.colors) % len(args.palette) 160 | args.colors[method] = args.palette[index] 161 | non_baseline_colors = len(args.colors) 162 | for pattern in args.baselines: 163 | for method in sorted(set(r.method for r in baselines)): 164 | if pattern.search(method): 165 | if method not in methods: 166 | methods.append(method) 167 | if method not in args.colors: 168 | index = len(args.colors) - non_baseline_colors 169 | index = index % len(PALETTES['baselines']) 170 | args.colors[method] = PALETTES['baselines'][index] 171 | return methods 172 | 173 | 174 | def figure(runs, methods, args): 175 | tasks = sorted(set(r.task for r in runs if r.xs is not None)) 176 | rows = int(np.ceil((len(tasks) + len(args.add)) / args.cols)) 177 | figsize = args.size[0] * args.cols, args.size[1] * rows 178 | fig, axes = plt.subplots(rows, args.cols, figsize=figsize, squeeze=False) 179 | for task, ax in zip(tasks, axes.flatten()): 180 | relevant = [r for r in runs if r.task == task] 181 | plot(task, ax, relevant, methods, args) 182 | for name, ax in zip(args.add, axes.flatten()[len(tasks):]): 183 | ax.set_facecolor((0.9, 0.9, 0.9)) 184 | if name == 'median': 185 | plot_combined( 186 | 'combined_median', ax, runs, methods, args, 187 | agg=lambda x: np.nanmedian(x, -1)) 188 | elif name == 'mean': 189 | plot_combined( 190 | 'combined_mean', ax, runs, methods, args, 191 | agg=lambda x: np.nanmean(x, -1)) 192 | elif name == 'gamer_median': 193 | plot_combined( 194 | 'combined_gamer_median', ax, runs, methods, args, 195 | lo='random', hi='human_gamer', 196 | agg=lambda x: np.nanmedian(x, -1)) 197 | elif name == 'gamer_mean': 198 | plot_combined( 199 | 'combined_gamer_mean', ax, runs, methods, args, 200 | lo='random', hi='human_gamer', 201 | agg=lambda x: np.nanmean(x, -1)) 202 | elif name == 'record_mean': 203 | plot_combined( 204 | 'combined_record_mean', ax, runs, methods, args, 205 | lo='random', hi='record', 206 | agg=lambda x: np.nanmean(x, -1)) 207 | elif name == 'clip_record_mean': 208 | plot_combined( 209 | 'combined_clipped_record_mean', ax, runs, methods, args, 210 | lo='random', hi='record', clip=True, 211 | agg=lambda x: np.nanmean(x, -1)) 212 | elif name == 'seeds': 213 | plot_combined( 214 | 'combined_seeds', ax, runs, methods, args, 215 | agg=lambda x: np.isfinite(x).sum(-1)) 216 | elif name == 'human_above': 217 | plot_combined( 218 | 'combined_above_human_gamer', ax, runs, methods, args, 219 | agg=lambda y: (y >= 1.0).astype(float).sum(-1)) 220 | elif name == 'human_below': 221 | plot_combined( 222 | 'combined_below_human_gamer', ax, runs, methods, args, 223 | agg=lambda y: (y <= 1.0).astype(float).sum(-1)) 224 | else: 225 | raise NotImplementedError(name) 226 | if args.xlim: 227 | for ax in axes[:-1].flatten(): 228 | ax.xaxis.get_offset_text().set_visible(False) 229 | if args.xlabel: 230 | for ax in axes[-1]: 231 | ax.set_xlabel(args.xlabel) 232 | if args.ylabel: 233 | for ax in axes[:, 0]: 234 | ax.set_ylabel(args.ylabel) 235 | for ax in axes.flatten()[len(tasks) + len(args.add):]: 236 | ax.axis('off') 237 | legend(fig, args.labels, ncol=args.legendcols, **LEGEND) 238 | return fig 239 | 240 | 241 | def plot(task, ax, runs, methods, args): 242 | assert runs 243 | try: 244 | title = task.split('_', 1)[1].replace('_', ' ').title() 245 | except IndexError: 246 | title = task.title() 247 | ax.set_title(title) 248 | xlim = [+np.inf, -np.inf] 249 | for index, method in enumerate(methods): 250 | relevant = [r for r in runs if r.method == method] 251 | if not relevant: 252 | continue 253 | if any(r.xs is None for r in relevant): 254 | baseline(index, method, ax, relevant, args) 255 | else: 256 | if args.agg == 'none': 257 | xs, ys = curve_lines(index, task, method, ax, relevant, args) 258 | else: 259 | xs, ys = curve_area(index, task, method, ax, relevant, args) 260 | if len(xs) == len(ys) == 0: 261 | print(f'Skipping empty: {task} {method}') 262 | continue 263 | xlim = [min(xlim[0], np.nanmin(xs)), max(xlim[1], np.nanmax(xs))] 264 | ax.ticklabel_format(axis='x', style='sci', scilimits=(0, 0)) 265 | steps = [1, 2, 2.5, 5, 10] 266 | ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps)) 267 | ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps)) 268 | if np.isfinite(xlim).all(): 269 | ax.set_xlim(args.xlim or xlim) 270 | if args.xlim: 271 | ticks = sorted({*ax.get_xticks(), *args.xlim}) 272 | ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]] 273 | ax.set_xticks(ticks) 274 | if args.ylim: 275 | ax.set_ylim(args.ylim) 276 | if args.ylimticks: 277 | ticks = sorted({*ax.get_yticks(), *args.ylim}) 278 | ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]] 279 | ax.set_yticks(ticks) 280 | 281 | 282 | def plot_combined( 283 | name, ax, runs, methods, args, agg, lo=None, hi=None, clip=False): 284 | tasks = sorted(set(run.task for run in runs if run.xs is not None)) 285 | seeds = list(set(run.seed for run in runs)) 286 | runs = [r for r in runs if r.task in tasks] # Discard unused baselines. 287 | # Bin all runs onto the same X steps. 288 | borders = sorted( 289 | [r.xs for r in runs if r.xs is not None], 290 | key=lambda x: np.nanmax(x))[-1] 291 | for index, run in enumerate(runs): 292 | if run.xs is None: 293 | continue 294 | xs, ys = bin_scores(run.xs, run.ys, borders, fill='last') 295 | runs[index] = run._replace(xs=xs, ys=ys) 296 | # Per-task normalization by low and high baseline. 297 | if lo or hi: 298 | mins = collections.defaultdict(list) 299 | maxs = collections.defaultdict(list) 300 | [mins[r.task].append(r.ys) for r in load_baselines([re.compile(lo)])] 301 | [maxs[r.task].append(r.ys) for r in load_baselines([re.compile(hi)])] 302 | mins = {task: min(ys) for task, ys in mins.items() if task in tasks} 303 | maxs = {task: max(ys) for task, ys in maxs.items() if task in tasks} 304 | missing_baselines = [] 305 | for task in tasks: 306 | if task not in mins or task not in maxs: 307 | missing_baselines.append(task) 308 | if set(missing_baselines) == set(tasks): 309 | print(f'No baselines found to normalize any tasks in {name} plot.') 310 | else: 311 | for task in missing_baselines: 312 | print(f'No baselines found to normalize {task} in {name} plot.') 313 | for index, run in enumerate(runs): 314 | if run.task not in mins or run.task not in maxs: 315 | continue 316 | ys = (run.ys - mins[run.task]) / (maxs[run.task] - mins[run.task]) 317 | if clip: 318 | ys = np.minimum(ys, 1.0) 319 | runs[index] = run._replace(ys=ys) 320 | # Aggregate across tasks but not methods or seeds. 321 | combined = [] 322 | for method, seed in itertools.product(methods, seeds): 323 | relevant = [r for r in runs if r.method == method and r.seed == seed] 324 | if not relevant: 325 | continue 326 | if relevant[0].xs is None: 327 | xs, ys = None, np.array([r.ys for r in relevant]) 328 | else: 329 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in relevant])) 330 | with warnings.catch_warnings(): # Ignore empty slice warnings. 331 | warnings.simplefilter('ignore', category=RuntimeWarning) 332 | combined.append(Run('combined', method, seed, xs, agg(ys))) 333 | plot(name, ax, combined, methods, args) 334 | 335 | 336 | def curve_lines(index, task, method, ax, runs, args): 337 | zorder = 10000 - 10 * index - 1 338 | for run in runs: 339 | color = args.colors[method] 340 | ax.plot(run.xs, run.ys, label=method, color=color, zorder=zorder) 341 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in runs])) 342 | return xs, ys 343 | 344 | 345 | def curve_area(index, task, method, ax, runs, args): 346 | xs, ys = stack_scores(*zip(*[(r.xs, r.ys) for r in runs])) 347 | with warnings.catch_warnings(): # NaN buckets remain NaN. 348 | warnings.simplefilter('ignore', category=RuntimeWarning) 349 | if args.agg == 'std1': 350 | mean, std = np.nanmean(ys, -1), np.nanstd(ys, -1) 351 | lo, mi, hi = mean - std, mean, mean + std 352 | elif args.agg == 'per0': 353 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (0, 50, 100)] 354 | elif args.agg == 'per5': 355 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (5, 50, 95)] 356 | elif args.agg == 'per25': 357 | lo, mi, hi = [np.nanpercentile(ys, k, -1) for k in (25, 50, 75)] 358 | else: 359 | raise NotImplementedError(args.agg) 360 | color = args.colors[method] 361 | kw = dict(color=color, zorder=1000 - 10 * index, alpha=0.1, linewidths=0) 362 | mask = ~np.isnan(mi) 363 | xs, lo, mi, hi = xs[mask], lo[mask], mi[mask], hi[mask] 364 | ax.fill_between(xs, lo, hi, **kw) 365 | ax.plot(xs, mi, label=method, color=color, zorder=10000 - 10 * index - 1) 366 | return xs, mi 367 | 368 | 369 | def baseline(index, method, ax, runs, args): 370 | assert all(run.xs is None for run in runs) 371 | ys = np.array([run.ys for run in runs]) 372 | mean, std = ys.mean(), ys.std() 373 | color = args.colors[method] 374 | kw = dict(color=color, zorder=500 - 20 * index - 1, alpha=0.1, linewidths=0) 375 | ax.fill_between([-np.inf, np.inf], [mean - std] * 2, [mean + std] * 2, **kw) 376 | kw = dict(ls='--', color=color, zorder=5000 - 10 * index - 1) 377 | ax.axhline(mean, label=method, **kw) 378 | 379 | 380 | def legend(fig, mapping=None, **kwargs): 381 | entries = {} 382 | for ax in fig.axes: 383 | for handle, label in zip(*ax.get_legend_handles_labels()): 384 | if mapping and label in mapping: 385 | label = mapping[label] 386 | entries[label] = handle 387 | leg = fig.legend(entries.values(), entries.keys(), **kwargs) 388 | leg.get_frame().set_edgecolor('white') 389 | extent = leg.get_window_extent(fig.canvas.get_renderer()) 390 | extent = extent.transformed(fig.transFigure.inverted()) 391 | yloc, xloc = kwargs['loc'].split() 392 | y0 = dict(lower=extent.y1, center=0, upper=0)[yloc] 393 | y1 = dict(lower=1, center=1, upper=extent.y0)[yloc] 394 | x0 = dict(left=extent.x1, center=0, right=0)[xloc] 395 | x1 = dict(left=1, center=1, right=extent.x0)[xloc] 396 | fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5) 397 | 398 | 399 | def save(fig, args): 400 | args.outdir.mkdir(parents=True, exist_ok=True) 401 | filename = args.outdir / 'curves.png' 402 | fig.savefig(filename, dpi=args.dpi) 403 | print('Saved to', filename) 404 | filename = args.outdir / 'curves.pdf' 405 | fig.savefig(filename) 406 | try: 407 | subprocess.call(['pdfcrop', str(filename), str(filename)]) 408 | except FileNotFoundError: 409 | print('Install texlive-extra-utils to crop PDF outputs.') 410 | 411 | 412 | def bin_scores(xs, ys, borders, reducer=np.nanmean, fill='nan'): 413 | order = np.argsort(xs) 414 | xs, ys = xs[order], ys[order] 415 | binned = [] 416 | with warnings.catch_warnings(): # Empty buckets become NaN. 417 | warnings.simplefilter('ignore', category=RuntimeWarning) 418 | for start, stop in zip(borders[:-1], borders[1:]): 419 | left = (xs <= start).sum() 420 | right = (xs <= stop).sum() 421 | if left < right: 422 | value = reducer(ys[left:right]) 423 | elif binned: 424 | value = {'nan': np.nan, 'last': binned[-1]}[fill] 425 | else: 426 | value = np.nan 427 | binned.append(value) 428 | return borders[1:], np.array(binned) 429 | 430 | 431 | def stack_scores(multiple_xs, multiple_ys, fill='last'): 432 | longest_xs = sorted(multiple_xs, key=lambda x: len(x))[-1] 433 | multiple_padded_ys = [] 434 | for xs, ys in zip(multiple_xs, multiple_ys): 435 | assert (longest_xs[:len(xs)] == xs).all(), (list(xs), list(longest_xs)) 436 | value = {'nan': np.nan, 'last': ys[-1]}[fill] 437 | padding = [value] * (len(longest_xs) - len(xs)) 438 | padded_ys = np.concatenate([ys, padding]) 439 | multiple_padded_ys.append(padded_ys) 440 | stacked_ys = np.stack(multiple_padded_ys, -1) 441 | return longest_xs, stacked_ys 442 | 443 | 444 | def load_jsonl(filename): 445 | try: 446 | with filename.open() as f: 447 | lines = list(f.readlines()) 448 | records = [] 449 | for index, line in enumerate(lines): 450 | try: 451 | records.append(json.loads(line)) 452 | except Exception: 453 | if index == len(lines) - 1: 454 | continue # Silently skip last line if it is incomplete. 455 | raise ValueError( 456 | f'Skipping invalid JSON line ({index+1}/{len(lines)+1}) in' 457 | f'{filename}: {line}') 458 | return pd.DataFrame(records) 459 | except ValueError as e: 460 | print('Invalid', filename, e) 461 | return None 462 | 463 | 464 | def save_runs(runs, filename): 465 | filename.parent.mkdir(parents=True, exist_ok=True) 466 | records = [] 467 | for run in runs: 468 | if run.xs is None: 469 | continue 470 | records.append(dict( 471 | task=run.task, method=run.method, seed=run.seed, 472 | xs=run.xs.tolist(), ys=run.ys.tolist())) 473 | runs = json.dumps(records) 474 | filename.write_text(runs) 475 | print('Saved', filename) 476 | 477 | 478 | def main(args): 479 | find_keys(args) 480 | runs = load_runs(args) 481 | save_runs(runs, args.outdir / 'runs.json') 482 | baselines = load_baselines(args.baselines, args.prefix) 483 | stats(runs, baselines) 484 | methods = order_methods(runs, baselines, args) 485 | if not runs: 486 | print('Noting to plot.') 487 | return 488 | # Adjust options based on loaded runs. 489 | tasks = set(r.task for r in runs) 490 | if 'auto' in args.add: 491 | index = args.add.index('auto') 492 | del args.add[index] 493 | atari = any(run.task.startswith('atari_') for run in runs) 494 | if len(tasks) < 2: 495 | pass 496 | elif atari: 497 | args.add[index:index] = [ 498 | 'gamer_median', 'gamer_mean', 'record_mean', 'clip_record_mean', 499 | ] 500 | else: 501 | args.add[index:index] = ['mean', 'median'] 502 | args.cols = min(args.cols, len(tasks) + len(args.add)) 503 | args.legendcols = min(args.legendcols, args.cols) 504 | print('Plotting...') 505 | fig = figure(runs + baselines, methods, args) 506 | save(fig, args) 507 | 508 | 509 | def parse_args(): 510 | boolean = lambda x: bool(['False', 'True'].index(x)) 511 | parser = argparse.ArgumentParser() 512 | parser.add_argument('--indir', nargs='+', type=pathlib.Path, required=True) 513 | parser.add_argument('--indir-prefix', type=pathlib.Path) 514 | parser.add_argument('--outdir', type=pathlib.Path, required=True) 515 | parser.add_argument('--subdir', type=boolean, default=True) 516 | parser.add_argument('--xaxis', type=str, default='step') 517 | parser.add_argument('--yaxis', type=str, default='eval_return') 518 | parser.add_argument('--tasks', nargs='+', default=[r'.*']) 519 | parser.add_argument('--methods', nargs='+', default=[r'.*']) 520 | parser.add_argument('--baselines', nargs='+', default=DEFAULT_BASELINES) 521 | parser.add_argument('--prefix', type=boolean, default=False) 522 | parser.add_argument('--bins', type=float, default=-1) 523 | parser.add_argument('--agg', type=str, default='std1') 524 | parser.add_argument('--size', nargs=2, type=float, default=[2.5, 2.3]) 525 | parser.add_argument('--dpi', type=int, default=80) 526 | parser.add_argument('--cols', type=int, default=6) 527 | parser.add_argument('--xlim', nargs=2, type=float, default=None) 528 | parser.add_argument('--ylim', nargs=2, type=float, default=None) 529 | parser.add_argument('--ylimticks', type=boolean, default=True) 530 | parser.add_argument('--xlabel', type=str, default=None) 531 | parser.add_argument('--ylabel', type=str, default=None) 532 | parser.add_argument('--xticks', type=int, default=6) 533 | parser.add_argument('--yticks', type=int, default=5) 534 | parser.add_argument('--xmult', type=float, default=1) 535 | parser.add_argument('--labels', nargs='+', default=None) 536 | parser.add_argument('--palette', nargs='+', default=['contrast']) 537 | parser.add_argument('--legendcols', type=int, default=4) 538 | parser.add_argument('--colors', nargs='+', default={}) 539 | parser.add_argument('--maxval', type=float, default=0) 540 | parser.add_argument('--add', nargs='+', type=str, default=['auto', 'seeds']) 541 | args = parser.parse_args() 542 | if args.subdir: 543 | args.outdir /= args.indir[0].stem 544 | if args.indir_prefix: 545 | args.indir = [args.indir_prefix / indir for indir in args.indir] 546 | args.indir = [d.expanduser() for d in args.indir] 547 | args.outdir = args.outdir.expanduser() 548 | if args.labels: 549 | assert len(args.labels) % 2 == 0 550 | args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])} 551 | if args.colors: 552 | assert len(args.colors) % 2 == 0 553 | args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])} 554 | args.tasks = [re.compile(p) for p in args.tasks] 555 | args.methods = [re.compile(p) for p in args.methods] 556 | args.baselines = [re.compile(p) for p in args.baselines] 557 | if 'return' not in args.yaxis: 558 | args.baselines = [] 559 | if args.prefix is None: 560 | args.prefix = len(args.indir) > 1 561 | if len(args.palette) == 1 and args.palette[0] in PALETTES: 562 | args.palette = 10 * PALETTES[args.palette[0]] 563 | if len(args.add) == 1 and args.add[0] == 'none': 564 | args.add = [] 565 | return args 566 | 567 | 568 | if __name__ == '__main__': 569 | main(parse_args()) 570 | -------------------------------------------------------------------------------- /dreamerv2/common/ram_annotations.py: -------------------------------------------------------------------------------- 1 | """In gym, the RAM is represented as an 128-element array, where each element in the array can range from 0 to 255 2 | The atari_dict below is organized as so: 3 | key: the name of the game 4 | value: the game dictionary 5 | Game dictionary is organized as: 6 | key: state variable name 7 | value: the element in the RAM array where the value of that state variable is stored 8 | e.g. the value of the x coordinate of the player in asteroids is stored in the 73rd (counting up from 0) 9 | element of the RAM array (when the player in asteroids moves horizontally, ram_array[73] should change 10 | in value correspondingly) 11 | """ 12 | """ MZR player_direction values: 13 | 72: facing left, 14 | 40: facing left, climbing down ladder/rope 15 | 24: facing left, climbing up ladder/rope 16 | 128: facing right 17 | 32: facing right, climbing down ladder/rope 18 | 16: facing right climbing up ladder/rope """ 19 | 20 | atari_dict = { 21 | "asteroids": dict(enemy_asteroids_y=[3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 16, 17, 18, 19], 22 | enemy_asteroids_x=[21, 22, 23, 24, 25, 26, 27, 30, 31, 32, 33, 34, 35, 36, 37], 23 | player_x=73, 24 | player_y=74, 25 | num_lives_direction=60, 26 | player_score_high=61, 27 | player_score_low=62, 28 | player_missile_x1=83, 29 | player_missile_x2=84, 30 | player_missile_y1=86, 31 | player_missile_y2=87, 32 | player_missile1_direction=89, 33 | player_missile2_direction=90), 34 | 35 | "battlezone": dict( # red_enemy_x=75, 36 | blue_tank_facing_direction=46, # 17 left 21 forward 29 right 37 | blue_tank_size_y=47, # tank gets larger as it gets closer 38 | blue_tank_x=48, 39 | blue_tank2_facing_direction=52, 40 | blue_tank2_size_y=53, 41 | blue_tank2_x=54, 42 | num_lives=58, 43 | missile_y=105, 44 | compass_needles_angle=84, 45 | angle_of_tank=4, # as shown by what the mountains look like 46 | left_tread_position=59, # got to mod this number by 8 to get unique values 47 | right_tread_position=60, # got to mod this number by 8 to get unique values 48 | crosshairs_color=108, # 0 if black 46 if yellow 49 | score=29), 50 | 51 | "berzerk": dict(player_x=19, 52 | player_y=11, 53 | player_direction=14, 54 | player_missile_x=22, 55 | player_missile_y=23, 56 | player_missile_direction=21, 57 | robot_missile_direction=26, 58 | robot_missile_x=29, 59 | robot_missile_y=30, 60 | num_lives=90, 61 | robots_killed_count=91, 62 | game_level=92, 63 | enemy_evilOtto_x=46, 64 | enemy_evilOtto_y=89, 65 | enemy_robots_x=range(65, 73), 66 | enemy_robots_y=range(56, 65), 67 | player_score=range(93, 96)), 68 | 69 | "bowling": dict(ball_x=30, 70 | ball_y=41, 71 | player_x=29, 72 | player_y=40, 73 | frame_number_display=36, 74 | pin_existence=range(57, 67), 75 | score=33), 76 | 77 | "boxing": dict(player_x=32, 78 | player_y=34, 79 | enemy_x=33, 80 | enemy_y=35, 81 | enemy_score=19, 82 | clock=17, 83 | player_score=18), 84 | 85 | "breakout": dict(ball_x=99, 86 | ball_y=101, 87 | player_x=72, 88 | blocks_hit_count=77, 89 | block_bit_map=range(30), # see breakout bitmaps tab 90 | score=84), # 5 for each hit 91 | 92 | "demonattack": dict(level=62, 93 | player_x=22, 94 | enemy_x1=17, 95 | enemy_x2=18, 96 | enemy_x3=19, 97 | missile_y=21, 98 | enemy_y1=69, 99 | enemy_y2=70, 100 | enemy_y3=71, 101 | num_lives=114), 102 | 103 | "freeway": dict(player_y=14, 104 | score=103, 105 | enemy_car_x=range(108, 118)), # which lane the car collided with player 106 | 107 | "frostbite": dict( 108 | top_row_iceflow_x=34, 109 | second_row_iceflow_x=33, 110 | third_row_iceflow_x=32, 111 | fourth_row_iceflow_x=31, 112 | enemy_bear_x=104, 113 | num_lives=76, 114 | igloo_blocks_count=77, # 255 is none and 15 is all " 115 | enemy_x=range(84, 88), # 84 bottom row - 87 top row 116 | player_x=102, 117 | player_y=100, 118 | player_direction=4, 119 | score=[72, 73, 74]), 120 | 121 | "hero": dict(player_x=27, 122 | player_y=31, 123 | power_meter=43, 124 | room_number=28, 125 | level_number=117, 126 | dynamite_count=50, 127 | score=[56, 57]), 128 | 129 | 130 | 131 | "montezumarevenge": dict(room_number=3, 132 | player_x=42, 133 | player_y=43, 134 | player_direction=52, # 72: facing left, 40: facing left, climbing down ladder/rope 24: facing left, climbing up ladder/rope 128: facing right 32: facing right, climbing down ladder/rope, 16: facing right climbing up ladder/rope 135 | enemy_skull_x=47, 136 | enemy_skull_y=46, 137 | key_monster_x=44, 138 | key_monster_y=45, 139 | level=57, 140 | num_lives=58, 141 | items_in_inventory_count=61, 142 | room_state=62, 143 | score_0=19, 144 | score_1=20, 145 | score_2=21), 146 | 147 | "mspacman": dict(enemy_sue_x=6, 148 | enemy_inky_x=7, 149 | enemy_pinky_x=8, 150 | enemy_blinky_x=9, 151 | enemy_sue_y=12, 152 | enemy_inky_y=13, 153 | enemy_pinky_y=14, 154 | enemy_blinky_y=15, 155 | player_x=10, 156 | player_y=16, 157 | fruit_x=11, 158 | fruit_y=17, 159 | ghosts_count=19, 160 | player_direction=56, 161 | dots_eaten_count=119, 162 | player_score=120, 163 | num_lives=123), 164 | 165 | "pitfall": dict(player_x=97, # 8-148 166 | player_y=105, # 21-86 except for when respawning then 0-255 with confusing wraparound 167 | enemy_logs_x=98, # 0-160 168 | enemy_scorpion_x=99, 169 | # player_y_on_ladder= 108, # 0-20 170 | # player_collided_with_rope= 5, #yes if bit 6 is 1 171 | bottom_of_rope_y=18, # 0-20 varies even when you can't see rope 172 | clock_sec=89, 173 | clock_min=88 174 | ), 175 | 176 | "pong": dict(player_y=51, 177 | player_x=46, 178 | enemy_y=50, 179 | enemy_x=45, 180 | ball_x=49, 181 | ball_y=54, 182 | enemy_score=13, 183 | player_score=14), 184 | 185 | "privateeye": dict(player_x=63, 186 | player_y=86, 187 | room_number=92, 188 | clock=[67, 69], 189 | player_direction=58, 190 | score=[73, 74], 191 | dove_x=48, 192 | dove_y=39), 193 | 194 | "qbert": dict(player_x=43, 195 | player_y=67, 196 | player_column=35, 197 | red_enemy_column=69, 198 | green_enemy_column=105, 199 | score=[89, 90, 91], # binary coded decimal score 200 | tile_color=[ 21, # row of 1 201 | 52, 54, # row of 2 202 | 83, 85, 87, # row of 3 203 | 98, 100, 102, 104, # row of 4 204 | 1, 3, 5, 7, 9, # row of 5 205 | 32, 34, 36, 38, 40, 42]), # row of 6 206 | 207 | "riverraid": dict(player_x=51, 208 | missile_x=117, 209 | missile_y=50, 210 | fuel_meter_high=55, # high value displayed 211 | fuel_meter_low=56 # low value 212 | ), 213 | 214 | "seaquest": dict(enemy_obstacle_x=range(30, 34), 215 | player_x=70, 216 | player_y=97, 217 | diver_or_enemy_missile_x=range(71, 75), 218 | player_direction=86, 219 | player_missile_direction=87, 220 | oxygen_meter_value=102, 221 | player_missile_x=103, 222 | score=[57, 58], 223 | num_lives=59, 224 | divers_collected_count=62), 225 | 226 | "skiing": dict(player_x=25, 227 | clock_m=104, 228 | clock_s=105, 229 | clock_ms=106, 230 | score=107, 231 | object_y=range(87, 94)), # object_y_1 is y position of whatever topmost object on the screen is 232 | 233 | "spaceinvaders": dict(invaders_left_count=17, 234 | player_score=104, 235 | num_lives=73, 236 | player_x=28, 237 | enemies_x=26, 238 | missiles_y=9, 239 | enemies_y=24), 240 | 241 | "tennis": dict(enemy_x=27, 242 | enemy_y=25, 243 | enemy_score=70, 244 | ball_x=16, 245 | ball_y=17, 246 | player_x=26, 247 | player_y=24, 248 | player_score=69), 249 | 250 | "venture": dict(sprite0_y=20, 251 | sprite1_y=21, 252 | sprite2_y=22, 253 | sprite3_y=23, 254 | sprite4_y=24, 255 | sprite5_y=25, 256 | sprite0_x=79, 257 | sprite1_x=80, 258 | sprite2_x=81, 259 | sprite3_x=82, 260 | sprite4_x=83, 261 | sprite5_x=84, 262 | player_x=85, 263 | player_y=26, 264 | current_room=90, # The number of the room the player is currently in 0 to 9_ 265 | num_lives=70, 266 | score_1_2=71, 267 | score_3_4=72), 268 | 269 | "videopinball": dict(ball_x=67, 270 | ball_y=68, 271 | player_left_paddle_y=98, 272 | player_right_paddle_y=102, 273 | score_1=48, 274 | score_2=50), 275 | 276 | "yarsrevenge": dict(player_x=32, 277 | player_y=31, 278 | player_missile_x=38, 279 | player_missile_y=37, 280 | enemy_x=43, 281 | enemy_y=42, 282 | enemy_missile_x=47, 283 | enemy_missile_y=46) 284 | } 285 | 286 | # break up any lists (e.g. dict(clock=[67, 69]) -> dict(clock_0=67, clock_1=69) ) 287 | update_dict = {k: {} for k in atari_dict.keys()} 288 | 289 | remove_dict = {k: [] for k in atari_dict.keys()} 290 | 291 | for game, d in atari_dict.items(): 292 | for k, v in d.items(): 293 | if isinstance(v, range) or isinstance(v, list): 294 | for i, vi in enumerate(v): 295 | update_dict[game]["%s_%i" % (k, i)] = vi 296 | remove_dict[game].append(k) 297 | 298 | for k in atari_dict.keys(): 299 | atari_dict[k].update(update_dict[k]) 300 | for rk in remove_dict[k]: 301 | atari_dict[k].pop(rk) -------------------------------------------------------------------------------- /dreamerv2/common/recorder.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import pathlib 4 | 5 | import imageio 6 | import numpy as np 7 | 8 | 9 | class Recorder: 10 | 11 | def __init__( 12 | self, env, directory, save_stats=True, save_video=True, 13 | save_episode=True, video_size=(512, 512)): 14 | if directory and save_stats: 15 | env = StatsRecorder(env, directory) 16 | if directory and save_video: 17 | env = VideoRecorder(env, directory, video_size) 18 | if directory and save_episode: 19 | env = EpisodeRecorder(env, directory) 20 | if not directory: 21 | env = NoopRecorder(env) 22 | self._env = env 23 | 24 | def __getattr__(self, name): 25 | if name.startswith('__'): 26 | raise AttributeError(name) 27 | return getattr(self._env, name) 28 | 29 | class NoopRecorder: 30 | def __init__(self, env): 31 | self._env = env 32 | 33 | def reset(self): 34 | obs = self._env.reset() 35 | return obs 36 | 37 | def step(self, action, policy_idx=0): 38 | return self._env.step(action) 39 | 40 | def __getattr__(self, name): 41 | if name.startswith('__'): 42 | raise AttributeError(name) 43 | return getattr(self._env, name) 44 | 45 | class StatsRecorder: 46 | 47 | def __init__(self, env, directory): 48 | self._env = env 49 | self._directory = pathlib.Path(directory).expanduser() 50 | self._directory.mkdir(exist_ok=True, parents=True) 51 | self._file = (self._directory / 'stats.jsonl').open('a') 52 | self._length = None 53 | self._reward = None 54 | self._unlocked = None 55 | self._stats = None 56 | 57 | def __getattr__(self, name): 58 | if name.startswith('__'): 59 | raise AttributeError(name) 60 | return getattr(self._env, name) 61 | 62 | def reset(self): 63 | obs = self._env.reset() 64 | self._length = 0 65 | self._reward = 0 66 | self._unlocked = None 67 | self._stats = None 68 | return obs 69 | 70 | def step(self, action, policy_idx=0): 71 | obs, reward, done, info = self._env.step(action) 72 | self._length += 1 73 | self._reward += info['reward'] 74 | if done: 75 | self._stats = {'length': self._length, 'reward': round(self._reward, 1), 'policy_idx': policy_idx} 76 | for key, value in info['achievements'].items(): 77 | self._stats[f'achievement_{key}'] = value 78 | self._save() 79 | return obs, reward, done, info 80 | 81 | def _save(self): 82 | self._file.write(json.dumps(self._stats) + '\n') 83 | self._file.flush() 84 | 85 | 86 | class VideoRecorder: 87 | 88 | def __init__(self, env, directory, size=(512, 512)): 89 | if not hasattr(env, 'episode_name'): 90 | env = EpisodeName(env) 91 | self._env = env 92 | self._directory = pathlib.Path(directory).expanduser() 93 | self._directory.mkdir(exist_ok=True, parents=True) 94 | self._size = size 95 | self._frames = None 96 | 97 | def __getattr__(self, name): 98 | if name.startswith('__'): 99 | raise AttributeError(name) 100 | return getattr(self._env, name) 101 | 102 | def reset(self): 103 | obs = self._env.reset() 104 | self._frames = [self._env.render(self._size)] 105 | return obs 106 | 107 | def step(self, action): 108 | obs, reward, done, info = self._env.step(action) 109 | self._frames.append(self._env.render(self._size)) 110 | if done: 111 | self._save() 112 | return obs, reward, done, info 113 | 114 | def _save(self): 115 | filename = str(self._directory / (self._env.episode_name + '.mp4')) 116 | imageio.mimsave(filename, self._frames) 117 | 118 | 119 | class EpisodeRecorder: 120 | 121 | def __init__(self, env, directory): 122 | if not hasattr(env, 'episode_name'): 123 | env = EpisodeName(env) 124 | self._env = env 125 | self._directory = pathlib.Path(directory).expanduser() 126 | self._directory.mkdir(exist_ok=True, parents=True) 127 | self._episode = None 128 | 129 | def __getattr__(self, name): 130 | if name.startswith('__'): 131 | raise AttributeError(name) 132 | return getattr(self._env, name) 133 | 134 | def reset(self): 135 | obs = self._env.reset() 136 | self._episode = [{'image': obs}] 137 | return obs 138 | 139 | def step(self, action): 140 | # Transitions are defined from the environment perspective, meaning that a 141 | # transition contains the action and the resulting reward and next 142 | # observation produced by the environment in response to said action. 143 | obs, reward, done, info = self._env.step(action) 144 | transition = { 145 | 'action': action, 'image': obs, 'reward': reward, 'done': done, 146 | } 147 | for key, value in info.items(): 148 | if key in ('inventory', 'achievements'): 149 | continue 150 | transition[key] = value 151 | for key, value in info['achievements'].items(): 152 | transition[f'achievement_{key}'] = value 153 | for key, value in info['inventory'].items(): 154 | transition[f'ainventory_{key}'] = value 155 | self._episode.append(transition) 156 | if done: 157 | self._save() 158 | return obs, reward, done, info 159 | 160 | def _save(self): 161 | filename = str(self._directory / (self._env.episode_name + '.npz')) 162 | # Fill in zeros for keys missing at the first time step. 163 | for key, value in self._episode[1].items(): 164 | if key not in self._episode[0]: 165 | self._episode[0][key] = np.zeros_like(value) 166 | episode = { 167 | k: np.array([step[k] for step in self._episode]) 168 | for k in self._episode[0]} 169 | np.savez_compressed(filename, **episode) 170 | 171 | 172 | class EpisodeName: 173 | 174 | def __init__(self, env): 175 | self._env = env 176 | self._timestamp = None 177 | self._unlocked = None 178 | self._length = None 179 | 180 | def __getattr__(self, name): 181 | if name.startswith('__'): 182 | raise AttributeError(name) 183 | return getattr(self._env, name) 184 | 185 | def reset(self): 186 | obs = self._env.reset() 187 | self._timestamp = None 188 | self._unlocked = None 189 | self._length = 0 190 | return obs 191 | 192 | def step(self, action): 193 | obs, reward, done, info = self._env.step(action) 194 | self._length += 1 195 | if done: 196 | self._timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 197 | self._unlocked = sum(int(v >= 1) for v in info['achievements'].values()) 198 | return obs, reward, done, info 199 | 200 | @property 201 | def episode_name(self): 202 | return f'{self._timestamp}-ach{self._unlocked}-len{self._length}' -------------------------------------------------------------------------------- /dreamerv2/common/replay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | import datetime 9 | import io 10 | import pathlib 11 | import uuid 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | 17 | class Replay: 18 | 19 | def __init__( 20 | self, directory, capacity=0, offline_init=False, ongoing=False, minlen=1, maxlen=0, 21 | prioritize_ends=False, multi_reward=False, offline_directory=None): 22 | 23 | self._capacity = capacity 24 | self._ongoing = ongoing 25 | self._minlen = minlen 26 | self._maxlen = maxlen 27 | self._prioritize_ends = prioritize_ends 28 | self._random = np.random.RandomState() 29 | self._eval_score = 0 30 | self.achievements = collections.defaultdict(list) 31 | self._solved_levels = 0 32 | self._multi_reward = multi_reward 33 | self._max_scores = 0 34 | self.rewards = [] 35 | self._mean_scores = 0 36 | 37 | self._directory = pathlib.Path(directory).expanduser() 38 | self._directory.mkdir(parents=True, exist_ok=True) 39 | 40 | if offline_init: 41 | self._total_episodes = 0 42 | self._total_steps = 0 43 | self._loaded_episodes = 0 44 | self._loaded_steps = 0 45 | self._complete_eps = {} 46 | 47 | if type(offline_directory) is not list: 48 | offline_directory = [offline_directory] 49 | 50 | for d in offline_directory: 51 | print(f"\nloading...{d}") 52 | path = pathlib.Path(d).expanduser() 53 | complete_eps, t_steps, t_eps = self.load_episodes(path, capacity, minlen) 54 | saved_eps = save_episodes(self._directory, complete_eps) 55 | self._complete_eps.update(saved_eps) 56 | self._enforce_limit() 57 | self._loaded_episodes += len(complete_eps) 58 | self._loaded_steps += sum(eplen(x) for x in complete_eps.values()) 59 | # filename -> key -> value_sequence 60 | self._complete_eps, _, _ = self.load_episodes(self._directory, capacity, minlen) 61 | # worker -> key -> value_sequence 62 | self._total_episodes, self._total_steps = count_episodes(directory) 63 | self._loaded_episodes = len(self._complete_eps) 64 | self._loaded_steps = sum(eplen(x) for x in self._complete_eps.values()) 65 | 66 | self._ongoing_eps = collections.defaultdict(lambda: collections.defaultdict(list)) 67 | 68 | @property 69 | def stats(self): 70 | return { 71 | 'total_steps': self._total_steps, 72 | 'total_episodes': self._total_episodes, 73 | 'loaded_steps': self._loaded_steps, 74 | 'loaded_episodes': self._loaded_episodes, 75 | 'running_score': self._eval_score, 76 | 'solved_levels': self._solved_levels, 77 | 'max_scores': self._max_scores, 78 | 'mean_scores': self._mean_scores 79 | } 80 | 81 | def add_step(self, transition, worker=0): 82 | episode = self._ongoing_eps[worker] 83 | for key, value in transition.items(): 84 | episode[key].append(value) 85 | if transition['is_last']: 86 | self.add_episode(episode) 87 | episode.clear() 88 | 89 | def add_episode(self, episode): 90 | length = eplen(episode) 91 | if 'log_achievement_collect_diamond' in episode.keys(): 92 | self.update_crafter_score(episode) 93 | if self._multi_reward: 94 | pass # in case we need to do something here 95 | elif 'reward' in episode.keys() and sum(episode['reward']) > 0: 96 | rew = sum(episode['reward']) 97 | self._solved_levels += 1 98 | self._max_scores = max(self._max_scores, rew) 99 | self.rewards.append(rew) 100 | self._mean_scores = np.mean(self.rewards) 101 | if length < self._minlen: 102 | print(f'Skipping short episode of length {length}.') 103 | return 104 | self._total_steps += length 105 | self._loaded_steps += length 106 | self._total_episodes += 1 107 | self._loaded_episodes += 1 108 | episode = {key: convert(value) for key, value in episode.items()} 109 | if self._multi_reward: 110 | episode['reward'] = reshape_rewards_dmc(episode) 111 | filename = save_episode(self._directory, episode) 112 | self._complete_eps[str(filename)] = episode 113 | self._enforce_limit() 114 | 115 | def dataset(self, batch, length): 116 | example = next(iter(self._generate_chunks(length))) 117 | dataset = tf.data.Dataset.from_generator( 118 | lambda: self._generate_chunks(length), 119 | {k: v.dtype for k, v in example.items()}, 120 | {k: v.shape for k, v in example.items()}) 121 | dataset = dataset.batch(batch, drop_remainder=True) 122 | dataset = dataset.prefetch(5) 123 | return dataset 124 | 125 | def _generate_chunks(self, length): 126 | sequence = self._sample_sequence() 127 | while True: 128 | chunk = collections.defaultdict(list) 129 | added = 0 130 | while added < length: 131 | needed = length - added 132 | adding = {k: v[:needed] for k, v in sequence.items()} 133 | sequence = {k: v[needed:] for k, v in sequence.items()} 134 | for key, value in adding.items(): 135 | chunk[key].append(value) 136 | added += len(adding['action']) 137 | if len(sequence['action']) < 1: 138 | sequence = self._sample_sequence() 139 | chunk = {k: np.concatenate(v) for k, v in chunk.items()} 140 | yield chunk 141 | 142 | def _sample_sequence(self): 143 | episodes = list(self._complete_eps.values()) 144 | if self._ongoing: 145 | episodes += [ 146 | x for x in self._ongoing_eps.values() 147 | if eplen(x) >= self._minlen] 148 | episode = self._random.choice(episodes) 149 | total = len(episode['action']) 150 | length = total 151 | if self._maxlen: 152 | length = min(length, self._maxlen) 153 | # Randomize length to avoid all chunks ending at the same time in case the 154 | # episodes are all of the same length. 155 | length -= np.random.randint(self._minlen) 156 | length = max(self._minlen, length) 157 | upper = total - length + 1 158 | if self._prioritize_ends: 159 | upper += self._minlen 160 | index = min(self._random.randint(upper), total - length) 161 | sequence = { 162 | k: convert(v[index: index + length]) 163 | for k, v in episode.items() if not k.startswith('log_')} 164 | sequence['is_first'] = np.zeros(len(sequence['action']), np.bool) 165 | sequence['is_first'][0] = True 166 | if self._maxlen: 167 | assert self._minlen <= len(sequence['action']) <= self._maxlen 168 | return sequence 169 | 170 | def _enforce_limit(self): 171 | if not self._capacity: 172 | return 173 | while self._loaded_episodes > 1 and self._loaded_steps > self._capacity: 174 | # Relying on Python preserving the insertion order of dicts. 175 | oldest, episode = next(iter(self._complete_eps.items())) 176 | self._loaded_steps -= eplen(episode) 177 | self._loaded_episodes -= 1 178 | del self._complete_eps[oldest] 179 | 180 | def update_crafter_score(self, episode): 181 | for key, val in episode.items(): 182 | if 'log_achievement' in key: 183 | self.achievements[key] += [int(any([x.item() for x in episode[key]]))] 184 | 185 | means = [np.mean(vals)*100 for vals in self.achievements.values()] 186 | self._eval_score = (np.exp(np.nanmean(np.log(1 + np.array(means)), -1)) - 1) 187 | 188 | def load_episodes(self, directory, capacity=None, minlen=1): 189 | # The returned directory from filenames to episodes is guaranteed to be in 190 | # temporally sorted order. 191 | filenames = sorted(directory.glob('*.npz')) 192 | if capacity: 193 | num_steps = 0 194 | num_episodes = 0 195 | for filename in reversed(filenames): 196 | length = int(str(filename).split('-')[-1][:-4]) 197 | num_steps += length 198 | num_episodes += 1 199 | if num_steps >= capacity: 200 | break 201 | filenames = filenames[-num_episodes:] 202 | episodes = {} 203 | num_steps = 0 204 | num_episodes = 0 205 | for filename in filenames: 206 | try: 207 | with filename.open('rb') as f: 208 | episode = np.load(f) 209 | episode = {k: episode[k] for k in episode.keys()} 210 | for key, val in episode.items(): 211 | if 'log_achievement' in key: 212 | self.achievements[key] += [int(any([x.item() for x in episode[key]]))] 213 | if not self._multi_reward: 214 | if 'reward' in episode.keys() and sum(episode['reward']) > 0: 215 | rew = sum(episode['reward']) 216 | self._solved_levels += 1 217 | self._max_scores = max(self._max_scores, rew) 218 | self.rewards.append(rew) 219 | self._mean_scores = np.mean(self.rewards) 220 | num_steps += 1 221 | num_episodes += 1 222 | except Exception as e: 223 | print(f'Could not load episode {str(filename)}: {e}') 224 | continue 225 | if 'is_terminal' not in episode: 226 | episode['is_terminal'] = episode['discount'] == 0 227 | episodes[str(filename)] = episode 228 | return episodes, num_steps, num_episodes 229 | 230 | def count_episodes(directory): 231 | filenames = list(directory.glob('*.npz')) 232 | num_episodes = len(filenames) 233 | num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames) 234 | return num_episodes, num_steps 235 | 236 | 237 | def save_episode(directory, episode): 238 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 239 | identifier = str(uuid.uuid4().hex) 240 | length = eplen(episode) 241 | filename = directory / f'{timestamp}-{identifier}-{length}.npz' 242 | with io.BytesIO() as f1: 243 | np.savez_compressed(f1, **episode) 244 | f1.seek(0) 245 | with filename.open('wb') as f2: 246 | f2.write(f1.read()) 247 | return filename 248 | 249 | def save_episodes(directory, episodes): 250 | saved_eps = {} 251 | for _, ep in episodes.items(): 252 | filename = save_episode(directory, ep) 253 | saved_eps[str(filename)] = ep 254 | return saved_eps 255 | 256 | 257 | def convert(value): 258 | value = np.array(value) 259 | if np.issubdtype(value.dtype, np.floating): 260 | return value.astype(np.float32) 261 | elif np.issubdtype(value.dtype, np.signedinteger): 262 | return value.astype(np.int32) 263 | elif np.issubdtype(value.dtype, np.uint8): 264 | return value.astype(np.uint8) 265 | return value 266 | 267 | 268 | def reshape_rewards_dmc(episode): 269 | rew = np.concatenate([r.reshape(1, -1) for r in episode['reward'][1:]], 0) 270 | rew = np.concatenate((np.zeros(rew.shape[1]).reshape(1, rew.shape[1]), rew)) 271 | return rew 272 | 273 | def eplen(episode): 274 | return len(episode['action']) - 1 275 | 276 | -------------------------------------------------------------------------------- /dreamerv2/common/tfutils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import pickle 3 | import re 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.keras import mixed_precision as prec 8 | 9 | try: 10 | from tensorflow.python.distribute import values 11 | except Exception: 12 | from google3.third_party.tensorflow.python.distribute import values 13 | 14 | tf.tensor = tf.convert_to_tensor 15 | for base in (tf.Tensor, tf.Variable, values.PerReplica): 16 | base.mean = tf.math.reduce_mean 17 | base.std = tf.math.reduce_std 18 | base.var = tf.math.reduce_variance 19 | base.sum = tf.math.reduce_sum 20 | base.any = tf.math.reduce_any 21 | base.all = tf.math.reduce_all 22 | base.min = tf.math.reduce_min 23 | base.max = tf.math.reduce_max 24 | base.abs = tf.math.abs 25 | base.logsumexp = tf.math.reduce_logsumexp 26 | base.transpose = tf.transpose 27 | base.reshape = tf.reshape 28 | base.astype = tf.cast 29 | 30 | 31 | # values.PerReplica.dtype = property(lambda self: self.values[0].dtype) 32 | 33 | # tf.TensorHandle.__repr__ = lambda x: '' 34 | # tf.TensorHandle.__str__ = lambda x: '' 35 | # np.set_printoptions(threshold=5, edgeitems=0) 36 | 37 | 38 | class Module(tf.Module): 39 | 40 | def save(self, filename): 41 | values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) 42 | amount = len(tf.nest.flatten(values)) 43 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values))) 44 | print(f'Save checkpoint with {amount} tensors and {count} parameters.') 45 | with pathlib.Path(filename).open('wb') as f: 46 | pickle.dump(values, f) 47 | 48 | def load(self, filename): 49 | with pathlib.Path(filename).open('rb') as f: 50 | values = pickle.load(f) 51 | amount = len(tf.nest.flatten(values)) 52 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values))) 53 | print(f'Load checkpoint with {amount} tensors and {count} parameters.') 54 | amount_agent = len(tf.nest.flatten(self.variables)) 55 | count_agent = int(sum(np.prod(x.shape) for x in tf.nest.flatten(self.variables))) 56 | print(f'Agent checkpoint has {amount_agent} tensors and {count_agent} parameters.') 57 | tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) 58 | 59 | def get(self, name, ctor, *args, **kwargs): 60 | # Create or get layer by name to avoid mentioning it in the constructor. 61 | if not hasattr(self, '_modules'): 62 | self._modules = {} 63 | if name not in self._modules: 64 | self._modules[name] = ctor(*args, **kwargs) 65 | return self._modules[name] 66 | 67 | 68 | class Optimizer(tf.Module): 69 | 70 | def __init__( 71 | self, name, lr, eps=1e-4, clip=None, wd=None, 72 | opt='adam', wd_pattern=r'.*'): 73 | assert 0 <= wd < 1 74 | assert not clip or 1 <= clip 75 | self._name = name 76 | self._clip = clip 77 | self._wd = wd 78 | self._wd_pattern = wd_pattern 79 | self._opt = { 80 | 'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps), 81 | 'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps), 82 | 'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps), 83 | 'sgd': lambda: tf.optimizers.SGD(lr), 84 | 'momentum': lambda: tf.optimizers.SGD(lr, 0.9), 85 | }[opt]() 86 | self._mixed = (prec.global_policy().compute_dtype == tf.float16) 87 | if self._mixed: 88 | self._opt = prec.LossScaleOptimizer(self._opt, dynamic=True) 89 | self._once = True 90 | 91 | @property 92 | def variables(self): 93 | return self._opt.variables() 94 | 95 | def __call__(self, tape, loss, modules): 96 | assert loss.dtype is tf.float32, (self._name, loss.dtype) 97 | assert len(loss.shape) == 0, (self._name, loss.shape) 98 | metrics = {} 99 | 100 | # Find variables. 101 | modules = modules if hasattr(modules, '__len__') else (modules,) 102 | varibs = tf.nest.flatten([module.variables for module in modules]) 103 | count = sum(np.prod(x.shape) for x in varibs) 104 | if self._once: 105 | print(f'Found {count} {self._name} parameters.') 106 | self._once = False 107 | 108 | # Check loss. 109 | tf.debugging.check_numerics(loss, self._name + '_loss') 110 | metrics[f'{self._name}_loss'] = loss 111 | 112 | # Compute scaled gradient. 113 | if self._mixed: 114 | with tape: 115 | loss = self._opt.get_scaled_loss(loss) 116 | grads = tape.gradient(loss, varibs) 117 | if self._mixed: 118 | grads = self._opt.get_unscaled_gradients(grads) 119 | if self._mixed: 120 | metrics[f'{self._name}_loss_scale'] = self._opt.loss_scale 121 | 122 | # Distributed sync. 123 | context = tf.distribute.get_replica_context() 124 | if context: 125 | grads = context.all_reduce('mean', grads) 126 | 127 | # Gradient clipping. 128 | norm = tf.linalg.global_norm(grads) 129 | if not self._mixed: 130 | tf.debugging.check_numerics(norm, self._name + '_norm') 131 | if self._clip: 132 | grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) 133 | metrics[f'{self._name}_grad_norm'] = norm 134 | 135 | # Weight decay. 136 | if self._wd: 137 | self._apply_weight_decay(varibs) 138 | 139 | # Apply gradients. 140 | self._opt.apply_gradients( 141 | zip(grads, varibs), 142 | experimental_aggregate_gradients=False) 143 | 144 | return metrics 145 | 146 | def _apply_weight_decay(self, varibs): 147 | nontrivial = (self._wd_pattern != r'.*') 148 | if nontrivial: 149 | print('Applied weight decay to variables:') 150 | for var in varibs: 151 | if re.search(self._wd_pattern, self._name + '/' + var.name): 152 | if nontrivial: 153 | print('- ' + self._name + '/' + var.name) 154 | var.assign((1 - self._wd) * var) 155 | -------------------------------------------------------------------------------- /dreamerv2/common/when.py: -------------------------------------------------------------------------------- 1 | class Every: 2 | 3 | def __init__(self, every): 4 | self._every = every 5 | self._last = None 6 | 7 | def __call__(self, step): 8 | step = int(step) 9 | if not self._every: 10 | return False 11 | if self._last is None: 12 | self._last = step 13 | return True 14 | if step >= self._last + self._every: 15 | self._last += self._every 16 | return True 17 | return False 18 | 19 | 20 | class Once: 21 | 22 | def __init__(self): 23 | self._once = True 24 | 25 | def __call__(self): 26 | if self._once: 27 | self._once = False 28 | return True 29 | return False 30 | 31 | 32 | class Until: 33 | 34 | def __init__(self, until): 35 | self._until = until 36 | 37 | def __call__(self, step): 38 | step = int(step) 39 | if not self._until: 40 | return True 41 | return step < self._until 42 | -------------------------------------------------------------------------------- /dreamerv2/configs.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | defaults: 8 | 9 | # Train Script 10 | logdir: /dev/null 11 | fix_seed: False 12 | seed: 0 13 | task: dmc_walker_walk 14 | method: single_disag 15 | envs: 1 16 | eval_envs: 1 17 | envs_parallel: none 18 | num_agents: 1 19 | render_size: [64, 64] 20 | dmc_camera: -1 21 | atari_grayscale: True 22 | time_limit: 0 23 | action_repeat: 1 24 | steps: 1e8 25 | log_every: 1e4 26 | eval_every: 1e5 27 | eval_eps: 10 28 | eval_type: coincidental 29 | prefill: 10000 30 | pretrain: 1 31 | train_every: 5 32 | train_steps: 1 33 | explorer_train_steps: 2e5 34 | explorer_reinit: False 35 | expl_until: 0 36 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: True} 37 | dataset: {batch: 16, length: 50} 38 | log_keys_video: ['image'] 39 | log_keys_sum: '^$' 40 | log_keys_mean: '^$' 41 | log_keys_max: '^$' 42 | precision: 16 43 | jit: True 44 | wandb_silent: False 45 | wandb_base_url: https://api.wandb.ai 46 | wandb_api_key: none 47 | wandb_entity: divwm 48 | wandb_project: dmc 49 | wandb_group: none 50 | xpid: none 51 | checkpoint: True 52 | load_pretrained: none 53 | replay_dir: none 54 | load_wm: none 55 | skip_wm_train: False 56 | 57 | offline_dir: none 58 | offline_model_train_steps: 25001 59 | offline_model_loaddir: none 60 | offline_lmbd: 5.0 61 | offline_penalty_type: none 62 | offline_model_save_every: 1000000 63 | offline_split_val: False 64 | offline_tune_lmbd: False 65 | offline_lmbd_cons: 1.5 66 | offline_model_dataset: {batch: 16, length: 50} 67 | offline_train_dataset: {batch: 16, length: 50} 68 | task_train_steps: 5001 69 | record_video: False 70 | 71 | # Agent 72 | clip_rewards: tanh 73 | expl_behavior: greedy 74 | expl_noise: 0.0 75 | eval_noise: 0.0 76 | eval_state_mean: False 77 | 78 | # World Model 79 | grad_heads: [decoder, reward, discount] 80 | pred_discount: True 81 | rssm: {ensemble: 1, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 82 | encoder: {mlp_keys: '.*', cnn_keys: '.*', act: elu, norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} 83 | decoder: {mlp_keys: '.*', cnn_keys: '.*', act: elu, norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]} 84 | reward_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 85 | discount_head: {layers: 4, units: 400, act: elu, norm: none, dist: binary} 86 | loss_scales: {kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0} 87 | kl: {free: 0.0, forward: False, balance: 0.8, free_avg: True} 88 | model_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100, wd: 1e-6} 89 | 90 | # Actor Critic 91 | actor: {layers: 4, units: 400, act: elu, norm: none, dist: auto, min_std: 0.1} 92 | critic: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 93 | actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6} 94 | critic_opt: {opt: adam, lr: 2e-4, eps: 1e-5, clip: 100, wd: 1e-6} 95 | discount: 0.99 96 | discount_lambda: 0.95 97 | imag_horizon: 15 98 | actor_grad: auto 99 | actor_grad_mix: 0.1 100 | actor_ent: 2e-3 101 | slow_target: True 102 | slow_target_update: 100 103 | slow_target_fraction: 1 104 | slow_baseline: True 105 | reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 106 | 107 | # Exploration 108 | expl_intr_scale: 1.0 109 | expl_extr_scale: 0.0 110 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 111 | expl_head: {layers: 4, units: 400, act: elu, norm: none, dist: mse} 112 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 113 | disag_target: stoch 114 | disag_log: False 115 | disag_models: 10 116 | disag_offset: 1 117 | disag_action_cond: True 118 | expl_model_loss: kl 119 | 120 | # Cascade 121 | cascade_alpha: 0.0 122 | cascade_feat: "deter" 123 | cascade_k: 5 124 | cascade_sample: 10 125 | 126 | atari: 127 | 128 | task: atari_pong 129 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 130 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 131 | time_limit: 27000 132 | action_repeat: 4 133 | steps: 5e7 134 | eval_every: 2.5e5 135 | prefill: 50000 136 | train_every: 16 137 | clip_rewards: tanh 138 | rssm: {hidden: 600, deter: 600} 139 | model_opt.lr: 2e-4 140 | actor_opt.lr: 4e-5 141 | critic_opt.lr: 1e-4 142 | actor_ent: 1e-3 143 | discount: 0.999 144 | loss_scales.kl: 0.1 145 | loss_scales.discount: 5.0 146 | 147 | crafter: 148 | 149 | task: crafter_reward 150 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 151 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 152 | log_keys_max: '^log_achievement_.*' 153 | log_keys_sum: '^log_reward$' 154 | rssm: {hidden: 1024, deter: 1024} 155 | discount: 0.999 156 | model_opt.lr: 1e-4 157 | actor_opt.lr: 1e-4 158 | critic_opt.lr: 1e-4 159 | actor_ent: 3e-3 160 | .*\.norm: layer 161 | 162 | dmc_vision: 163 | 164 | task: dmc_walker_walk 165 | encoder: {mlp_keys: '$^', cnn_keys: 'image'} 166 | decoder: {mlp_keys: '$^', cnn_keys: 'image'} 167 | action_repeat: 2 168 | imag_horizon: 5 169 | eval_every: 1e4 170 | prefill: 1000 171 | pretrain: 100 172 | clip_rewards: identity 173 | pred_discount: False 174 | replay.prioritize_ends: False 175 | grad_heads: [decoder, reward] 176 | rssm: {hidden: 200, deter: 200} 177 | model_opt.lr: 3e-4 178 | actor_opt.lr: 8e-5 179 | critic_opt.lr: 8e-5 180 | actor_ent: 1e-4 181 | kl.free: 1.0 182 | 183 | dmc_proprio: 184 | 185 | task: dmc_walker_walk 186 | encoder: {mlp_keys: '.*', cnn_keys: '$^'} 187 | decoder: {mlp_keys: '.*', cnn_keys: '$^'} 188 | action_repeat: 2 189 | eval_every: 1e4 190 | prefill: 1000 191 | pretrain: 100 192 | clip_rewards: identity 193 | pred_discount: False 194 | replay.prioritize_ends: False 195 | grad_heads: [decoder, reward] 196 | rssm: {hidden: 200, deter: 200} 197 | model_opt.lr: 3e-4 198 | actor_opt.lr: 8e-5 199 | critic_opt.lr: 8e-5 200 | actor_ent: 1e-4 201 | kl.free: 1.0 202 | 203 | debug: 204 | 205 | jit: False 206 | time_limit: 100 207 | eval_every: 300 208 | log_every: 300 209 | prefill: 100 210 | pretrain: 1 211 | train_steps: 1 212 | replay: {minlen: 10, maxlen: 30} 213 | dataset: {batch: 10, length: 10} 214 | 215 | minigrid: 216 | actor_ent: 3e-3 217 | log_keys_sum: 'reward' 218 | log_keys_mean: 'reward' 219 | log_keys_max: 'reward' 220 | replay: {capacity: 2e6, ongoing: False, minlen: 2, maxlen: 120, prioritize_ends: True} 221 | dataset: {batch: 16, length: 40} 222 | -------------------------------------------------------------------------------- /dreamerv2/expl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import tensorflow as tf 8 | from tensorflow_probability import distributions as tfd 9 | 10 | import agent 11 | import common 12 | 13 | 14 | class Random(common.Module): 15 | 16 | def __init__(self, config, act_space, wm, tfstep, reward): 17 | self.config = config 18 | self.act_space = act_space 19 | discrete = hasattr(act_space, 'n') 20 | if self.config.actor.dist == 'auto': 21 | self.config = self.config.update({ 22 | 'actor.dist': 'onehot' if discrete else 'trunc_normal'}) 23 | 24 | def actor(self, feat): 25 | shape = feat.shape[:-1] + self.act_space.shape 26 | if self.config.actor.dist == 'onehot': 27 | return common.OneHotDist(tf.zeros(shape)) 28 | else: 29 | dist = tfd.Uniform(-tf.ones(shape), tf.ones(shape)) 30 | return tfd.Independent(dist, 1) 31 | 32 | def train(self, start, context, data): 33 | return None, {} 34 | 35 | 36 | class Plan2Explore(common.Module): 37 | 38 | def __init__(self, config, act_space, wm, tfstep, reward): 39 | self.config = config 40 | self.act_space = act_space 41 | self.tfstep = tfstep 42 | self.reward = reward 43 | self.wm = wm 44 | self._init_actors() 45 | 46 | stoch_size = config.rssm.stoch 47 | if config.rssm.discrete: 48 | stoch_size *= config.rssm.discrete 49 | size = { 50 | 'embed': 32 * config.encoder.cnn_depth, 51 | 'stoch': stoch_size, 52 | 'deter': config.rssm.deter, 53 | 'feat': config.rssm.stoch + config.rssm.deter, 54 | }[self.config.disag_target] 55 | self._networks = [ 56 | common.MLP(size, **config.expl_head) 57 | for _ in range(config.disag_models)] 58 | self.opt = common.Optimizer('expl', **config.expl_opt) 59 | self.extr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm) 60 | 61 | def _init_actors(self): 62 | self.intr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm) 63 | self.ac = [agent.ActorCritic(self.config, self.act_space, self.tfstep) for _ in range(self.config.num_agents)] 64 | if self.config.cascade_alpha > 0: 65 | self.intr_rewnorm_cascade = [common.StreamNorm(**self.config.expl_reward_norm) for _ in range(self.config.num_agents)] 66 | self.actor = [ac.actor for ac in self.ac] 67 | 68 | def train(self, start, context, data): 69 | metrics = {} 70 | stoch = start['stoch'] 71 | if self.config.rssm.discrete: 72 | stoch = tf.reshape( 73 | stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1])) 74 | target = { 75 | 'embed': context['embed'], 76 | 'stoch': stoch, 77 | 'deter': start['deter'], 78 | 'feat': context['feat'], 79 | }[self.config.disag_target] 80 | inputs = context['feat'] 81 | if self.config.disag_action_cond: 82 | action = tf.cast(data['action'], inputs.dtype) 83 | inputs = tf.concat([inputs, action], -1) 84 | metrics.update(self._train_ensemble(inputs, target)) 85 | gpu = tf.config.list_physical_devices('GPU') 86 | if gpu: 87 | tf.config.experimental.set_memory_growth(gpu[0], True) 88 | print(f"Before: {tf.config.experimental.get_memory_usage('GPU:0')}", flush=True) 89 | self.cascade = [] 90 | reward_func = self._intr_reward_incr 91 | print("training explorers", flush=True) 92 | [metrics.update(ac.train(self.wm, start, data['is_terminal'], reward_func)) for ac in self.ac] 93 | self.cascade = [] 94 | print("finished training explorers", flush=True) 95 | return None, metrics 96 | 97 | def _intr_reward(self, seq, rtn_meta=True): 98 | inputs = seq['feat'] 99 | if self.config.disag_action_cond: 100 | action = tf.cast(seq['action'], inputs.dtype) 101 | inputs = tf.concat([inputs, action], -1) 102 | preds = [head(inputs).mode() for head in self._networks] 103 | disag = tf.cast(tf.tensor(preds).std(0).mean(-1), tf.float16) 104 | if self.config.disag_log: 105 | disag = tf.math.log(disag) 106 | reward = self.config.expl_intr_scale * self.intr_rewnorm(disag)[0] 107 | if self.config.expl_extr_scale: 108 | reward += self.config.expl_extr_scale * self.extr_rewnorm( 109 | self.reward(seq))[0] 110 | if rtn_meta: 111 | return reward, {'Disagreement': [disag.mean()]} 112 | else: 113 | return reward 114 | 115 | @tf.function 116 | def get_dists(self, obs, cascade): 117 | ### zzz way to do this 118 | out = [] 119 | for idx in range(obs.shape[1]): 120 | cascade = tf.reshape(cascade, [-1, cascade.shape[-1]]) 121 | ob = tf.reshape(obs[:, idx, :], [obs.shape[0], 1, obs.shape[-1]]) 122 | dists = tf.math.sqrt(tf.einsum('ijk, ijk->ij', cascade - ob, cascade - ob)) 123 | topk_mean = tf.negative(tf.math.top_k(tf.negative(dists), k=self.config.cascade_k)[0]) 124 | out += [tf.reshape(tf.math.reduce_mean(topk_mean, axis=-1), (1, -1))] 125 | return tf.concat(out, axis=1) 126 | 127 | def get_cascade_entropy(self): 128 | cascade = tf.concat(self.cascade, axis=0) 129 | cascade = tf.reshape(cascade, [-1, cascade.shape[-1]]) 130 | entropy = tf.math.reduce_variance(cascade, axis=-1).mean() 131 | self.entropy = entropy 132 | return entropy 133 | 134 | def _intr_reward_incr(self, seq): 135 | agent_idx = len(self.cascade) 136 | ## disagreement 137 | reward, met = self._intr_reward(seq) 138 | # CASCADE 139 | if self.config.cascade_alpha > 0: 140 | ## reward = (1 - \alpha) * disagreement + \alpha * diversity 141 | if len(self.cascade) == 0: 142 | idxs = tf.range(tf.shape(seq[self.config.cascade_feat])[1]) 143 | size = min(seq[self.config.cascade_feat].shape[1], self.config.cascade_sample) 144 | self.ridxs = tf.random.shuffle(idxs)[:size] 145 | self.dist = None 146 | self.entropy = 0 147 | 148 | self.cascade.append(tf.gather(seq[self.config.cascade_feat][-1], self.ridxs, axis=1)) 149 | cascade_reward = self.get_cascade_entropy() 150 | cascade_reward = tf.concat([tf.cast(tf.zeros([seq[self.config.cascade_feat].shape[0] - 1, seq[self.config.cascade_feat].shape[1]]), tf.float16), tf.cast(tf.broadcast_to(cascade_reward, shape=(1, seq[self.config.cascade_feat].shape[1])), tf.float16)], axis=0) 151 | cascade_reward = self.intr_rewnorm_cascade[agent_idx](cascade_reward)[0] 152 | met.update({'Diversity': [cascade_reward.mean()]}) 153 | reward = reward * (1 - self.config.cascade_alpha) + self.config.cascade_alpha * cascade_reward 154 | return reward, met 155 | 156 | def _train_ensemble(self, inputs, targets): 157 | if self.config.disag_offset: 158 | targets = targets[:, self.config.disag_offset:] 159 | inputs = inputs[:, :-self.config.disag_offset] 160 | targets = tf.stop_gradient(targets) 161 | inputs = tf.stop_gradient(inputs) 162 | with tf.GradientTape() as tape: 163 | preds = [head(inputs) for head in self._networks] 164 | loss = -sum([pred.log_prob(targets).mean() for pred in preds]) 165 | metrics = self.opt(tape, loss, self._networks) 166 | return metrics 167 | 168 | class ModelLoss(common.Module): 169 | 170 | def __init__(self, config, act_space, wm, tfstep, reward): 171 | self.config = config 172 | self.reward = reward 173 | self.wm = wm 174 | self.ac = agent.ActorCritic(config, act_space, tfstep) 175 | self.actor = self.ac.actor 176 | self.head = common.MLP([], **self.config.expl_head) 177 | self.opt = common.Optimizer('expl', **self.config.expl_opt) 178 | 179 | def train(self, start, context, data): 180 | metrics = {} 181 | target = tf.cast(context[self.config.expl_model_loss], tf.float16) 182 | with tf.GradientTape() as tape: 183 | loss = -self.head(context['feat']).log_prob(target).mean() 184 | metrics.update(self.opt(tape, loss, self.head)) 185 | metrics.update(self.ac.train( 186 | self.wm, start, data['is_terminal'], self._intr_reward)) 187 | return None, metrics 188 | 189 | def _intr_reward(self, seq): 190 | reward = self.config.expl_intr_scale * self.head(seq['feat']).mode() 191 | if self.config.expl_extr_scale: 192 | reward += self.config.expl_extr_scale * self.reward(seq) 193 | return reward 194 | -------------------------------------------------------------------------------- /dreamerv2/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import pathlib 10 | import re 11 | import sys 12 | import warnings 13 | import pickle 14 | 15 | try: 16 | import rich.traceback 17 | rich.traceback.install() 18 | except ImportError: 19 | pass 20 | 21 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 22 | logging.getLogger().setLevel('ERROR') 23 | warnings.filterwarnings('ignore', '.*box bound precision lowered.*') 24 | 25 | sys.path.append(str(pathlib.Path(__file__).parent)) 26 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 27 | 28 | import numpy as np 29 | 30 | import agent 31 | import common 32 | 33 | def run(config): 34 | 35 | logdir = pathlib.Path(config.logdir + config.xpid).expanduser() 36 | logdir.mkdir(parents=True, exist_ok=True) 37 | config.save(logdir / 'config.yaml') 38 | print(config, '\n') 39 | print('Logdir', logdir) 40 | 41 | import tensorflow as tf 42 | tf.config.experimental_run_functions_eagerly(not config.jit) 43 | message = 'No GPU found. To actually train on CPU remove this assert.' 44 | if len(tf.config.experimental.list_physical_devices('GPU')) == 0: 45 | print(message) 46 | else: 47 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 48 | tf.config.experimental.set_memory_growth(gpu, True) 49 | assert config.precision in (16, 32), config.precision 50 | if config.precision == 16: 51 | from tensorflow.keras.mixed_precision import experimental as prec 52 | prec.set_policy(prec.Policy('mixed_float16')) 53 | 54 | ## Load the stats that we keep track of 55 | if (logdir / 'stats.pkl').exists(): 56 | stats = pickle.load(open(f"{logdir}/stats.pkl", "rb")) 57 | print("Loaded stats: ", stats) 58 | else: 59 | stats = { 60 | 'num_deployments': 0, 61 | 'num_trains': 0, 62 | 'num_evals': 0 63 | } 64 | pickle.dump(stats, open(f"{logdir}/stats.pkl", "wb")) 65 | 66 | multi_reward = config.task in common.DMC_TASK_IDS 67 | replay_dir = logdir / 'train_episodes' 68 | ## load dataset - we dont want to load offline again if we have already deployed 69 | if config.offline_dir == 'none' or stats['num_deployments'] > 0: 70 | train_replay = common.Replay(replay_dir, offline_init=False, 71 | multi_reward=multi_reward, **config.replay) 72 | else: 73 | train_replay = common.Replay(replay_dir, offline_init=True, 74 | multi_reward=multi_reward, offline_directory=config.offline_dir, **config.replay) 75 | eval_replay = common.Replay(logdir / 'eval_episodes', **dict( 76 | capacity=config.replay.capacity // 10, 77 | minlen=config.dataset.length, 78 | maxlen=config.dataset.length, 79 | multi_reward=multi_reward)) 80 | step = common.Counter(train_replay.stats['total_steps']) 81 | outputs = [ 82 | common.TerminalOutput(), 83 | common.JSONLOutput(logdir), 84 | common.TensorBoardOutput(logdir), 85 | ] 86 | logger = common.Logger(step, outputs, multiplier=config.action_repeat) 87 | 88 | def make_env(mode, seed=1): 89 | if '_' in config.task: 90 | suite, task = config.task.split('_', 1) 91 | else: 92 | suite, task = config.task, '' 93 | if suite == 'dmc': 94 | env = common.DMC( 95 | task, config.action_repeat, config.render_size, config.dmc_camera, save_path=logdir / 'videos') 96 | env = common.NormalizeAction(env) 97 | elif suite == 'atari': 98 | env = common.Atari( 99 | task, config.action_repeat, config.render_size, 100 | config.atari_grayscale, life_done=False, save_path=logdir / 'videos') # do not terminate on life loss 101 | env = common.OneHotAction(env) 102 | elif suite == 'crafter': 103 | assert config.action_repeat == 1 104 | outdir = logdir / 'crafter' if mode == 'train' else None 105 | reward = bool(['noreward', 'reward'].index(task)) or mode == 'eval' 106 | env = common.Crafter(outdir, reward, save_path=logdir / 'videos') 107 | env = common.OneHotAction(env) 108 | elif suite == 'minigrid': 109 | if mode == 'eval': 110 | env = common.make_minigrid_env(task, fix_seed=True, seed=seed) 111 | else: 112 | env = common.make_minigrid_env(task, fix_seed=False, seed=None) 113 | else: 114 | raise NotImplementedError(suite) 115 | env = common.TimeLimit(env, config.time_limit) 116 | return env 117 | 118 | def per_episode(ep, mode, task='none'): 119 | length = len(ep['reward']) - 1 120 | if task in common.DMC_TASK_IDS: 121 | scores = { 122 | key: np.sum([val[idx] for val in ep['reward'][1:]]) 123 | for idx, key in enumerate(common.DMC_TASK_IDS[task])} 124 | print_rews = f'{mode.title()} episode has {length} steps and returns ' 125 | print_rews += ''.join([f"{key}:{np.round(val,1)} " for key,val in scores.items()]) 126 | print(print_rews) 127 | for key,val in scores.items(): 128 | logger.scalar(f'{mode}_return_{key}', val) 129 | else: 130 | score = float(ep['reward'].astype(np.float64).sum()) 131 | print(f'{mode.title()} episode has {length} steps and return {score:.1f}.') 132 | logger.scalar(f'{mode}_return', score) 133 | logger.scalar(f'{mode}_length', length) 134 | for key, value in ep.items(): 135 | if re.match(config.log_keys_sum, key): 136 | logger.scalar(f'sum_{mode}_{key}', ep[key].sum()) 137 | if re.match(config.log_keys_mean, key): 138 | logger.scalar(f'mean_{mode}_{key}', ep[key].mean()) 139 | if re.match(config.log_keys_max, key): 140 | logger.scalar(f'max_{mode}_{key}', ep[key].max(0).mean()) 141 | replay = dict(train=train_replay, eval=eval_replay)[mode] 142 | logger.add(replay.stats, prefix=mode) 143 | logger.write() 144 | 145 | print('Create envs.\n') 146 | train_envs = [make_env('train') for _ in range(config.envs)] 147 | eval_envs = [make_env('eval') for _ in range(config.eval_envs)] 148 | 149 | act_space = train_envs[0].act_space 150 | obs_space = train_envs[0].obs_space 151 | train_driver = common.Driver(train_envs) 152 | train_driver.on_episode(lambda ep: per_episode(ep, mode='train', task=config.task)) 153 | train_driver.on_step(lambda tran, worker: step.increment()) 154 | train_driver.on_step(train_replay.add_step) 155 | train_driver.on_reset(train_replay.add_step) 156 | eval_driver = common.Driver(eval_envs) 157 | eval_driver.on_episode(eval_replay.add_episode) 158 | eval_driver.on_episode(lambda ep: per_episode(ep, mode='eval', task=config.task)) 159 | 160 | if stats['num_deployments'] == 0: 161 | if config.offline_dir == 'none': 162 | prefill = max(0, config.train_every - train_replay.stats['total_steps']) 163 | if prefill: 164 | print(f'Prefill dataset ({prefill} steps).') 165 | random_agent = common.RandomAgent(act_space) 166 | train_driver(random_agent, steps=prefill, episodes=1, policy_idx=-1) 167 | train_driver.reset() 168 | 169 | eval_driver(random_agent, episodes=1, policy_idx=-1) 170 | eval_driver.reset() 171 | stats['num_deployments'] += 1 172 | train_dataset = iter(train_replay.dataset(**config.offline_model_dataset)) 173 | 174 | print('Create agent.\n') 175 | agnt = agent.Agent(config, obs_space, act_space, step) 176 | train_agent = common.CarryOverState(agnt.train) 177 | 178 | # Attempt to load pretrained full model. 179 | # this can be used to test zero-shot performance on new tasks. 180 | if config.load_pretrained != "none": 181 | print("\nLoading pretrained model...") 182 | train_agent(next(train_dataset)) 183 | path = pathlib.Path(config.load_pretrained).expanduser() 184 | agnt.load(path) 185 | ## Assume we've done 1 full cycle 186 | stats = { 187 | 'num_deployments': 1, 188 | 'num_trains': 1, 189 | 'num_evals': 1 190 | } 191 | print("\nSuccessfully loaded pretrained model.") 192 | else: 193 | print("\nInitializing agent...") 194 | train_agent(next(train_dataset)) 195 | if (logdir / 'variables.pkl').exists(): 196 | print("\nStart loading model checkpoint...") 197 | agnt.load(logdir / 'variables.pkl') 198 | print("\nFinished initialize agent.") 199 | 200 | # Initialize policies 201 | eval_policies = {} 202 | tasks = [''] 203 | if config.task in common.DMC_TASK_IDS: 204 | tasks = common.DMC_TASK_IDS[config.task] 205 | for task in tasks: 206 | eval_policies[task] = lambda *args: agnt.policy(*args, mode='eval', goal=task) 207 | expl_policies = {} 208 | for idx in range(config.num_agents): 209 | expl_policies[idx] = lambda *args: agnt.policy(*args, policy_idx=idx, mode='explore') 210 | 211 | 212 | ## each loop we do one of the following: 213 | # 1. deploy explorers to collect data 214 | # 2. train WM, explorers, task policies etc. 215 | # 3. evaluate models 216 | while step < config.steps: 217 | print(f"\nMain loop step {step.value}") 218 | should_deploy = stats['num_deployments'] <= stats['num_evals'] 219 | should_train_wm = stats['num_trains'] < stats['num_deployments'] 220 | should_eval = stats['num_evals'] < stats['num_trains'] 221 | 222 | assert should_deploy + should_train_wm + should_eval == 1 223 | 224 | if should_deploy: 225 | print("\n\nStart collecting data...", flush=True) 226 | ## collect a batch of steps with the expl policy 227 | ## need to increment steps here 228 | num_steps = int(config.train_every / config.num_agents) 229 | for idx in range(config.num_agents): 230 | expl_policy = expl_policies[idx] 231 | train_driver(expl_policy, steps=num_steps, policy_idx=idx) 232 | stats['num_deployments'] += 1 233 | 234 | elif should_eval: 235 | print('\n\nStart evaluation...', flush=True) 236 | if int(step.value) % int(config.eval_every) != 0 or config.eval_type == 'none': 237 | pass 238 | elif config.eval_type == 'coincidental': 239 | mets = common.eval(eval_driver, config, expl_policies, logdir) 240 | for name, values in mets.items(): 241 | logger.scalar(name, np.array(values, np.float64).mean()) 242 | logger.write() 243 | elif config.eval_type == 'labels': 244 | tasks = [''] 245 | if config.task in common.DMC_TASK_IDS: 246 | tasks = common.DMC_TASK_IDS[config.task] 247 | for idx, task in enumerate(tasks): 248 | print("\n\nStart Evaluating " + task) 249 | eval_policy = eval_policies[task] 250 | eval_driver(eval_policy, episodes=config.eval_eps) 251 | mets = common.get_stats(eval_driver, task=config.task, num_agents=config.num_agents, logdir=logdir) 252 | rew = mets["eval_reward_" + task] if task != '' else mets["eval_reward"] 253 | # logging 254 | logger.scalar("eval_reward_" + task, np.mean(rew)) 255 | logger.write() 256 | stats['num_evals'] += 1 257 | 258 | elif should_train_wm: 259 | print('\n\nStart model training...') 260 | should_pretrain = (stats['num_trains'] == 0 and config.offline_dir != "none") 261 | if should_pretrain: 262 | # Use all offline data for pretrain 263 | batch_size = config.offline_model_dataset["batch"] * config.offline_model_dataset["length"] 264 | model_train_steps = train_replay._loaded_steps // batch_size - 1 265 | else: 266 | model_train_steps = config.offline_model_train_steps 267 | model_step = common.Counter(0) 268 | while model_step < model_train_steps: 269 | model_step.increment() 270 | mets = train_agent(next(train_dataset)) 271 | # save model every 1000 272 | if int(model_step.value) % 1000 == 0: 273 | agnt.save(logdir / 'variables.pkl') 274 | stats['num_trains'] += 1 275 | 276 | # save 277 | pickle.dump(stats, open(f"{logdir}/stats.pkl", "wb")) 278 | agnt.save(logdir / 'variables.pkl') 279 | 280 | # closing all envs 281 | for env in train_envs + eval_envs: 282 | try: 283 | env.close() 284 | except Exception: 285 | pass 286 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import dreamerv2.api as dv2 9 | from dreamerv2.train import run 10 | 11 | def str2bool(v): 12 | if isinstance(v, bool): 13 | return v 14 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 15 | return True 16 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 17 | return False 18 | else: 19 | raise argparse.ArgumentTypeError('Boolean value expected.') 20 | 21 | def main(args): 22 | 23 | ## get defaults 24 | config = dv2.defaults 25 | if args.task: 26 | if 'crafter' in args.task: 27 | config = config.update(dv2.configs['crafter']) 28 | elif 'minigrid' in args.task: 29 | config = config.update(dv2.configs['minigrid']) 30 | elif 'atari' in args.task: 31 | config = config.update(dv2.configs['atari']) 32 | elif 'dmc' in args.task: 33 | config = config.update(dv2.configs['dmc_vision']) 34 | 35 | params = vars(args) 36 | config = config.update(params) 37 | 38 | config = config.update({ 39 | 'expl_behavior': 'Plan2Explore', 40 | 'pred_discount': False, 41 | 'grad_heads': ['decoder'], # this means we dont learn the reward head 42 | 'expl_intr_scale': 1.0, 43 | 'expl_extr_scale': 0.0, 44 | 'discount': 0.99, 45 | }) 46 | 47 | run(config) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser(description='RL') 52 | 53 | # DreamerV2 54 | parser.add_argument('--xpid', type=str, default=None, help='experiment id') 55 | parser.add_argument('--steps', type=int, default=1e6, help='number of environment steps to train') 56 | parser.add_argument('--train_every', type=int, default=1e5, help='number of environment steps to train') 57 | parser.add_argument('--offline_model_train_steps', type=int, default=25001, help='=250 * train_every (in thousands) + 1. Default assumes 100k.') 58 | parser.add_argument('--task', type=str, default='crafter_noreward', help='environment to train on') 59 | parser.add_argument('--logdir', default='~/wm_logs/', help='directory to save agent logs') 60 | parser.add_argument('--num_agents', type=int, default=1, help='exploration population size.') 61 | parser.add_argument('--seed', type=int, default=100, help='seed for init NNs.') 62 | parser.add_argument('--envs', type=int, default=1, help='number of training envs.') 63 | parser.add_argument('--envs_parallel', type=str, default="none", help='how to parallelize.') 64 | parser.add_argument('--eval_envs', type=int, default=1, help='number of parallel eval envs.') 65 | parser.add_argument('--eval_eps', type=int, default=100, help='number of eval eps.') 66 | parser.add_argument('--eval_type', type=str, default='coincidental', help='how to evaluate the model.') 67 | parser.add_argument('--expl_behavior', type=str, default='Plan2Explore', help='algorithm for exploration: Plan2Explore or Random.') 68 | parser.add_argument('--load_pretrained', type=str, default='none', help='name of pretrained model') 69 | parser.add_argument('--offline_dir', type=str, default='none', help='directory to load offline dataset') 70 | 71 | # CASCADE 72 | parser.add_argument('--cascade_alpha', type=float, default=0, help='Cascade weight.') 73 | parser.add_argument('--cascade_feat', type=str, default="deter", help='Cascade features if state based.') 74 | parser.add_argument('--cascade_k', type=int, default=5, help='number of nearest neighbors to use in the mean dist.') 75 | parser.add_argument('--cascade_sample', type=int, default=100, help='max number of cascade states') 76 | 77 | args = parser.parse_args() 78 | main(args) 79 | --------------------------------------------------------------------------------