├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── bc_config.py ├── drq_config.py ├── pixel_bc_config.py ├── pixel_config.py ├── pixel_rm_config.py ├── pixel_rnd_config.py ├── rlpd_config.py ├── rlpd_pixels_config.py ├── rm_config.py ├── rnd_config.py ├── sac_config.py └── td_config.py ├── create_env.sh ├── data └── README.md ├── generate_adroit.sh ├── generate_all.sh ├── generate_antmaze.sh ├── generate_cog.sh ├── plotting ├── filter_data.py ├── make_plots.py ├── visualize_maze.py ├── visualize_reward.py └── wandb_dl.py ├── requirements.txt ├── rlpd ├── __init__.py ├── agents │ ├── __init__.py │ ├── agent.py │ ├── bc.py │ ├── drq │ │ ├── __init__.py │ │ ├── augmentations.py │ │ ├── bc.py │ │ ├── drq_learner.py │ │ ├── icvf.py │ │ ├── rm.py │ │ └── rnd.py │ ├── rm.py │ ├── rnd.py │ └── sac │ │ ├── __init__.py │ │ ├── sac_learner.py │ │ └── temperature.py ├── data │ ├── __init__.py │ ├── binary_datasets.py │ ├── cog_datasets.py │ ├── d4rl_datasets.py │ ├── dataset.py │ ├── memory_efficient_replay_buffer.py │ └── replay_buffer.py ├── distributions │ ├── __init__.py │ ├── tanh_deterministic.py │ ├── tanh_normal.py │ └── tanh_transformed.py ├── evaluation.py ├── gc_dataset.py ├── networks │ ├── __init__.py │ ├── encoders │ │ ├── __init__.py │ │ └── d4pg_encoder.py │ ├── ensemble.py │ ├── mlp.py │ ├── mlp_resnet.py │ ├── pixel_multiplexer.py │ └── state_action_value.py ├── types.py └── wrappers │ ├── __init__.py │ ├── frame_stack.py │ ├── pixels.py │ ├── repeat_action.py │ ├── single_precision.py │ └── universal_seed.py ├── run_antmaze.sh ├── submit.py ├── submit_all.sh ├── train_finetuning.py ├── train_finetuning_pixels.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | exp_data_cog/ 163 | data.zip 164 | core 165 | wandb 166 | logs 167 | sbatch 168 | .DS_Store 169 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ExPLORe 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * [Black formatter](https://black.readthedocs.io/en/stable/) 31 | 32 | ## License 33 | By contributing to ExPLORe, you agree that your contributions will be licensed 34 | under the LICENSE file in the root directory of this source tree. 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploration from Prior Data by Labeling Optimistic Reward (ExPLORe) 2 | 3 | This is code to accompany the NeurIPS 2023 paper [Accelerating Exploration with Unlabeled Prior Data](https://arxiv.org/abs/2311.05067). 4 | 5 | The code is built off from https://github.com/ikostrikov/rlpd/ and the ICVF implementation is from https://github.com/dibyaghosh/icvf_release/. 6 | 7 | ExPLORe is licensed under CC-BY-NC, however portions of the project (indicated by the header in each file influenced) are available under separate license terms: 8 | - rlpd (https://github.com/ikostrikov/) is licensed under the MIT license. 9 | - icvf (https://github.com/dibyaghosh/icvf_release) is licensed under the MIT license. 10 | 11 | # Installation (assumes CUDA 11) 12 | 13 | ```bash 14 | ./create_env 15 | ``` 16 | 17 | ## D4RL Antmaze 18 | ```bash 19 | XLA_PYTHON_CLIENT_PREALLOCATE=false python train_finetuning.py \ 20 | --env_name=antmaze-large-diverse-v2 \ 21 | --max_steps=300000 \ 22 | --config.backup_entropy=False \ 23 | --config.num_min_qs=1 \ 24 | --offline_relabel_type=pred \ 25 | --use_rnd_offline=True \ 26 | --eval_episodes=10 \ 27 | --project_name=explore-antmaze \ 28 | --seed=0 29 | ``` 30 | 31 | ## Adroit Binary 32 | 33 | First, download and unzip `.npy` files into `~/.datasets/awac-data/` from [here](https://drive.google.com/file/d/1SsVaQKZnY5UkuR78WrInp9XxTdKHbF0x/view). 34 | 35 | Make sure you have `mjrl` installed: 36 | ```bash 37 | git clone https://github.com/aravindr93/mjrl 38 | cd mjrl 39 | pip install -e . 40 | ``` 41 | 42 | Then, recursively clone `mj_envs` from this fork: 43 | ```bash 44 | git clone --recursive https://github.com/philipjball/mj_envs.git 45 | ``` 46 | 47 | Then sync the submodules (add the `--init` flag if you didn't recursively clone): 48 | ```bash 49 | $ cd mj_envs 50 | $ git submodule update --remote 51 | ``` 52 | 53 | Finally: 54 | ```bash 55 | $ pip install -e . 56 | ``` 57 | 58 | Now you can run the following in this directory 59 | ```bash 60 | XLA_PYTHON_CLIENT_PREALLOCATE=false python train_finetuning.py \ 61 | --env_name=pen-binary-v0 \ 62 | --max_steps=1000000 \ 63 | --config.backup_entropy=False \ 64 | --offline_relabel_type=pred \ 65 | --use_rnd_offline=True \ 66 | --eval_episodes=10 \ 67 | --project_name=explore-adroit \ 68 | --seed=0 69 | ``` 70 | 71 | ```bash 72 | XLA_PYTHON_CLIENT_PREALLOCATE=false python train_finetuning.py \ 73 | --env_name=relocate-binary-v0 \ 74 | --reset_rm_every=1000 \ 75 | --max_steps=1000000 \ 76 | --config.backup_entropy=False \ 77 | --offline_relabel_type=pred \ 78 | --use_rnd_offline=True \ 79 | --project_name=explore-adroit \ 80 | --eval_episodes=10 \ 81 | --seed=0 82 | ``` 83 | 84 | ## COG 85 | Based on https://github.com/avisingh599/cog. 86 | 87 | First, install roboverse for COG: https://github.com/avisingh599/roboverse, for the environment. Follow the instructions in `data/README.md` to obtain the dataset. 88 | 89 | ```bash 90 | XLA_PYTHON_CLIENT_PREALLOCATE=false python train_finetuning_pixels.py \ 91 | --env_name=Widow250PickTray-v0 \ 92 | --project_name=explore-cog \ 93 | --offline_relabel_type=pred \ 94 | --use_rnd_offline=True \ 95 | --use_icvf=True \ 96 | --dataset_subsample_ratio=0.1 \ 97 | --seed=0 98 | ``` 99 | 100 | ## For SLURM and reproducing the paper results 101 | 102 | `./generate_all.sh [partition] [conda environment] [resource constraint]` will generate all the sbatch scripts (with three seeds. The paper uses 10 for AntMaze and Adroit, 20 for COG) under `sbatch/` and `./submit_all.sh` will launch all of them. 103 | 104 | ### To reproduce main figures in the paper 105 | Assuming that all the experiments above completed successfully with wandb tracking. The following steps can be followed to generate paper-style figures. 106 | ``` 107 | cd plotting 108 | python wandb_dl.py --entity=[YOUR WANDB ENTITY/USERNAME] --domain=antmaze --project_name=release-explore-antmaze 109 | python wandb_dl.py --entity=[YOUR WANDB ENTITY/USERNAME] --domain=adroit --project_name=release-explore-adroit 110 | python wandb_dl.py --entity=[YOUR WANDB ENTITY/USERNAME] --domain=cog --project_name=release-explore-cog 111 | 112 | python make_plots.py --domain=all 113 | ``` 114 | 115 | ### To reproduce Figure 2 116 | ``` 117 | run_antmaze.sh 118 | cd plotting 119 | python visualize_maze.py 120 | ``` 121 | 122 | # Bibtex 123 | ``` 124 | @inproceedings{ 125 | li2023accelerating, 126 | title={Accelerating Exploration with Unlabeled Prior Data}, 127 | author={Qiyang Li and Jason Zhang and Dibya Ghosh and Amy Zhang and Sergey Levine}, 128 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 129 | year={2023}, 130 | url={https://openreview.net/forum?id=Itorzn4Kwf} 131 | } 132 | ``` 133 | -------------------------------------------------------------------------------- /configs/bc_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import ml_collections 9 | 10 | 11 | def get_config(): 12 | config = ml_collections.ConfigDict() 13 | 14 | config.model_cls = "BCAgent" 15 | 16 | config.actor_lr = 3e-4 17 | config.hidden_dims = (256, 256, 256) 18 | 19 | return config 20 | -------------------------------------------------------------------------------- /configs/drq_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/configs/drq_config.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from ml_collections.config_dict import config_dict 32 | 33 | from configs import pixel_config 34 | 35 | 36 | def get_config(): 37 | config = pixel_config.get_config() 38 | 39 | config.model_cls = "DrQLearner" 40 | 41 | config.actor_lr = 3e-4 42 | config.critic_lr = 3e-4 43 | config.temp_lr = 3e-4 44 | 45 | config.discount = 0.99 46 | 47 | config.num_qs = 2 48 | 49 | config.tau = 0.005 50 | config.init_temperature = 0.1 51 | config.backup_entropy = True 52 | config.target_entropy = config_dict.placeholder(float) 53 | 54 | config.bc_coeff = 0. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /configs/pixel_bc_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import ml_collections 9 | 10 | from configs import pixel_config 11 | 12 | 13 | def get_config(): 14 | config = pixel_config.get_config() 15 | 16 | config.model_cls = "PixelBCAgent" 17 | 18 | config.actor_lr = 3e-4 19 | config.hidden_dims = (256, 256, 256) 20 | 21 | return config 22 | -------------------------------------------------------------------------------- /configs/pixel_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/configs/pixel_config.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | import ml_collections 32 | 33 | 34 | def get_config(): 35 | config = ml_collections.ConfigDict() 36 | 37 | config.hidden_dims = (256, 256) 38 | 39 | config.cnn_features = (32, 64, 128, 256) 40 | config.cnn_filters = (3, 3, 3, 3) 41 | config.cnn_strides = (2, 2, 2, 2) 42 | config.cnn_padding = "VALID" 43 | config.latent_dim = 50 44 | config.encoder = "d4pg" 45 | 46 | return config -------------------------------------------------------------------------------- /configs/pixel_rm_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | 9 | from ml_collections.config_dict import config_dict 10 | 11 | from configs import pixel_config 12 | 13 | def get_config(): 14 | config = pixel_config.get_config() 15 | 16 | config.model_cls = "PixelRM" 17 | config.lr = 3e-4 18 | config.hidden_dims = (256, 256) 19 | return config 20 | 21 | -------------------------------------------------------------------------------- /configs/pixel_rnd_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | 9 | from ml_collections.config_dict import config_dict 10 | 11 | from configs import pixel_config 12 | 13 | def get_config(): 14 | config = pixel_config.get_config() 15 | 16 | config.model_cls = "PixelRND" 17 | config.lr = 3e-4 18 | config.hidden_dims = (256, 256) 19 | config.coeff = 1. 20 | return config 21 | -------------------------------------------------------------------------------- /configs/rlpd_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/configs/rlpd_config.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from configs import sac_config 32 | 33 | 34 | def get_config(): 35 | config = sac_config.get_config() 36 | 37 | config.num_qs = 10 38 | config.num_min_qs = 2 39 | config.critic_layer_norm=True 40 | 41 | return config 42 | -------------------------------------------------------------------------------- /configs/rlpd_pixels_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/configs/rlpd_pixels_config.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from ml_collections.config_dict import config_dict 32 | 33 | from configs import drq_config 34 | 35 | 36 | def get_config(): 37 | config = drq_config.get_config() 38 | 39 | config.num_qs = 10 40 | config.num_min_qs = 1 41 | 42 | config.critic_layer_norm = True 43 | config.backup_entropy = False 44 | 45 | return config 46 | -------------------------------------------------------------------------------- /configs/rm_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import ml_collections 9 | 10 | def get_config(): 11 | config = ml_collections.ConfigDict() 12 | 13 | config.model_cls = "RM" 14 | config.lr = 3e-4 15 | config.hidden_dims = (256, 256, 256) 16 | return config 17 | -------------------------------------------------------------------------------- /configs/rnd_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import ml_collections 9 | 10 | def get_config(): 11 | config = ml_collections.ConfigDict() 12 | 13 | config.model_cls = "RND" 14 | config.lr = 3e-4 15 | config.hidden_dims = (256, 256, 256) 16 | config.coeff = 1. 17 | return config 18 | -------------------------------------------------------------------------------- /configs/sac_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/configs/sac_config.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from ml_collections.config_dict import config_dict 32 | 33 | from configs import td_config 34 | 35 | 36 | def get_config(): 37 | config = td_config.get_config() 38 | 39 | config.model_cls = "SACLearner" 40 | 41 | config.temp_lr = 3e-4 42 | 43 | config.init_temperature = 1.0 44 | config.target_entropy = config_dict.placeholder(float) 45 | 46 | config.backup_entropy = True 47 | 48 | return config 49 | -------------------------------------------------------------------------------- /configs/td_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/configs/td_config.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | import ml_collections 32 | 33 | 34 | def get_config(): 35 | config = ml_collections.ConfigDict() 36 | 37 | config.actor_lr = 3e-4 38 | config.critic_lr = 3e-4 39 | 40 | config.hidden_dims = (256, 256, 256) 41 | 42 | config.discount = 0.99 43 | config.num_qs = 2 44 | config.tau = 0.005 45 | config.bc_coeff = 0.0 46 | 47 | return config 48 | -------------------------------------------------------------------------------- /create_env.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | #!/bin/bash 9 | CONDA_DIR="$(conda info --base)" 10 | source "${CONDA_DIR}/etc/profile.d/conda.sh" 11 | 12 | conda create -n explore python=3.10 13 | conda activate explore 14 | pip install --upgrade pip 15 | pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 16 | pip install -r requirements.txt 17 | pip install "cython<3" patchelf 18 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | First, download the data from here: https://drive.google.com/drive/folders/1jxBQE1adsFT1sWsfatbhiZG6Zkf3EW0Q 2 | 3 | After downloading the data, for each environment (Pick and Place, Closed Drawer, and Blocked Drawer 1) organize it into the structure shown in the code block below. The "prior data" in the context of COG refers to the trajectories from the first subtask in the compositional task, and the "task data" are trajectories from the second subtask that leads to the completed task. The drawer environments each have their own "prior data", but share the "task data" "drawer_task.npy". 4 | 5 | ``` 6 | [ENV_NAME]/PRIOR_DATA.npy 7 | [ENV_NAME]/TASK_DATA.npy 8 | [ENV_NAME]/successful/prior_success.npy 9 | [ENV_NAME]/successful/task_success.npy 10 | ``` 11 | In our code, we assume ENV_NAME is the following: pickplace for the Pick and Place task (Widow250PickTray-v0), closeddrawer_small for the Closed Drawer task (Widow250DoubleDrawerOpenGraspNeutral-v0), and blockeddrawer1_small for the Blocked Drawer 1 task (Widow250DoubleDrawerCloseOpenGraspNeutral-v0). PRIOR_DATA.npy and TASK_DATA.npy can be named as you wish. 12 | 13 | We also provide some successful trajectories used for evaluation and visualization purposes. These trajectories can be found under ./data/successful_trajectories. Place them in a successful folder in the respective environment when setting up the data. 14 | 15 | The drawer prior and task data may be big. The code below was used to subsample from the datasets in drawer environment and also zero out the rewards. Note that even if subsampling is not done, it is still important to zero out the rewards in the "prior data". It may be inconsistent in the datasets, but the prior data sometimes contains +1 rewards for completing the subtask, which would need to be zeroed since they do not complete the compositional task. The rewards in the task data should align with the compositional task and are fine to keep. The curating of these reward labels is primarily for the RLPD baseline. 16 | 17 | ``` 18 | path = 'NPYFILEPATH' 19 | data = np.load(path, allow_pickle=True) 20 | for i in range(len(data)): 21 | for j in range(len(data[i]['rewards'])): 22 | data[i]['rewards'][j] *= 0 23 | data_small = data[np.random.choice(range(len(data)), size=SMALLER_SIZE_HERE (ex. 2500), replace=False)] 24 | np.save('ZEROED_SMALL_PATH', data_small) 25 | ``` 26 | -------------------------------------------------------------------------------- /generate_adroit.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | # Ours 9 | python submit.py -j 2 --name adroit-main --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 10 | --env_name=pen-binary-v0,door-binary-v0,relocate-binary-v0 \ 11 | --max_steps=1000000 \ 12 | --config.backup_entropy=False \ 13 | --project_name=release-explore-adroit \ 14 | --offline_relabel_type=pred \ 15 | --seed=0,1,2 \ 16 | --use_rnd_offline=True,False 17 | 18 | # Reset reward function for relocate 19 | python submit.py -j 2 --name adroit-main-relocate-reset --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 20 | --env_name=relocate-binary-v0 \ 21 | --reset_rm_every=1000 \ 22 | --max_steps=1000000 \ 23 | --config.backup_entropy=False \ 24 | --project_name=release-explore-adroit \ 25 | --offline_relabel_type=pred \ 26 | --seed=0,1,2 \ 27 | --use_rnd_offline=True,False 28 | 29 | # BC + JSRL 30 | python submit.py -j 2 --name adroit-bc-jsrl --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 31 | --env_name=pen-binary-v0,door-binary-v0,relocate-binary-v0 \ 32 | --max_steps=1000000 \ 33 | --config.backup_entropy=False \ 34 | --project_name=release-explore-adroit \ 35 | --bc_pretrain_rollin=0.5 \ 36 | --offline_ratio=0.0 \ 37 | --seed=0,1,2 \ 38 | --bc_pretrain_steps=100000 39 | 40 | # Naive + BC 41 | python submit.py -j 2 --name adroit-bc --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 42 | --env_name=pen-binary-v0,door-binary-v0,relocate-binary-v0 \ 43 | --max_steps=1000000 \ 44 | --config.backup_entropy=False \ 45 | --project_name=release-explore-adroit \ 46 | --seed=0,1,2 \ 47 | --offline_relabel_type=pred \ 48 | --offline_ratio=0.5 \ 49 | --config.bc_coeff=0.01 50 | 51 | # Oracle 52 | python submit.py -j 2 --name adroit-oracle --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 53 | --env_name=pen-binary-v0,door-binary-v0,relocate-binary-v0 \ 54 | --max_steps=1000000 \ 55 | --config.backup_entropy=False \ 56 | --project_name=release-explore-adroit \ 57 | --offline_relabel_type=gt \ 58 | --offline_ratio=0.5 \ 59 | --seed=0,1,2 60 | 61 | # Online, Online + RND 62 | python submit.py -j 2 --name adroit-online --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 63 | --env_name=pen-binary-v0,door-binary-v0,relocate-binary-v0 \ 64 | --max_steps=1000000 \ 65 | --config.backup_entropy=False \ 66 | --project_name=release-explore-adroit \ 67 | --offline_relabel_type=gt \ 68 | --offline_ratio=0 \ 69 | --seed=0,1,2 \ 70 | --use_rnd_online=True,False 71 | 72 | # Min 73 | python submit.py -j 2 --name adroit-min --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 74 | --env_name=pen-binary-v0,door-binary-v0,relocate-binary-v0 \ 75 | --max_steps=1000000 \ 76 | --config.backup_entropy=False \ 77 | --project_name=release-explore-adroit \ 78 | --offline_relabel_type=min \ 79 | --offline_ratio=0.5 \ 80 | --seed=0,1,2 81 | -------------------------------------------------------------------------------- /generate_all.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | #!/bin/bash 9 | mkdir -p sbatch 10 | ./generate_antmaze.sh $1 $2 $3 11 | ./generate_adroit.sh $1 $2 $3 12 | ./generate_cog.sh $1 $2 $3 13 | -------------------------------------------------------------------------------- /generate_antmaze.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | # Ours 9 | python submit.py -j 3 --name antmaze-main --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 10 | --env_name=antmaze-umaze-v2,antmaze-umaze-diverse-v2,antmaze-medium-diverse-v2,antmaze-medium-play-v2,antmaze-large-play-v2,antmaze-large-diverse-v2 \ 11 | --max_steps=300000 \ 12 | --config.backup_entropy=False \ 13 | --config.num_min_qs=1 \ 14 | --project_name=release-explore-antmaze \ 15 | --offline_relabel_type=pred \ 16 | --use_rnd_offline=True,False \ 17 | --seed=0,1,2 18 | 19 | # BC + JSRL 20 | python submit.py -j 3 --name antmaze-bc-jsrl --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 21 | --env_name=antmaze-umaze-v2,antmaze-umaze-diverse-v2,antmaze-medium-diverse-v2,antmaze-medium-play-v2,antmaze-large-play-v2,antmaze-large-diverse-v2 \ 22 | --max_steps=300000 \ 23 | --config.backup_entropy=False \ 24 | --config.num_min_qs=1 \ 25 | --project_name=release-explore-antmaze \ 26 | --bc_pretrain_rollin=0.9 \ 27 | --offline_ratio=0.0 \ 28 | --bc_pretrain_steps=5000 \ 29 | --seed=0,1,2 30 | 31 | # Naive + BC 32 | python submit.py -j 3 --name antmaze-bc --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 33 | --env_name=antmaze-umaze-v2,antmaze-umaze-diverse-v2,antmaze-medium-diverse-v2,antmaze-medium-play-v2,antmaze-large-play-v2,antmaze-large-diverse-v2 \ 34 | --max_steps=300000 \ 35 | --config.backup_entropy=False \ 36 | --config.num_min_qs=1 \ 37 | --project_name=release-explore-antmaze \ 38 | --offline_relabel_type=pred \ 39 | --offline_ratio=0.5 \ 40 | --config.bc_coeff=0.01 \ 41 | --seed=0,1,2 42 | 43 | # Oracle 44 | python submit.py -j 3 --name antmaze-oracle --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 45 | --env_name=antmaze-umaze-v2,antmaze-umaze-diverse-v2,antmaze-medium-diverse-v2,antmaze-medium-play-v2,antmaze-large-play-v2,antmaze-large-diverse-v2 \ 46 | --max_steps=300000 \ 47 | --config.backup_entropy=False \ 48 | --config.num_min_qs=1 \ 49 | --project_name=release-explore-antmaze \ 50 | --offline_relabel_type=gt \ 51 | --offline_ratio=0.5 \ 52 | --seed=0,1,2 53 | 54 | # Online, Online + RND 55 | python submit.py -j 3 --name antmaze-online --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 56 | --env_name=antmaze-umaze-v2,antmaze-umaze-diverse-v2,antmaze-medium-diverse-v2,antmaze-medium-play-v2,antmaze-large-play-v2,antmaze-large-diverse-v2 \ 57 | --max_steps=300000 \ 58 | --config.backup_entropy=False \ 59 | --config.num_min_qs=1 \ 60 | --project_name=release-explore-antmaze \ 61 | --offline_relabel_type=gt \ 62 | --offline_ratio=0 \ 63 | --use_rnd_online=True,False \ 64 | --seed=0,1,2 65 | 66 | # Min 67 | python submit.py -j 3 --name antmaze-min --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning.py \ 68 | --env_name=antmaze-umaze-v2,antmaze-umaze-diverse-v2,antmaze-medium-diverse-v2,antmaze-medium-play-v2,antmaze-large-play-v2,antmaze-large-diverse-v2 \ 69 | --max_steps=300000 \ 70 | --config.backup_entropy=False \ 71 | --config.num_min_qs=1 \ 72 | --project_name=release-explore-antmaze \ 73 | --offline_relabel_type=min \ 74 | --offline_ratio=0.5 \ 75 | --seed=0,1,2 76 | -------------------------------------------------------------------------------- /generate_cog.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | # Ours-ICVF 9 | python submit.py -j 2 --name cog-main-ours --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning_pixels.py \ 10 | --env_name=Widow250PickTray-v0,Widow250DoubleDrawerOpenGraspNeutral-v0,Widow250DoubleDrawerCloseOpenGraspNeutral-v0 \ 11 | --project_name=release-explore-cog \ 12 | --offline_relabel_type=pred \ 13 | --seed=0,1,2 \ 14 | --use_rnd_offline=True,False \ 15 | --use_icvf=True,False \ 16 | --checkpoint_model=True 17 | 18 | # Oracle 19 | python submit.py -j 3 --name cog-main-oracle --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning_pixels.py \ 20 | --env_name=Widow250PickTray-v0,Widow250DoubleDrawerOpenGraspNeutral-v0,Widow250DoubleDrawerCloseOpenGraspNeutral-v0 \ 21 | --project_name=release-explore-cog \ 22 | --offline_relabel_type=gt \ 23 | --seed=0,1,2 24 | 25 | # BC + JSRL 26 | python submit.py -j 3 --name cog-bc-jsrl --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning_pixels.py \ 27 | --env_name=Widow250PickTray-v0,Widow250DoubleDrawerOpenGraspNeutral-v0,Widow250DoubleDrawerCloseOpenGraspNeutral-v0 \ 28 | --project_name=release-explore-cog \ 29 | --bc_pretrain_rollin=0.5 \ 30 | --bc_pretrain_steps=100000 \ 31 | --offline_ratio=0.0 \ 32 | --seed=0,1,2 33 | 34 | # Naive + BC 35 | python submit.py -j 3 --name cog-bc --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning_pixels.py \ 36 | --env_name=Widow250PickTray-v0,Widow250DoubleDrawerOpenGraspNeutral-v0,Widow250DoubleDrawerCloseOpenGraspNeutral-v0 \ 37 | --project_name=release-explore-cog \ 38 | --seed=0,1,2 \ 39 | --offline_relabel_type=pred \ 40 | --offline_ratio=0.5 \ 41 | --config.bc_coeff=0.01 42 | 43 | # Online, Online + RND 44 | python submit.py -j 3 --name cog-online --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning_pixels.py \ 45 | --env_name=Widow250PickTray-v0,Widow250DoubleDrawerOpenGraspNeutral-v0,Widow250DoubleDrawerCloseOpenGraspNeutral-v0 \ 46 | --project_name=release-explore-cog \ 47 | --offline_ratio=0 \ 48 | --seed=0,1,2 \ 49 | --use_rnd_online=True,False 50 | 51 | # Min 52 | python submit.py -j 3 --name cog-min --partition $1 --conda_env_name $2 --constraint $3 python train_finetuning_pixels.py \ 53 | --env_name=Widow250PickTray-v0,Widow250DoubleDrawerOpenGraspNeutral-v0,Widow250DoubleDrawerCloseOpenGraspNeutral-v0 \ 54 | --project_name=release-explore-cog \ 55 | --offline_relabel_type=min \ 56 | --offline_ratio=0.5 \ 57 | --seed=0,1,2 58 | -------------------------------------------------------------------------------- /plotting/filter_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pickle as pkl 8 | from wandb_dl import Data 9 | from itertools import chain 10 | 11 | import numpy as np 12 | 13 | def config_match(conf_src, conf_tgt): 14 | for k, v in conf_tgt.items(): 15 | if v == "null": 16 | if k not in conf_src: 17 | print("skipping", k) 18 | continue 19 | 20 | if k not in conf_src: return False 21 | if type(v) == tuple: 22 | min_v, max_v = v 23 | if conf_src[k] >= max_v: return False 24 | if conf_src[k] <= min_v: return False 25 | else: 26 | if conf_src[k] != v: return False 27 | return True 28 | 29 | 30 | def normalize_performance(x, env_name): 31 | if env_name in ["door-binary-v0", "relocate-binary-v0"]: 32 | return (x + 200.) / 200. 33 | if env_name in ["pen-binary-v0"]: 34 | return (x + 100.) / 100. 35 | if env_name in ["Widow250PickTray-v0"]: 36 | return x / 40. 37 | if env_name in ["Widow250DoubleDrawerOpenGraspNeutral-v0"]: 38 | return x / 50. 39 | if env_name in ["Widow250DoubleDrawerCloseOpenGraspNeutral-v0"]: 40 | return x / 80. 41 | return x 42 | 43 | def filter(data, conf): 44 | matched_runs = [] 45 | indices = [] 46 | 47 | if type(data) == list: 48 | iters = chain([d._runs for d in data]) 49 | else: 50 | iters = data._runs 51 | for index, run in enumerate(iters): 52 | if config_match(run["config"], conf): 53 | matched_runs.append(run) 54 | indices.append(index) 55 | return matched_runs, indices 56 | 57 | def interp(x1, x2, y1, y2, x): 58 | dy = y2 - y1 59 | dx = x2 - x1 60 | return (x - x1) / dx * dy + y1 61 | 62 | def get_runs_with_metrics(data, conf, step_metric, value_metric, anc_steps, min_length_v=30, num_seeds=20, patch_name="arxiv-patch", extra_args="", is_pixel=False): 63 | metric_list = [] 64 | runs, indices = filter(data, conf) 65 | min_length = None 66 | 67 | ids = [] 68 | seeds = set() 69 | for run, idt in zip(runs, indices): 70 | 71 | if value_metric not in run: continue 72 | if step_metric not in run: continue 73 | 74 | steps_rs, steps = run[step_metric] 75 | values_rs, values = run[value_metric] 76 | 77 | env_steps = [] 78 | step_index = 0 79 | failed = False 80 | for value_r, value in zip(values_rs, values): 81 | step_r = steps_rs[step_index] 82 | while step_r < value_r and step_index < len(steps_rs) - 1: 83 | step_index += 1 84 | step_r = steps_rs[step_index] 85 | if step_r < value_r: 86 | values = values[:len(env_steps)] 87 | break 88 | env_steps.append(steps[step_index]) 89 | 90 | if failed: continue 91 | adjusted_values = [] 92 | step_index = 0 93 | for anc_step in anc_steps: 94 | env_step = env_steps[step_index] 95 | while env_step < anc_step and step_index < len(values) - 1: 96 | step_index += 1 97 | env_step = env_steps[step_index] 98 | if env_step < anc_step: 99 | failed = True 100 | break 101 | if step_index == 0: 102 | adjusted_values.append(values[step_index]) 103 | else: 104 | adjusted_values.append(interp( 105 | env_steps[step_index - 1], env_steps[step_index], 106 | values[step_index - 1], values[step_index], 107 | anc_step, 108 | )) 109 | if failed: continue 110 | 111 | if len(adjusted_values) < min_length_v: 112 | continue 113 | 114 | seed = run["config"]["seed"] 115 | if seed in seeds: 116 | continue 117 | seeds.add(seed) 118 | 119 | metric_list.append(adjusted_values) 120 | 121 | if min_length is None or min_length > len(adjusted_values): 122 | min_length = len(adjusted_values) 123 | 124 | if len(metric_list) == 0: 125 | return None 126 | 127 | additional_str = "" 128 | for k, v in conf.items(): 129 | additional_str += f" --{k}={v}" 130 | additional_str += f" {extra_args}" 131 | for desired_seed in range(num_seeds): 132 | if desired_seed not in seeds: 133 | if is_pixel: 134 | cmd_file = "train_finetuning_pixels" 135 | else: 136 | cmd_file = "train_finetuning" 137 | print(f"python {cmd_file}.py --project_name={patch_name} --seed={desired_seed}" + additional_str) 138 | 139 | metrics = np.stack([metric[:min_length] for metric in metric_list], axis=0) 140 | metrics = normalize_performance(metrics, conf["env_name"]) 141 | return anc_steps[:min_length].copy(), metrics, ids 142 | 143 | if __name__ == '__main__': 144 | pass 145 | -------------------------------------------------------------------------------- /plotting/visualize_maze.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import numpy as np 9 | import sys 10 | sys.path.append('../') 11 | 12 | import matplotlib.pyplot as plt 13 | from mpl_toolkits.axes_grid1 import ImageGrid 14 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 15 | from mpl_toolkits.axes_grid1 import make_axes_locatable 16 | 17 | from rlpd.data import Dataset 18 | from visualize import * 19 | 20 | env_name = 'antmaze-medium-diverse-v2' 21 | viz_env, _ = get_env_and_dataset(env_name) 22 | 23 | seed=0 24 | buffer_paths = { 25 | 'Online': f'../exp_data/online-s{seed}/buffers/buffer.npz', 26 | 'Online + RND': f'../exp_data/online_rnd-s{seed}/buffers/buffer.npz', 27 | 'Naive': f'../exp_data/naive-s{seed}/buffers/buffer.npz', 28 | 'Ours': f'../exp_data/ours-s{seed}/buffers/buffer.npz', 29 | } 30 | 31 | buffers = {} 32 | for name, path in buffer_paths.items(): 33 | with open(path, 'rb') as f: 34 | buffers[name] = np.load(f)["observations"] 35 | print('min_buffer_length:', min([len(buf) for name, buf in buffers.items()])) 36 | 37 | cutoff = 60000 38 | 39 | # https://stackoverflow.com/questions/13784201/how-to-have-one-colorbar-for-all-subplots 40 | def plot_online_coverage(env, buffers, colorbar_scale=1000): 41 | n_pts = cutoff // 200 42 | 43 | fig = plt.figure(tight_layout=True) 44 | axs = ImageGrid( 45 | fig, 111, nrows_ncols=(1, len(buffers)), 46 | cbar_location='right', cbar_mode='single', cbar_size='5%', cbar_pad=0.05) 47 | 48 | canvas = FigureCanvas(fig) 49 | 50 | for i, (name, buffer) in enumerate(buffers.items()): 51 | axs[i].set_title(name) 52 | axs[i].axis('off') 53 | axs[i].set_box_aspect(1) 54 | env.draw(axs[i]) 55 | 56 | ## add buffer pts 57 | obs = buffer[:cutoff] 58 | idxs = np.arange(len(obs)) 59 | idxs = np.sort(np.random.choice(idxs, size=n_pts, replace=False)) 60 | x, y = obs[idxs, 0], obs[idxs, 1] 61 | scatter = axs[i].scatter(x, y, c=idxs // colorbar_scale, **dict(alpha=0.75, s=5, cmap='viridis', marker='o')) 62 | 63 | axs[-1].cax.colorbar(scatter, label="Env Steps $\\left(\\times 10^3\\right)$", 64 | ticks=range(0, cutoff // colorbar_scale, 15)) 65 | axs[-1].cax.toggle_label(True) 66 | 67 | image = get_canvas_image(canvas) 68 | plt.savefig(f'../plotting/figures/antmaze-exploration-{cutoff}.pdf', bbox_inches="tight") 69 | # plt.savefig(f'../plotting/figures/antmaze-exploration-{cutoff}.png', bbox_inches="tight") 70 | plt.close(fig) 71 | return image 72 | 73 | img = plot_online_coverage(viz_env, buffers) 74 | -------------------------------------------------------------------------------- /plotting/visualize_reward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | 9 | import types 10 | import sys 11 | sys.path.append('../') 12 | 13 | import numpy as np 14 | 15 | import matplotlib.pyplot as plt 16 | from matplotlib.offsetbox import (DrawingArea, OffsetImage, AnnotationBbox) 17 | 18 | import roboverse 19 | import types 20 | from visualize import * 21 | 22 | from flax.training import checkpoints 23 | from flax.core import frozen_dict 24 | 25 | import roboverse 26 | 27 | from rlpd.agents import PixelRND, PixelRM 28 | from rlpd.wrappers import wrap_pixels 29 | from gym.wrappers import FilterObservation, TimeLimit, RecordEpisodeStatistics 30 | 31 | from collections import defaultdict 32 | 33 | ###### LOAD SUCCESSFUL PRIOR TRAJECTORY ###### 34 | 35 | successful_task1_path = '../data/closeddrawer_small/successful/prior_success.npy' 36 | successful_task2_path = '../data/closeddrawer_small/successful/task_success.npy' 37 | 38 | def dict_to_list(D): 39 | # https://stackoverflow.com/questions/5558418/list-of-dicts-to-from-dict-of-lists 40 | return [dict(zip(D, t)) for t in zip(*D.values())] 41 | 42 | def make_data_dict(tran): 43 | return dict( 44 | observations={"pixels": np.array(tran["observations"]["image"])[..., None]}, 45 | actions=np.array(tran["actions"]), 46 | next_observations={"pixels": np.array(tran["next_observations"]["image"])[..., None]}, 47 | rewards=np.array(tran["rewards"]), 48 | masks=1-np.array(tran["terminals"], dtype=float), 49 | dones=np.array(tran["agent_infos"]["done"]) 50 | ) 51 | 52 | t1 = np.load(successful_task1_path, allow_pickle=True) 53 | t2 = np.load(successful_task2_path, allow_pickle=True) 54 | successful_t1_trajs = [] 55 | successful_t2_trajs = [] 56 | 57 | for traj in t1: 58 | trans = dict_to_list(traj) 59 | trans = [make_data_dict(tran) for tran in trans] 60 | successful_t1_trajs.append(trans) 61 | 62 | for traj in t2: 63 | trans = dict_to_list(traj) 64 | trans = [make_data_dict(tran) for tran in trans] 65 | successful_t2_trajs.append(trans) 66 | 67 | successful_trajs = [successful_t1_trajs[i] + successful_t2_trajs[i] \ 68 | for i in range(min(len(successful_t1_trajs), len(successful_t2_trajs)))] 69 | images = [] 70 | for traj in successful_trajs: 71 | images.append([]) 72 | for tran in traj: 73 | images[-1].append(tran['observations']['pixels'].squeeze()) 74 | 75 | ###### RECREATE TRAIN STATE ###### 76 | 77 | def wrap(env): 78 | return wrap_pixels( 79 | env, 80 | action_repeat=1, 81 | image_size=48, 82 | num_stack=1, 83 | camera_id=0, 84 | ) 85 | 86 | def render(env, *args, **kwargs): 87 | return env.render_obs() 88 | 89 | env_name = "Widow250DoubleDrawerOpenGraspNeutral-v0" 90 | 91 | env = roboverse.make(env_name, transpose_image=False) 92 | env.render = types.MethodType(render, env) 93 | env = FilterObservation(env, ['image']) 94 | env = TimeLimit(env, max_episode_steps=50) 95 | env, pixel_keys = wrap(env) 96 | env = RecordEpisodeStatistics(env, deque_size=1) 97 | env.seed(0) 98 | 99 | rnd_kwargs = dict( 100 | cnn_features = (32, 64, 128, 256), 101 | cnn_filters = (3, 3, 3, 3), 102 | cnn_strides = (2, 2, 2, 2), 103 | cnn_padding = "VALID", 104 | latent_dim = 50, 105 | encoder = "d4pg", 106 | lr=3e-4, 107 | hidden_dims=(256, 256), 108 | coeff=1. 109 | ) 110 | 111 | rnd_base = PixelRND.create( 112 | 0, env.observation_space, env.action_space, pixel_keys=pixel_keys, **rnd_kwargs 113 | ) 114 | 115 | rm_kwargs = dict( 116 | cnn_features = (32, 64, 128, 256), 117 | cnn_filters = (3, 3, 3, 3), 118 | cnn_strides = (2, 2, 2, 2), 119 | cnn_padding = "VALID", 120 | latent_dim = 50, 121 | encoder = "d4pg", 122 | lr = 3e-4, 123 | hidden_dims = (256, 256), 124 | ) 125 | 126 | rm_base = PixelRM.create( 127 | 0, env.observation_space, env.action_space, pixel_keys=pixel_keys, **rm_kwargs 128 | ) 129 | 130 | ###### EVALUATE AND COLLECT REWARDS ###### 131 | seeds = list(range(20)) 132 | env_step = 25000 133 | 134 | rm = PixelRM.create( 135 | 0, env.observation_space, env.action_space, pixel_keys=pixel_keys, **rm_kwargs 136 | ) 137 | icvf_rm = PixelRM.create( 138 | 1, env.observation_space, env.action_space, pixel_keys=pixel_keys, **rm_kwargs 139 | ) 140 | rnd = PixelRND.create( 141 | 2, env.observation_space, env.action_space, pixel_keys=pixel_keys, **rnd_kwargs 142 | ) 143 | icvf_rnd = PixelRND.create( 144 | 3, env.observation_space, env.action_space, pixel_keys=pixel_keys, **rnd_kwargs 145 | ) 146 | 147 | # seeds = [] 148 | icvf_rnd_rewards_ind_seed = [] 149 | rnd_rewards_ind_seed = [] 150 | icvf_rm_rewards_ind_seed = [] 151 | rm_rewards_ind_seed = [] 152 | for i, seed in enumerate(seeds): 153 | icvf_rnd_path = f"../exp_data_cog/{env_name}-s{seed}-icvf_True-ours_True/checkpoints/" 154 | rnd_path = f"../exp_data_cog/{env_name}-s{seed}-icvf_False-ours_True/checkpoints/" 155 | icvf_rm_path = f"../exp_data_cog/{env_name}-s{seed}-icvf_True-ours_True/checkpoints/" 156 | rm_path = f"../exp_data_cog/{env_name}-s{seed}-icvf_False-ours_True/checkpoints/" 157 | 158 | icvf_rnd = checkpoints.restore_checkpoint(icvf_rnd_path, target=icvf_rnd, prefix="rnd_checkpoint_", step=env_step) 159 | rnd = checkpoints.restore_checkpoint(rnd_path, target=rnd, prefix="rnd_checkpoint_", step=env_step) 160 | icvf_rm = checkpoints.restore_checkpoint(icvf_rm_path, target=icvf_rm, prefix="rm_checkpoint_", step=env_step) 161 | rm = checkpoints.restore_checkpoint(rm_path, target=rm, prefix="rm_checkpoint_", step=env_step) 162 | 163 | icvf_rnd_rewards_list = defaultdict(list) 164 | rnd_rewards_list = defaultdict(list) 165 | icvf_rm_rewards_list = defaultdict(list) 166 | rm_rewards_list = defaultdict(list) 167 | 168 | for t, successful_traj in enumerate(successful_trajs): 169 | icvf_rnd_rewards, rnd_rewards, icvf_rm_rewards, rm_rewards = \ 170 | [], [], [], [] 171 | for tran in successful_traj: 172 | icvf_rnd_rewards.append(icvf_rnd.get_reward(frozen_dict.freeze(tran)).item()) 173 | rnd_rewards.append(rnd.get_reward(frozen_dict.freeze(tran)).item()) 174 | icvf_rm_rewards.append(icvf_rm.get_reward(frozen_dict.freeze(tran)).item()) 175 | rm_rewards.append(rm.get_reward(frozen_dict.freeze(tran)).item()) 176 | 177 | icvf_rnd_rewards_list[t].append(np.array(icvf_rnd_rewards)) 178 | rnd_rewards_list[t].append(np.array(rnd_rewards)) 179 | icvf_rm_rewards_list[t].append(np.array(icvf_rm_rewards)) 180 | rm_rewards_list[t].append(np.array(rm_rewards)) 181 | 182 | icvf_rnd_rewards = [] 183 | rnd_rewards = [] 184 | icvf_rm_rewards = [] 185 | rm_rewards = [] 186 | for t in range(len(successful_trajs)): 187 | icvf_rnd_rewards.append(np.stack(icvf_rnd_rewards_list[t], axis=0)) 188 | rnd_rewards.append(np.stack(rnd_rewards_list[t], axis=0)) 189 | icvf_rm_rewards.append(np.stack(icvf_rm_rewards_list[t], axis=0)) 190 | rm_rewards.append(np.stack(rm_rewards_list[t], axis=0)) 191 | 192 | ###### MAKING PLOTS ###### 193 | 194 | def plot_reward(icvf_rewards, norm_rewards, images, t): 195 | n, T = norm_rewards.shape 196 | def plot_single(ax, rewards, label): 197 | # normalize 198 | mean_traj = rewards.mean(axis=1, keepdims=True) 199 | std_traj = rewards.std(axis=1, keepdims=True) 200 | rewards = (rewards - mean_traj) / (std_traj + 1e-5) 201 | 202 | mean = rewards.mean(axis=0) 203 | sterr = rewards.std(axis=0) / np.sqrt(n) 204 | 205 | ax.plot(range(T), mean, label=label, linewidth=10) 206 | ax.fill_between(range(T), mean - sterr, mean + sterr, alpha=0.25) 207 | 208 | fig, ax = plt.subplots(figsize=(15, 5)) 209 | plot_single(ax, icvf_rewards, 'Ours + ICVF') 210 | plot_single(ax, norm_rewards, 'Ours') 211 | 212 | for i in range(0, T, 5): 213 | image = images[i] 214 | imagebox = OffsetImage(image, zoom=1.7) 215 | imagebox.image.axes = ax 216 | 217 | ab = AnnotationBbox( 218 | imagebox, (i, 0), 219 | xybox=(0, -30), 220 | xycoords=("data", "axes fraction"), 221 | boxcoords="offset points", 222 | box_alignment=(.5, 1), 223 | bboxprops={"edgecolor": "none"} 224 | ) 225 | 226 | ax.add_artist(ab) 227 | 228 | plt.legend(fontsize=20) 229 | plt.xticks(ticks=range(0, T, 5), fontsize=24) 230 | plt.tick_params(left=False, labelleft=False) 231 | plt.title(f'At {env_step // 1000}k Environment Steps', fontsize=28) 232 | plt.xlabel('Trajectory Steps', labelpad=-55, fontsize=24) 233 | plt.ylabel('Normalized Reward', fontsize=24) 234 | plt.tight_layout() 235 | plt.subplots_adjust(bottom=0.3) 236 | plt.savefig(f'icvf_reward_effect-{env_step}-{t}.pdf') 237 | 238 | for t in range(len(images)): 239 | plt.clf() 240 | plot_reward(icvf_rnd_rewards[t] + icvf_rm_rewards[t], 241 | rnd_rewards[t] + rm_rewards[t], 242 | images[t], t) 243 | -------------------------------------------------------------------------------- /plotting/wandb_dl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | 9 | import pickle as pkl 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | from collections import defaultdict 13 | import wandb 14 | import numpy as np 15 | 16 | from matplotlib import rc 17 | import matplotlib.ticker as mticker 18 | import seaborn as sns 19 | 20 | import wandb 21 | import tqdm 22 | import copy 23 | import argparse 24 | 25 | import os 26 | 27 | api = wandb.Api() 28 | 29 | def find_next_hp(config_str): 30 | index = config_str.find("\n") 31 | if index == -1: 32 | return "" 33 | if config_str[index:index + 2] == "\n-": 34 | return find_next_hp(config_str[index + 2:]) 35 | return config_str[index + 1:] 36 | 37 | def parse(config_str): 38 | config = {} 39 | while True: 40 | config_str_nxt = find_next_hp(config_str) 41 | 42 | curr_seg = config_str[:len(config_str)-len(config_str_nxt)] 43 | k, v = curr_seg.split(": ") 44 | config[k] = v[:-1] 45 | 46 | if config_str_nxt == "": 47 | break 48 | config_str = config_str_nxt 49 | return config 50 | 51 | def post_process(config): 52 | config = copy.deepcopy(config) 53 | items = list(config.items()) 54 | for k, v in items: 55 | if type(v) == str and v.find("\n") != -1: 56 | config.pop(k) 57 | for kk, vv in parse(v).items(): 58 | ak = k + "." + kk 59 | config[ak] = vv 60 | 61 | return config 62 | 63 | 64 | def save(path, data): 65 | with open(path, 'wb') as f: 66 | pkl.dump(data, f) 67 | 68 | def load(path): 69 | with open(path, 'rb') as f: 70 | return pkl.load(f) 71 | 72 | class Data: 73 | def __init__(self, proj_names, keys, entity): 74 | self._runs = [] 75 | self.keys = keys 76 | self.entity = entity 77 | 78 | for proj_name in proj_names: 79 | self._collect(proj_name) 80 | 81 | def _collect(self, proj_name): 82 | runs = api.runs(f"{self.entity}/{proj_name}") 83 | for run in tqdm.tqdm(runs): 84 | config = post_process(run.config) 85 | config["project_name"] = proj_name 86 | entry = {"config": config} 87 | for key in self.keys: 88 | df = run.history(keys=[key]) 89 | if len(df.columns) == 0: continue 90 | 91 | steps = df[df.columns[0]].to_numpy() 92 | values = df[key].to_numpy() 93 | entry[key] = (steps, values) 94 | self._runs.append(entry) 95 | 96 | if __name__ == '__main__': 97 | 98 | parser = argparse.ArgumentParser( 99 | prog='Wandb experiments downloader', 100 | description='Download experiment data from weights and biases server' 101 | ) 102 | 103 | parser.add_argument('--entity', type=str) 104 | parser.add_argument('--project_name', type=str) 105 | parser.add_argument('--domain', type=str, help="one of [antmaze, adroit, cog]") 106 | args = parser.parse_args() 107 | 108 | assert args.domain in ["antmaze", "adroit", "cog"] 109 | keys = ["env_step", "evaluation/return"] 110 | if args.domain == "antmaze": 111 | keys.append("coverage") 112 | data = Data([args.project_name], keys=keys, entity=args.entity) 113 | os.makedirs('data', exist_ok=True) 114 | with open(f'data/{args.domain}.pkl', 'wb') as handle: 115 | pkl.dump(data, handle, protocol=pkl.HIGHEST_PROTOCOL) 116 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym >= 0.21.0, < 0.24.0 2 | numpy >= 1.20.2 3 | flax >= 0.5.2 4 | ml_collections >= 0.1.0 5 | tqdm >= 4.60.0 6 | optax >= 0.1.3 7 | absl-py >= 0.12.0 8 | scipy >= 1.6.0 9 | wandb >= 0.12.14 10 | tensorflow-probability >= 0.17.0 11 | matplotlib == 3.4.3 12 | mujoco >= 2.2.2 13 | moviepy >= 1.0.3 14 | imageio >= 2.21.3 15 | d4rl >= 1.1 16 | plotly >= 5.10.0 17 | pandas >= 2.1.0 18 | distrax >= 0.1.4 19 | -------------------------------------------------------------------------------- /rlpd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ExPLORe/fff95eb0e1aa16c1410a2daefed068dfeb5c3620/rlpd/__init__.py -------------------------------------------------------------------------------- /rlpd/agents/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/agents/__init__.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from rlpd.agents.drq.drq_learner import DrQLearner 31 | from rlpd.agents.sac.sac_learner import SACLearner 32 | from rlpd.agents.rm import RM 33 | from rlpd.agents.rnd import RND 34 | from rlpd.agents.bc import BCAgent 35 | from rlpd.agents.drq.bc import PixelBCAgent 36 | from rlpd.agents.drq.rm import PixelRM 37 | from rlpd.agents.drq.rnd import PixelRND 38 | -------------------------------------------------------------------------------- /rlpd/agents/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/agents/agent.py. 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from functools import partial 31 | 32 | import jax 33 | import jax.numpy as jnp 34 | import numpy as np 35 | from flax import struct 36 | from flax.training.train_state import TrainState 37 | 38 | from rlpd.types import PRNGKey 39 | 40 | 41 | @partial(jax.jit, static_argnames="apply_fn") 42 | def _sample_actions(rng, apply_fn, params, observations: np.ndarray) -> np.ndarray: 43 | key, rng = jax.random.split(rng) 44 | dist = apply_fn({"params": params}, observations) 45 | return dist.sample(seed=key), rng 46 | 47 | 48 | @partial(jax.jit, static_argnames="apply_fn") 49 | def _eval_actions(apply_fn, params, observations: np.ndarray) -> np.ndarray: 50 | dist = apply_fn({"params": params}, observations) 51 | return dist.mode() 52 | 53 | 54 | class Agent(struct.PyTreeNode): 55 | actor: TrainState 56 | rng: PRNGKey 57 | 58 | def eval_actions(self, observations: np.ndarray) -> np.ndarray: 59 | actions = _eval_actions(self.actor.apply_fn, self.actor.params, observations) 60 | return np.asarray(actions) 61 | 62 | def sample_actions(self, observations: np.ndarray) -> np.ndarray: 63 | actions, new_rng = _sample_actions( 64 | self.rng, self.actor.apply_fn, self.actor.params, observations 65 | ) 66 | return np.asarray(actions), self.replace(rng=new_rng) 67 | 68 | @jax.jit 69 | def sample(self, observations): 70 | dist = self.actor.apply_fn({"params": self.actor.params}, observations) 71 | key, rng = jax.random.split(self.rng) 72 | actions = dist.sample(seed=key) 73 | return actions, self.replace(rng=rng) 74 | -------------------------------------------------------------------------------- /rlpd/agents/bc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | 10 | from functools import partial 11 | from typing import Dict, Optional, Sequence, Tuple 12 | 13 | import flax 14 | import gym 15 | import jax 16 | import jax.numpy as jnp 17 | import numpy as np 18 | import optax 19 | from flax import struct 20 | from flax.training.train_state import TrainState 21 | from flax.core import FrozenDict 22 | 23 | from rlpd.agents.agent import Agent 24 | from rlpd.data.dataset import DatasetDict 25 | from rlpd.distributions import TanhNormal 26 | from rlpd.networks import MLP 27 | 28 | 29 | class BCAgent(Agent): 30 | @classmethod 31 | def create( 32 | cls, 33 | seed: int, 34 | observation_space: gym.Space, 35 | action_space: gym.Space, 36 | actor_lr: float = 3e-4, 37 | hidden_dims: Sequence[int] = (256, 256), 38 | use_pnorm: bool = False, 39 | ): 40 | """ 41 | An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1812.05905 42 | """ 43 | 44 | action_dim = action_space.shape[-1] 45 | observations = observation_space.sample() 46 | 47 | rng = jax.random.PRNGKey(seed) 48 | rng, actor_key = jax.random.split(rng, 2) 49 | 50 | actor_base_cls = partial( 51 | MLP, hidden_dims=hidden_dims, activate_final=True, use_pnorm=use_pnorm 52 | ) 53 | actor_def = TanhNormal(actor_base_cls, action_dim) 54 | actor_params = FrozenDict(actor_def.init(actor_key, observations)["params"]) 55 | actor = TrainState.create( 56 | apply_fn=actor_def.apply, 57 | params=actor_params, 58 | tx=optax.adam(learning_rate=actor_lr), 59 | ) 60 | 61 | return cls( 62 | rng=rng, 63 | actor=actor, 64 | ) 65 | 66 | def update_actor(self, batch: DatasetDict) -> Tuple[Agent, Dict[str, float]]: 67 | def actor_loss_fn(actor_params) -> Tuple[jnp.ndarray, Dict[str, float]]: 68 | dist = self.actor.apply_fn({"params": actor_params}, batch["observations"]) 69 | actor_loss = -dist.log_prob(batch["actions"]).mean() 70 | return actor_loss, {"pretrain_bc_loss": actor_loss} 71 | 72 | grads, actor_info = jax.grad(actor_loss_fn, has_aux=True)(self.actor.params) 73 | actor = self.actor.apply_gradients(grads=grads) 74 | 75 | return self.replace(actor=actor), actor_info 76 | 77 | @partial(jax.jit, static_argnames="utd_ratio") 78 | def update(self, batch: DatasetDict, utd_ratio: int): 79 | 80 | new_agent = self 81 | for i in range(utd_ratio): 82 | 83 | def slice(x): 84 | assert x.shape[0] % utd_ratio == 0 85 | batch_size = x.shape[0] // utd_ratio 86 | return x[batch_size * i : batch_size * (i + 1)] 87 | 88 | mini_batch = jax.tree_util.tree_map(slice, batch) 89 | new_agent, actor_info = new_agent.update_actor(mini_batch) 90 | 91 | return new_agent, actor_info 92 | -------------------------------------------------------------------------------- /rlpd/agents/drq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ExPLORe/fff95eb0e1aa16c1410a2daefed068dfeb5c3620/rlpd/agents/drq/__init__.py -------------------------------------------------------------------------------- /rlpd/agents/drq/augmentations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/ikostrikov/rlpd/blob/main/rlpd/agents/drq/augmentations.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | """ 28 | 29 | import jax 30 | import jax.numpy as jnp 31 | 32 | 33 | def random_crop(key, img, padding): 34 | crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1) 35 | crop_from = jnp.concatenate([crop_from, jnp.zeros((2,), dtype=jnp.int32)]) 36 | padded_img = jnp.pad( 37 | img, ((padding, padding), (padding, padding), (0, 0), (0, 0)), mode="edge" 38 | ) 39 | return jax.lax.dynamic_slice(padded_img, crop_from, img.shape) 40 | 41 | 42 | def batched_random_crop(key, obs, pixel_key, padding=4): 43 | imgs = obs[pixel_key] 44 | keys = jax.random.split(key, imgs.shape[0]) 45 | imgs = jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding) 46 | return obs.copy(add_or_replace={pixel_key: imgs}) 47 | -------------------------------------------------------------------------------- /rlpd/agents/drq/bc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from functools import partial 10 | from typing import Dict, Optional, Sequence, Tuple 11 | 12 | import flax 13 | import gym 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as np 17 | import optax 18 | from flax import struct 19 | from flax.training.train_state import TrainState 20 | from flax.core import FrozenDict 21 | 22 | from rlpd.agents.agent import Agent 23 | from rlpd.agents.sac.temperature import Temperature 24 | from rlpd.data.dataset import DatasetDict 25 | from rlpd.distributions import TanhNormal 26 | from rlpd.networks import ( 27 | MLP, 28 | PixelMultiplexer, 29 | ) 30 | 31 | from rlpd.networks.encoders import D4PGEncoder 32 | from rlpd.agents.bc import BCAgent 33 | 34 | 35 | class PixelBCAgent(BCAgent): 36 | @classmethod 37 | def create( 38 | cls, 39 | seed: int, 40 | observation_space: gym.Space, 41 | action_space: gym.Space, 42 | actor_lr: float = 3e-4, 43 | cnn_features: Sequence[int] = (32, 32, 32, 32), 44 | cnn_filters: Sequence[int] = (3, 3, 3, 3), 45 | cnn_strides: Sequence[int] = (2, 1, 1, 1), 46 | cnn_padding: str = "VALID", 47 | latent_dim: int = 50, 48 | hidden_dims: Sequence[int] = (256, 256), 49 | pixel_keys: Tuple[str, ...] = ("pixels",), 50 | depth_keys: Tuple[str, ...] = (), 51 | encoder: str = "d4pg", 52 | ): 53 | assert encoder == "d4pg" 54 | action_dim = action_space.shape[-1] 55 | observations = observation_space.sample() 56 | 57 | rng = jax.random.PRNGKey(seed) 58 | rng, actor_key = jax.random.split(rng, 2) 59 | 60 | encoder_cls = partial( 61 | D4PGEncoder, 62 | features=cnn_features, 63 | filters=cnn_filters, 64 | strides=cnn_strides, 65 | padding=cnn_padding, 66 | ) 67 | actor_base_cls = partial(MLP, hidden_dims=hidden_dims, activate_final=True) 68 | actor_cls = partial(TanhNormal, base_cls=actor_base_cls, action_dim=action_dim) 69 | actor_def = PixelMultiplexer( 70 | encoder_cls=encoder_cls, 71 | network_cls=actor_cls, 72 | latent_dim=latent_dim, 73 | stop_gradient=False, 74 | pixel_keys=pixel_keys, 75 | depth_keys=depth_keys, 76 | ) 77 | actor_params = FrozenDict(actor_def.init(actor_key, observations)["params"]) 78 | actor = TrainState.create( 79 | apply_fn=actor_def.apply, 80 | params=actor_params, 81 | tx=optax.adam(learning_rate=actor_lr), 82 | ) 83 | 84 | return cls( 85 | rng=rng, 86 | actor=actor, 87 | ) 88 | -------------------------------------------------------------------------------- /rlpd/agents/drq/drq_learner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/agents/drq/drq_learner.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from functools import partial 31 | from itertools import zip_longest 32 | from typing import Callable, Optional, Sequence, Tuple 33 | 34 | import gym 35 | import jax 36 | import optax 37 | from flax import struct 38 | from flax.training.train_state import TrainState 39 | from flax.core.frozen_dict import FrozenDict 40 | 41 | from rlpd.agents.drq.augmentations import batched_random_crop 42 | from rlpd.agents.sac.sac_learner import SACLearner 43 | from rlpd.agents.sac.temperature import Temperature 44 | from rlpd.data.dataset import DatasetDict 45 | from rlpd.distributions import TanhNormal 46 | from rlpd.networks import MLP, Ensemble, PixelMultiplexer, StateActionValue 47 | from rlpd.networks.encoders import D4PGEncoder 48 | 49 | 50 | # Helps to minimize CPU to GPU transfer. 51 | def _unpack(batch): 52 | # Assuming that if next_observation is missing, it's combined with observation: 53 | for pixel_key in batch["observations"].keys(): 54 | if pixel_key not in batch["next_observations"]: 55 | obs_pixels = batch["observations"][pixel_key][..., :-1] 56 | next_obs_pixels = batch["observations"][pixel_key][..., 1:] 57 | 58 | obs = batch["observations"].copy(add_or_replace={pixel_key: obs_pixels}) 59 | next_obs = batch["next_observations"].copy( 60 | add_or_replace={pixel_key: next_obs_pixels} 61 | ) 62 | 63 | batch = batch.copy( 64 | add_or_replace={"observations": obs, "next_observations": next_obs} 65 | ) 66 | 67 | return batch 68 | 69 | 70 | def _share_encoder(source, target): 71 | replacers = {} 72 | 73 | for k, v in source.params.items(): 74 | if "encoder" in k: 75 | replacers[k] = v 76 | 77 | # Use critic conv layers in actor: 78 | new_params = FrozenDict(target.params).copy(add_or_replace=replacers) 79 | return target.replace(params=new_params) 80 | 81 | 82 | class DrQLearner(SACLearner): 83 | data_augmentation_fn: Callable = struct.field(pytree_node=False) 84 | 85 | @classmethod 86 | def create( 87 | cls, 88 | seed: int, 89 | observation_space: gym.Space, 90 | action_space: gym.Space, 91 | actor_lr: float = 3e-4, 92 | critic_lr: float = 3e-4, 93 | temp_lr: float = 3e-4, 94 | cnn_features: Sequence[int] = (32, 32, 32, 32), 95 | cnn_filters: Sequence[int] = (3, 3, 3, 3), 96 | cnn_strides: Sequence[int] = (2, 1, 1, 1), 97 | cnn_padding: str = "VALID", 98 | latent_dim: int = 50, 99 | encoder: str = "d4pg", 100 | hidden_dims: Sequence[int] = (256, 256), 101 | discount: float = 0.99, 102 | tau: float = 0.005, 103 | num_qs: int = 2, 104 | num_min_qs: Optional[int] = None, 105 | critic_dropout_rate: Optional[float] = None, 106 | critic_layer_norm: bool = False, 107 | target_entropy: Optional[float] = None, 108 | init_temperature: float = 1.0, 109 | backup_entropy: bool = True, 110 | pixel_keys: Tuple[str, ...] = ("pixels",), 111 | depth_keys: Tuple[str, ...] = (), 112 | bc_coeff: float = 0, 113 | ): 114 | """ 115 | An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1812.05905 116 | """ 117 | 118 | action_dim = action_space.shape[-1] 119 | observations = observation_space.sample() 120 | actions = action_space.sample() 121 | 122 | if target_entropy is None: 123 | target_entropy = -action_dim / 2 124 | 125 | rng = jax.random.PRNGKey(seed) 126 | rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4) 127 | 128 | if encoder == "d4pg": 129 | encoder_cls = partial( 130 | D4PGEncoder, 131 | features=cnn_features, 132 | filters=cnn_filters, 133 | strides=cnn_strides, 134 | padding=cnn_padding, 135 | ) 136 | else: 137 | raise NotImplementedError 138 | 139 | actor_base_cls = partial(MLP, hidden_dims=hidden_dims, activate_final=True) 140 | actor_cls = partial(TanhNormal, base_cls=actor_base_cls, action_dim=action_dim) 141 | actor_def = PixelMultiplexer( 142 | encoder_cls=encoder_cls, 143 | network_cls=actor_cls, 144 | latent_dim=latent_dim, 145 | stop_gradient=True, 146 | pixel_keys=pixel_keys, 147 | depth_keys=depth_keys, 148 | ) 149 | actor_params = FrozenDict(actor_def.init(actor_key, observations)["params"]) 150 | actor = TrainState.create( 151 | apply_fn=actor_def.apply, 152 | params=actor_params, 153 | tx=optax.adam(learning_rate=actor_lr), 154 | ) 155 | 156 | critic_base_cls = partial( 157 | MLP, 158 | hidden_dims=hidden_dims, 159 | activate_final=True, 160 | dropout_rate=critic_dropout_rate, 161 | use_layer_norm=critic_layer_norm, 162 | ) 163 | critic_cls = partial(StateActionValue, base_cls=critic_base_cls) 164 | critic_cls = partial(Ensemble, net_cls=critic_cls, num=num_qs) 165 | critic_def = PixelMultiplexer( 166 | encoder_cls=encoder_cls, 167 | network_cls=critic_cls, 168 | latent_dim=latent_dim, 169 | pixel_keys=pixel_keys, 170 | depth_keys=depth_keys, 171 | ) 172 | critic_params = FrozenDict( 173 | critic_def.init(critic_key, observations, actions)["params"] 174 | ) 175 | critic = TrainState.create( 176 | apply_fn=critic_def.apply, 177 | params=critic_params, 178 | tx=optax.adam(learning_rate=critic_lr), 179 | ) 180 | target_critic = TrainState.create( 181 | apply_fn=critic_def.apply, 182 | params=critic_params, 183 | tx=optax.GradientTransformation(lambda _: None, lambda _: None), 184 | ) 185 | 186 | temp_def = Temperature(init_temperature) 187 | temp_params = FrozenDict(temp_def.init(temp_key)["params"]) 188 | temp = TrainState.create( 189 | apply_fn=temp_def.apply, 190 | params=temp_params, 191 | tx=optax.adam(learning_rate=temp_lr), 192 | ) 193 | 194 | def data_augmentation_fn(rng, observations): 195 | for pixel_key, depth_key in zip_longest(pixel_keys, depth_keys): 196 | key, rng = jax.random.split(rng) 197 | observations = batched_random_crop(key, observations, pixel_key) 198 | if depth_key is not None: 199 | observations = batched_random_crop(key, observations, depth_key) 200 | return observations 201 | 202 | return cls( 203 | rng=rng, 204 | actor=actor, 205 | critic=critic, 206 | target_critic=target_critic, 207 | temp=temp, 208 | target_entropy=target_entropy, 209 | tau=tau, 210 | discount=discount, 211 | num_qs=num_qs, 212 | num_min_qs=num_min_qs, 213 | backup_entropy=backup_entropy, 214 | data_augmentation_fn=data_augmentation_fn, 215 | bc_coeff=bc_coeff, 216 | ) 217 | 218 | @partial(jax.jit, static_argnames="utd_ratio") 219 | def update(self, batch: DatasetDict, utd_ratio: int): 220 | new_agent = self 221 | 222 | if "pixels" not in batch["next_observations"]: 223 | batch = _unpack(batch) 224 | 225 | actor = _share_encoder(source=new_agent.critic, target=new_agent.actor) 226 | new_agent = new_agent.replace(actor=actor) 227 | 228 | rng, key = jax.random.split(new_agent.rng) 229 | observations = self.data_augmentation_fn(key, batch["observations"]) 230 | rng, key = jax.random.split(rng) 231 | next_observations = self.data_augmentation_fn(key, batch["next_observations"]) 232 | batch = batch.copy( 233 | add_or_replace={ 234 | "observations": observations, 235 | "next_observations": next_observations, 236 | } 237 | ) 238 | 239 | new_agent = new_agent.replace(rng=rng) 240 | 241 | return SACLearner.update(new_agent, batch, utd_ratio) 242 | -------------------------------------------------------------------------------- /rlpd/agents/drq/icvf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from the official ICVF codebase: https://github.com/dibyaghosh/icvf_release/ 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2023 Dibya Ghosh 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | """ 28 | 29 | from functools import partial 30 | from itertools import zip_longest 31 | from typing import Callable, Optional, Sequence, Tuple, Dict 32 | 33 | import flax 34 | import gym 35 | import jax 36 | import jax.numpy as jnp 37 | import optax 38 | from flax import struct 39 | from flax.training.train_state import TrainState 40 | from flax.core import FrozenDict 41 | 42 | from rlpd.agents.drq.augmentations import batched_random_crop 43 | from rlpd.data.dataset import DatasetDict 44 | from rlpd.networks import MLP, PixelMultiplexer 45 | from rlpd.types import PRNGKey 46 | from rlpd.networks.encoders import D4PGEncoder 47 | 48 | from rlpd.agents.drq.drq_learner import _unpack 49 | import gym 50 | import numpy as np 51 | 52 | import flax.linen as nn 53 | import jax.numpy as jnp 54 | 55 | from rlpd.networks import default_init 56 | 57 | 58 | class ICVF(nn.Module): 59 | base_cls: nn.Module 60 | feature_dim: int 61 | 62 | @nn.compact 63 | def __call__(self, observations: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: 64 | inputs = observations 65 | phi = self.base_cls(name="phi")(inputs, *args, **kwargs) 66 | psi = self.base_cls(name="psi")(inputs, *args, **kwargs) 67 | T = self.base_cls(name="T")(inputs, *args, **kwargs) 68 | return { 69 | "phi": phi, 70 | "psi": psi, 71 | "T": T, 72 | } 73 | 74 | 75 | def apply_layernorm(x): 76 | net_def = nn.LayerNorm(use_bias=False, use_scale=False) 77 | return net_def.apply({"params": {}}, x) 78 | 79 | 80 | class PixelICVF(struct.PyTreeNode): 81 | rng: PRNGKey 82 | net: TrainState 83 | target_net: TrainState 84 | data_augmentation_fn: Callable = struct.field(pytree_node=False) 85 | 86 | @classmethod 87 | def create( 88 | cls, 89 | seed: int, 90 | observation_space: gym.Space, 91 | action_space: gym.Space, 92 | lr: float = 3e-4, 93 | cnn_features: Sequence[int] = (32, 32, 32, 32), 94 | cnn_filters: Sequence[int] = (3, 3, 3, 3), 95 | cnn_strides: Sequence[int] = (2, 1, 1, 1), 96 | cnn_padding: str = "VALID", 97 | latent_dim: int = 50, 98 | feature_dim: int = 256, 99 | encoder: str = "d4pg", 100 | hidden_dims: Sequence[int] = (256, 256), 101 | pixel_keys: Tuple[str, ...] = ("pixels",), 102 | depth_keys: Tuple[str, ...] = (), 103 | **kwargs, 104 | ): 105 | print("Got additional kwargs: ", kwargs) 106 | 107 | observations = observation_space.sample() 108 | actions = action_space.sample() 109 | 110 | rng = jax.random.PRNGKey(seed) 111 | rng, key1, key2 = jax.random.split(rng, 3) 112 | 113 | if encoder == "d4pg": 114 | encoder_cls = partial( 115 | D4PGEncoder, 116 | features=cnn_features, 117 | filters=cnn_filters, 118 | strides=cnn_strides, 119 | padding=cnn_padding, 120 | ) 121 | else: 122 | raise NotImplementedError 123 | rnd_base_cls = partial( 124 | MLP, 125 | hidden_dims=hidden_dims, 126 | activate_final=True, 127 | ) 128 | rnd_cls = partial(ICVF, base_cls=rnd_base_cls, feature_dim=feature_dim) 129 | net_def = PixelMultiplexer( 130 | encoder_cls=encoder_cls, 131 | network_cls=rnd_cls, 132 | latent_dim=latent_dim, 133 | pixel_keys=pixel_keys, 134 | depth_keys=depth_keys, 135 | ) 136 | params = FrozenDict(net_def.init(key1, observations)["params"]) 137 | net = TrainState.create( 138 | apply_fn=net_def.apply, 139 | params=params, 140 | tx=optax.adam(learning_rate=lr), 141 | ) 142 | target_net = TrainState.create( 143 | apply_fn=net_def.apply, 144 | params=params, 145 | tx=optax.adam(learning_rate=lr), 146 | ) 147 | 148 | def data_augmentation_fn(rng, observations): 149 | for pixel_key, depth_key in zip_longest(pixel_keys, depth_keys): 150 | key, rng = jax.random.split(rng) 151 | observations = batched_random_crop(key, observations, pixel_key) 152 | if depth_key is not None: 153 | observations = batched_random_crop(key, observations, depth_key) 154 | return observations 155 | 156 | return cls( 157 | rng=rng, 158 | net=net, 159 | target_net=target_net, 160 | data_augmentation_fn=data_augmentation_fn, 161 | ) 162 | 163 | def _update(self, batch: DatasetDict) -> Tuple[struct.PyTreeNode, Dict[str, float]]: 164 | def loss_fn(params) -> Tuple[jnp.ndarray, Dict[str, float]]: 165 | def get_v(params, s, g, z): 166 | phi = self.net.apply_fn({"params": params}, s)["phi"] 167 | psi = self.net.apply_fn({"params": params}, g)["psi"] 168 | T = self.net.apply_fn({"params": params}, z)["T"] 169 | phi_T = apply_layernorm(phi * T) 170 | psi_T = apply_layernorm(psi * T) 171 | return -1 * optax.safe_norm(phi_T - psi_T, 1e-3, axis=-1) 172 | 173 | V = get_v( 174 | params, batch["observations"], batch["goals"], batch["desired_goals"] 175 | ) 176 | nV = get_v( 177 | self.target_net.params, 178 | batch["next_observations"], 179 | batch["goals"], 180 | batch["desired_goals"], 181 | ) 182 | target_V = batch["rewards"] + 0.99 * batch["masks"] * nV 183 | 184 | V_z = get_v( 185 | self.target_net.params, 186 | batch["next_observations"], 187 | batch["desired_goals"], 188 | batch["desired_goals"], 189 | ) 190 | nV_z = get_v( 191 | self.target_net.params, 192 | batch["next_observations"], 193 | batch["desired_goals"], 194 | batch["desired_goals"], 195 | ) 196 | adv = batch["desired_rewards"] + 0.99 * batch["desired_masks"] * nV_z - V_z 197 | 198 | def expectile_fn(adv, loss, expectile): 199 | weight = jnp.where(adv >= 0, expectile, 1 - expectile) 200 | return weight * loss 201 | 202 | def masked_mean(x, mask): 203 | mask = (mask > 0).astype(jnp.float32) 204 | return jnp.sum(x * mask) / (1e-5 + jnp.sum(mask)) 205 | 206 | loss = expectile_fn(adv, jnp.square(V - target_V), 0.9).mean() 207 | return loss, { 208 | "icvf_loss": loss, 209 | "V_success": masked_mean(V, 1.0 - batch["masks"]), 210 | "V_failure": masked_mean(V, batch["masks"]), 211 | } 212 | 213 | grads, info = jax.grad(loss_fn, has_aux=True)(self.net.params) 214 | net = self.net.apply_gradients(grads=grads) 215 | target_params = optax.incremental_update( 216 | self.net.params, self.target_net.params, 0.005 217 | ) 218 | target_net = self.target_net.replace(params=target_params) 219 | return self.replace(net=net, target_net=target_net), info 220 | 221 | @partial(jax.jit, static_argnames="utd_ratio") 222 | def update(self, batch: DatasetDict, utd_ratio: int): 223 | 224 | # if "pixels" not in batch["next_observations"]: 225 | # batch = _unpack(batch) 226 | 227 | rng, key = jax.random.split(self.rng) 228 | observations = self.data_augmentation_fn(key, batch["observations"]) 229 | rng, key = jax.random.split(rng) 230 | next_observations = self.data_augmentation_fn(key, batch["next_observations"]) 231 | goals = self.data_augmentation_fn(key, batch["goals"]) 232 | desired_goals = self.data_augmentation_fn(key, batch["desired_goals"]) 233 | 234 | batch = batch.copy( 235 | add_or_replace={ 236 | "observations": observations, 237 | "next_observations": next_observations, 238 | "goals": goals, 239 | "desired_goals": desired_goals, 240 | } 241 | ) 242 | new_self = self.replace(rng=rng) 243 | 244 | for i in range(utd_ratio): 245 | 246 | def slice(x): 247 | assert x.shape[0] % utd_ratio == 0 248 | batch_size = x.shape[0] // utd_ratio 249 | return x[batch_size * i : batch_size * (i + 1)] 250 | 251 | mini_batch = jax.tree_util.tree_map(slice, batch) 252 | new_self, info = new_self._update(mini_batch) 253 | 254 | return new_self, info 255 | -------------------------------------------------------------------------------- /rlpd/agents/drq/rm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | 10 | from functools import partial 11 | from itertools import zip_longest 12 | from typing import Callable, Optional, Sequence, Tuple, Dict 13 | 14 | import flax 15 | import gym 16 | import jax 17 | import jax.numpy as jnp 18 | import optax 19 | from flax import struct 20 | from flax.training.train_state import TrainState 21 | from flax.core import FrozenDict 22 | 23 | from rlpd.agents.drq.augmentations import batched_random_crop 24 | from rlpd.networks import MLP, PixelMultiplexer, StateValue 25 | from rlpd.networks.encoders import D4PGEncoder 26 | from rlpd.data.dataset import DatasetDict 27 | from rlpd.agents.drq.drq_learner import _unpack 28 | from rlpd.types import PRNGKey 29 | import gym 30 | 31 | 32 | class PixelRM(struct.PyTreeNode): 33 | rng: PRNGKey 34 | r_net: TrainState 35 | m_net: TrainState 36 | data_augmentation_fn: Callable = struct.field(pytree_node=False) 37 | 38 | @classmethod 39 | def create( 40 | cls, 41 | seed: int, 42 | observation_space: gym.Space, 43 | action_space: gym.Space, 44 | lr: float = 3e-4, 45 | hidden_dims: Sequence[int] = (256, 256), 46 | cnn_features: Sequence[int] = (32, 32, 32, 32), 47 | cnn_filters: Sequence[int] = (3, 3, 3, 3), 48 | cnn_strides: Sequence[int] = (2, 1, 1, 1), 49 | cnn_padding: str = "VALID", 50 | latent_dim: int = 50, 51 | encoder: str = "d4pg", 52 | pixel_keys: Tuple[str, ...] = ("pixels",), 53 | depth_keys: Tuple[str, ...] = (), 54 | ): 55 | 56 | observations = observation_space.sample() 57 | actions = action_space.sample() 58 | 59 | rng = jax.random.PRNGKey(seed) 60 | rng, key = jax.random.split(rng) 61 | 62 | if encoder == "d4pg": 63 | encoder_cls = partial( 64 | D4PGEncoder, 65 | features=cnn_features, 66 | filters=cnn_filters, 67 | strides=cnn_strides, 68 | padding=cnn_padding, 69 | ) 70 | else: 71 | raise NotImplementedError 72 | base_cls = partial( 73 | MLP, 74 | hidden_dims=hidden_dims, 75 | activate_final=True, 76 | ) 77 | net_cls = partial(StateValue, base_cls=base_cls) 78 | ucb_def = PixelMultiplexer( 79 | encoder_cls=encoder_cls, 80 | network_cls=net_cls, 81 | latent_dim=latent_dim, 82 | pixel_keys=pixel_keys, 83 | depth_keys=depth_keys, 84 | ) 85 | r_params = FrozenDict(ucb_def.init(key, observations)["params"]) 86 | r_net = TrainState.create( 87 | apply_fn=ucb_def.apply, 88 | params=r_params, 89 | tx=optax.adam(learning_rate=lr), 90 | ) 91 | 92 | m_params = FrozenDict(ucb_def.init(key, observations)["params"]) 93 | m_net = TrainState.create( 94 | apply_fn=ucb_def.apply, 95 | params=m_params, 96 | tx=optax.adam(learning_rate=lr), 97 | ) 98 | 99 | def data_augmentation_fn(rng, observations): 100 | for pixel_key, depth_key in zip_longest(pixel_keys, depth_keys): 101 | key, rng = jax.random.split(rng) 102 | observations = batched_random_crop(key, observations, pixel_key) 103 | if depth_key is not None: 104 | observations = batched_random_crop(key, observations, depth_key) 105 | return observations 106 | 107 | return cls( 108 | rng=rng, 109 | r_net=r_net, 110 | m_net=m_net, 111 | data_augmentation_fn=data_augmentation_fn, 112 | ) 113 | 114 | def _update(self, batch: DatasetDict) -> Tuple[struct.PyTreeNode, Dict[str, float]]: 115 | def r_loss_fn(r_params) -> Tuple[jnp.ndarray, Dict[str, float]]: 116 | rs = self.r_net.apply_fn({"params": r_params}, batch["observations"]) 117 | 118 | loss = ((rs - batch["rewards"]) ** 2.0).mean() 119 | return loss, {"r_loss": loss} 120 | 121 | grads, r_info = jax.grad(r_loss_fn, has_aux=True)(self.r_net.params) 122 | r_net = self.r_net.apply_gradients(grads=grads) 123 | 124 | def m_loss_fn(m_params) -> Tuple[jnp.ndarray, Dict[str, float]]: 125 | ms = self.m_net.apply_fn({"params": m_params}, batch["observations"]) 126 | 127 | loss = optax.sigmoid_binary_cross_entropy(ms, batch["masks"]).mean() 128 | return loss, {"m_loss": loss} 129 | 130 | grads, m_info = jax.grad(m_loss_fn, has_aux=True)(self.m_net.params) 131 | m_net = self.m_net.apply_gradients(grads=grads) 132 | 133 | return self.replace(r_net=r_net, m_net=m_net), {**r_info, **m_info} 134 | 135 | @partial(jax.jit, static_argnames="utd_ratio") 136 | def update(self, batch: DatasetDict, utd_ratio: int): 137 | 138 | if "pixels" not in batch["next_observations"]: 139 | batch = _unpack(batch) 140 | 141 | rng, key = jax.random.split(self.rng) 142 | observations = self.data_augmentation_fn(key, batch["observations"]) 143 | rng, key = jax.random.split(rng) 144 | next_observations = self.data_augmentation_fn(key, batch["next_observations"]) 145 | batch = batch.copy( 146 | add_or_replace={ 147 | "observations": observations, 148 | "next_observations": next_observations, 149 | } 150 | ) 151 | new_self = self.replace(rng=rng) 152 | 153 | for i in range(utd_ratio): 154 | 155 | def slice(x): 156 | assert x.shape[0] % utd_ratio == 0 157 | batch_size = x.shape[0] // utd_ratio 158 | return x[batch_size * i : batch_size * (i + 1)] 159 | 160 | mini_batch = jax.tree_util.tree_map(slice, batch) 161 | new_self, info = new_self._update(mini_batch) 162 | 163 | return new_self, info 164 | 165 | @jax.jit 166 | def get_reward(self, batch): 167 | if "pixels" not in batch["next_observations"]: 168 | batch = _unpack(batch) 169 | 170 | rewards = self.r_net.apply_fn( 171 | {"params": self.r_net.params}, batch["observations"] 172 | ) 173 | return rewards 174 | 175 | @jax.jit 176 | def get_mask(self, batch): 177 | if "pixels" not in batch["next_observations"]: 178 | batch = _unpack(batch) 179 | 180 | logits = self.m_net.apply_fn( 181 | {"params": self.m_net.params}, batch["observations"] 182 | ) 183 | return jax.nn.sigmoid(logits) 184 | -------------------------------------------------------------------------------- /rlpd/agents/drq/rnd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from functools import partial 10 | from itertools import zip_longest 11 | from typing import Callable, Optional, Sequence, Tuple, Dict 12 | 13 | import flax 14 | import gym 15 | import jax 16 | import jax.numpy as jnp 17 | import optax 18 | from flax import struct 19 | from flax.training.train_state import TrainState 20 | from flax.core import FrozenDict 21 | 22 | from rlpd.agents.drq.augmentations import batched_random_crop 23 | from rlpd.data.dataset import DatasetDict 24 | from rlpd.networks import MLP, PixelMultiplexer, StateFeature 25 | from rlpd.types import PRNGKey 26 | from rlpd.networks.encoders import D4PGEncoder 27 | 28 | from rlpd.agents.drq.drq_learner import _unpack 29 | 30 | 31 | class PixelRND(struct.PyTreeNode): 32 | rng: PRNGKey 33 | net: TrainState 34 | frozen_net: TrainState 35 | coeff: float = struct.field(pytree_node=False) 36 | data_augmentation_fn: Callable = struct.field(pytree_node=False) 37 | 38 | @classmethod 39 | def create( 40 | cls, 41 | seed: int, 42 | observation_space: gym.Space, 43 | action_space: gym.Space, 44 | lr: float = 3e-4, 45 | coeff: float = 1.0, 46 | cnn_features: Sequence[int] = (32, 32, 32, 32), 47 | cnn_filters: Sequence[int] = (3, 3, 3, 3), 48 | cnn_strides: Sequence[int] = (2, 1, 1, 1), 49 | cnn_padding: str = "VALID", 50 | latent_dim: int = 50, 51 | feature_dim: int = 256, 52 | encoder: str = "d4pg", 53 | hidden_dims: Sequence[int] = (256, 256), 54 | pixel_keys: Tuple[str, ...] = ("pixels",), 55 | depth_keys: Tuple[str, ...] = (), 56 | ): 57 | 58 | observations = observation_space.sample() 59 | actions = action_space.sample() 60 | 61 | rng = jax.random.PRNGKey(seed) 62 | rng, key1, key2 = jax.random.split(rng, 3) 63 | 64 | if encoder == "d4pg": 65 | encoder_cls = partial( 66 | D4PGEncoder, 67 | features=cnn_features, 68 | filters=cnn_filters, 69 | strides=cnn_strides, 70 | padding=cnn_padding, 71 | ) 72 | else: 73 | raise NotImplementedError 74 | rnd_base_cls = partial( 75 | MLP, 76 | hidden_dims=hidden_dims, 77 | activate_final=True, 78 | ) 79 | rnd_cls = partial(StateFeature, base_cls=rnd_base_cls, feature_dim=feature_dim) 80 | net_def = PixelMultiplexer( 81 | encoder_cls=encoder_cls, 82 | network_cls=rnd_cls, 83 | latent_dim=latent_dim, 84 | pixel_keys=pixel_keys, 85 | depth_keys=depth_keys, 86 | ) 87 | params = FrozenDict(net_def.init(key1, observations)["params"]) 88 | net = TrainState.create( 89 | apply_fn=net_def.apply, 90 | params=params, 91 | tx=optax.adam(learning_rate=lr), 92 | ) 93 | frozen_params = FrozenDict(net_def.init(key2, observations)["params"]) 94 | frozen_net = TrainState.create( 95 | apply_fn=net_def.apply, 96 | params=frozen_params, 97 | tx=optax.adam(learning_rate=lr), 98 | ) 99 | 100 | def data_augmentation_fn(rng, observations): 101 | for pixel_key, depth_key in zip_longest(pixel_keys, depth_keys): 102 | key, rng = jax.random.split(rng) 103 | observations = batched_random_crop(key, observations, pixel_key) 104 | if depth_key is not None: 105 | observations = batched_random_crop(key, observations, depth_key) 106 | return observations 107 | 108 | return cls( 109 | rng=rng, 110 | net=net, 111 | frozen_net=frozen_net, 112 | coeff=coeff, 113 | data_augmentation_fn=data_augmentation_fn, 114 | ) 115 | 116 | @jax.jit 117 | def update(self, batch: DatasetDict) -> Tuple[struct.PyTreeNode, Dict[str, float]]: 118 | 119 | rng, key = jax.random.split(self.rng) 120 | observations = self.data_augmentation_fn(key, batch["observations"]) 121 | rng, key = jax.random.split(rng) 122 | next_observations = self.data_augmentation_fn(key, batch["next_observations"]) 123 | batch = batch.copy( 124 | add_or_replace={ 125 | "observations": observations, 126 | "next_observations": next_observations, 127 | } 128 | ) 129 | new_self = self.replace(rng=rng) 130 | 131 | def loss_fn(params) -> Tuple[jnp.ndarray, Dict[str, float]]: 132 | feats = new_self.net.apply_fn({"params": params}, batch["observations"]) 133 | frozen_feats = new_self.frozen_net.apply_fn( 134 | {"params": new_self.frozen_net.params}, batch["observations"] 135 | ) 136 | 137 | loss = ((feats - frozen_feats) ** 2.0).mean() 138 | return loss, {"rnd_loss": loss} 139 | 140 | grads, info = jax.grad(loss_fn, has_aux=True)(new_self.net.params) 141 | net = new_self.net.apply_gradients(grads=grads) 142 | 143 | return new_self.replace(net=net), info 144 | 145 | @jax.jit 146 | def get_reward(self, batch): 147 | if "pixels" not in batch["next_observations"]: 148 | batch = _unpack(batch) 149 | feats = self.net.apply_fn({"params": self.net.params}, batch["observations"]) 150 | frozen_feats = self.net.apply_fn( 151 | {"params": self.frozen_net.params}, batch["observations"] 152 | ) 153 | return jnp.mean((feats - frozen_feats) ** 2.0, axis=-1) * self.coeff 154 | -------------------------------------------------------------------------------- /rlpd/agents/rm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | 10 | from functools import partial 11 | from typing import Dict, Optional, Sequence, Tuple 12 | 13 | import flax 14 | import gym 15 | import jax 16 | import jax.numpy as jnp 17 | import optax 18 | from flax import struct 19 | from flax.training.train_state import TrainState 20 | from flax.core import FrozenDict 21 | 22 | from rlpd.data.dataset import DatasetDict 23 | from rlpd.networks import ( 24 | MLP, 25 | StateActionValue, 26 | ) 27 | from rlpd.types import PRNGKey 28 | 29 | 30 | class RM(struct.PyTreeNode): 31 | rng: PRNGKey 32 | init_r_net: TrainState 33 | init_m_net: TrainState 34 | r_net: TrainState 35 | m_net: TrainState 36 | 37 | @classmethod 38 | def create( 39 | cls, 40 | seed: int, 41 | observation_space: gym.Space, 42 | action_space: gym.Space, 43 | lr: float = 3e-4, 44 | hidden_dims: Sequence[int] = (256, 256), 45 | ): 46 | 47 | observations = observation_space.sample() 48 | actions = action_space.sample() 49 | 50 | rng = jax.random.PRNGKey(seed) 51 | rng, key1, key2 = jax.random.split(rng, 3) 52 | 53 | base_cls = partial(MLP, hidden_dims=hidden_dims, activate_final=True) 54 | net_def = StateActionValue(base_cls) 55 | r_params = FrozenDict(net_def.init(key1, observations, actions)["params"]) 56 | r_net = TrainState.create( 57 | apply_fn=net_def.apply, 58 | params=r_params, 59 | tx=optax.adam(learning_rate=lr), 60 | ) 61 | 62 | m_params = FrozenDict(net_def.init(key2, observations, actions)["params"]) 63 | m_net = TrainState.create( 64 | apply_fn=net_def.apply, 65 | params=m_params, 66 | tx=optax.adam(learning_rate=lr), 67 | ) 68 | 69 | return cls( 70 | rng=rng, 71 | init_r_net=r_net, 72 | init_m_net=m_net, 73 | r_net=r_net, 74 | m_net=m_net, 75 | ) 76 | 77 | @jax.jit 78 | def reset(self): 79 | return self.replace(r_net=self.init_r_net, m_net=self.init_m_net) 80 | 81 | def _update(self, batch: DatasetDict) -> Tuple[struct.PyTreeNode, Dict[str, float]]: 82 | def r_loss_fn(r_params) -> Tuple[jnp.ndarray, Dict[str, float]]: 83 | rs = self.r_net.apply_fn( 84 | {"params": r_params}, batch["observations"], batch["actions"] 85 | ) 86 | 87 | loss = ((rs - batch["rewards"]) ** 2.0).mean() 88 | return loss, {"r_loss": loss} 89 | 90 | grads, r_info = jax.grad(r_loss_fn, has_aux=True)(self.r_net.params) 91 | r_net = self.r_net.apply_gradients(grads=grads) 92 | 93 | def m_loss_fn(m_params) -> Tuple[jnp.ndarray, Dict[str, float]]: 94 | ms = self.m_net.apply_fn( 95 | {"params": m_params}, batch["observations"], batch["actions"] 96 | ) 97 | 98 | loss = optax.sigmoid_binary_cross_entropy(ms, batch["masks"]).mean() 99 | return loss, {"m_loss": loss} 100 | 101 | grads, m_info = jax.grad(m_loss_fn, has_aux=True)(self.m_net.params) 102 | m_net = self.m_net.apply_gradients(grads=grads) 103 | 104 | return self.replace(r_net=r_net, m_net=m_net), {**r_info, **m_info} 105 | 106 | @partial(jax.jit, static_argnames="utd_ratio") 107 | def update(self, batch: DatasetDict, utd_ratio: int): 108 | 109 | new_self = self 110 | for i in range(utd_ratio): 111 | 112 | def slice(x): 113 | assert x.shape[0] % utd_ratio == 0 114 | batch_size = x.shape[0] // utd_ratio 115 | return x[batch_size * i : batch_size * (i + 1)] 116 | 117 | mini_batch = jax.tree_util.tree_map(slice, batch) 118 | new_self, info = new_self._update(mini_batch) 119 | 120 | return new_self, info 121 | 122 | @jax.jit 123 | def evaluate(self, batch: DatasetDict): 124 | rewards = self.get_reward(batch["observations"], batch["actions"]) 125 | masks = self.get_mask(batch["observations"], batch["actions"]) 126 | info = { 127 | "val_r_loss": ((rewards - batch["rewards"]) ** 2.0).mean(), 128 | "val_m_loss": optax.sigmoid_binary_cross_entropy( 129 | masks, batch["masks"] 130 | ).mean(), 131 | } 132 | return info 133 | 134 | @jax.jit 135 | def get_reward(self, observations, actions): 136 | rewards = self.r_net.apply_fn( 137 | {"params": self.r_net.params}, observations, actions 138 | ) 139 | return rewards 140 | 141 | @jax.jit 142 | def get_mask(self, observations, actions): 143 | logits = self.m_net.apply_fn( 144 | {"params": self.m_net.params}, observations, actions 145 | ) 146 | return jax.nn.sigmoid(logits) 147 | -------------------------------------------------------------------------------- /rlpd/agents/rnd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | 10 | from functools import partial 11 | from typing import Dict, Optional, Sequence, Tuple 12 | 13 | import flax 14 | import gym 15 | import jax 16 | import jax.numpy as jnp 17 | import optax 18 | from flax import struct 19 | from flax.training.train_state import TrainState 20 | from flax.core import FrozenDict 21 | 22 | from rlpd.data.dataset import DatasetDict 23 | from rlpd.networks import ( 24 | MLP, 25 | StateActionFeature, 26 | ) 27 | from rlpd.types import PRNGKey 28 | 29 | 30 | class RND(struct.PyTreeNode): 31 | rng: PRNGKey 32 | net: TrainState 33 | init_net: TrainState 34 | frozen_net: TrainState 35 | coeff: float = struct.field(pytree_node=False) 36 | 37 | @classmethod 38 | def create( 39 | cls, 40 | seed: int, 41 | observation_space: gym.Space, 42 | action_space: gym.Space, 43 | lr: float = 3e-4, 44 | coeff: float = 1.0, 45 | hidden_dims: Sequence[int] = (256, 256), 46 | feature_dim: int = 256, 47 | ): 48 | 49 | observations = observation_space.sample() 50 | actions = action_space.sample() 51 | 52 | rng = jax.random.PRNGKey(seed) 53 | rng, key1, key2 = jax.random.split(rng, 3) 54 | 55 | net_cls = partial(MLP, hidden_dims=hidden_dims, activate_final=True) 56 | net_def = StateActionFeature(base_cls=net_cls, feature_dim=feature_dim) 57 | params = FrozenDict(net_def.init(key1, observations, actions)["params"]) 58 | net = TrainState.create( 59 | apply_fn=net_def.apply, 60 | params=params, 61 | tx=optax.adam(learning_rate=lr), 62 | ) 63 | frozen_params = FrozenDict(net_def.init(key2, observations, actions)["params"]) 64 | frozen_net = TrainState.create( 65 | apply_fn=net_def.apply, 66 | params=frozen_params, 67 | tx=optax.adam(learning_rate=lr), 68 | ) 69 | return cls( 70 | rng=rng, 71 | init_net=net, 72 | net=net, 73 | frozen_net=frozen_net, 74 | coeff=coeff, 75 | ) 76 | 77 | @jax.jit 78 | def reset(self): 79 | return self.replace(net=self.init_net) 80 | 81 | @jax.jit 82 | def update(self, batch: DatasetDict) -> Tuple[struct.PyTreeNode, Dict[str, float]]: 83 | def loss_fn(params) -> Tuple[jnp.ndarray, Dict[str, float]]: 84 | feats = self.net.apply_fn( 85 | {"params": params}, batch["observations"], batch["actions"] 86 | ) 87 | frozen_feats = self.frozen_net.apply_fn( 88 | {"params": self.frozen_net.params}, 89 | batch["observations"], 90 | batch["actions"], 91 | ) 92 | loss = ((feats - frozen_feats) ** 2.0).mean() 93 | return loss, {"rnd_loss": loss} 94 | 95 | grads, info = jax.grad(loss_fn, has_aux=True)(self.net.params) 96 | net = self.net.apply_gradients(grads=grads) 97 | 98 | return self.replace(net=net), info 99 | 100 | @jax.jit 101 | def get_reward(self, observations, actions): 102 | feats = self.net.apply_fn({"params": self.net.params}, observations, actions) 103 | frozen_feats = self.net.apply_fn( 104 | {"params": self.frozen_net.params}, observations, actions 105 | ) 106 | return jnp.mean((feats - frozen_feats) ** 2.0, axis=-1) * self.coeff 107 | -------------------------------------------------------------------------------- /rlpd/agents/sac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ExPLORe/fff95eb0e1aa16c1410a2daefed068dfeb5c3620/rlpd/agents/sac/__init__.py -------------------------------------------------------------------------------- /rlpd/agents/sac/temperature.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/agents/sac/temperature.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | """ 28 | 29 | 30 | import flax.linen as nn 31 | import jax.numpy as jnp 32 | 33 | 34 | class Temperature(nn.Module): 35 | initial_temperature: float = 1.0 36 | 37 | @nn.compact 38 | def __call__(self) -> jnp.ndarray: 39 | log_temp = self.param( 40 | "log_temp", 41 | init_fn=lambda key: jnp.full((), jnp.log(self.initial_temperature)), 42 | ) 43 | return jnp.exp(log_temp) 44 | -------------------------------------------------------------------------------- /rlpd/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/data/__init__.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from rlpd.data.memory_efficient_replay_buffer import MemoryEfficientReplayBuffer 32 | from rlpd.data.replay_buffer import ReplayBuffer 33 | from rlpd.data.dataset import Dataset 34 | -------------------------------------------------------------------------------- /rlpd/data/binary_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/data/binary_datasets.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | import os 32 | 33 | import gym 34 | 35 | try: 36 | import mj_envs 37 | except: 38 | pass 39 | import numpy as np 40 | 41 | from rlpd.data.dataset import Dataset 42 | 43 | AWAC_DATA_DIR = "~/.datasets/awac-data" 44 | 45 | 46 | def process_expert_dataset(expert_datset): 47 | """This is a mess, but works""" 48 | all_observations = [] 49 | all_next_observations = [] 50 | all_actions = [] 51 | all_rewards = [] 52 | all_terminals = [] 53 | 54 | for x in expert_datset: 55 | all_observations.append( 56 | np.vstack([xx["state_observation"] for xx in x["observations"]]) 57 | ) 58 | all_next_observations.append( 59 | np.vstack([xx["state_observation"] for xx in x["next_observations"]]) 60 | ) 61 | all_actions.append(np.vstack([xx for xx in x["actions"]])) 62 | # for some reason rewards has an extra entry, so in rlkit they just remove the last entry: https://github.com/rail-berkeley/rlkit/blob/354f14c707cc4eb7ed876215dd6235c6b30a2e2b/rlkit/demos/source/dict_to_mdp_path_loader.py#L84 63 | all_rewards.append(x["rewards"][:-1]) 64 | all_terminals.append(x["terminals"]) 65 | 66 | return { 67 | "observations": np.concatenate(all_observations, dtype=np.float32), 68 | "next_observations": np.concatenate(all_next_observations, dtype=np.float32), 69 | "actions": np.concatenate(all_actions, dtype=np.float32), 70 | "rewards": np.concatenate(all_rewards, dtype=np.float32), 71 | "terminals": np.concatenate(all_terminals, dtype=np.float32), 72 | } 73 | 74 | 75 | def process_bc_dataset(bc_dataset): 76 | final_bc_dataset = {k: [] for k in bc_dataset[0] if "info" not in k} 77 | 78 | for x in bc_dataset: 79 | for k in final_bc_dataset: 80 | final_bc_dataset[k].append(x[k]) 81 | 82 | return { 83 | k: np.concatenate(v, dtype=np.float32).squeeze() 84 | for k, v in final_bc_dataset.items() 85 | } 86 | 87 | 88 | class BinaryDataset(Dataset): 89 | def __init__( 90 | self, 91 | env: gym.Env, 92 | clip_to_eps: bool = True, 93 | eps: float = 1e-5, 94 | remove_terminals=True, 95 | include_bc_data=True, 96 | ): 97 | # import pdb; pdb.set_trace() 98 | env_prefix = env.spec.id.split("-")[0] 99 | 100 | expert_dataset = np.load( 101 | os.path.join( 102 | os.path.expanduser(AWAC_DATA_DIR), f"{env_prefix}2_sparse.npy" 103 | ), 104 | allow_pickle=True, 105 | ) 106 | 107 | # this seems super random, but I grabbed it from here: https://github.com/rail-berkeley/rlkit/blob/c81509d982b4d52a6239e7bfe7d2540e3d3cd986/rlkit/launchers/experiments/awac/awac_rl.py#L124 and here https://github.com/rail-berkeley/rlkit/blob/354f14c707cc4eb7ed876215dd6235c6b30a2e2b/rlkit/demos/source/dict_to_mdp_path_loader.py#L153 108 | dataset_split = 0.9 109 | last_train_idx = int(dataset_split * len(expert_dataset)) 110 | 111 | dataset_dict = process_expert_dataset(expert_dataset[:last_train_idx]) 112 | 113 | if include_bc_data: 114 | bc_dataset = np.load( 115 | os.path.join( 116 | os.path.expanduser(AWAC_DATA_DIR), f"{env_prefix}_bc_sparse4.npy" 117 | ), 118 | allow_pickle=True, 119 | ) 120 | 121 | # this seems super random, but I grabbed it from here: https://github.com/rail-berkeley/rlkit/blob/c81509d982b4d52a6239e7bfe7d2540e3d3cd986/rlkit/launchers/experiments/awac/awac_rl.py#L124 and here https://github.com/rail-berkeley/rlkit/blob/354f14c707cc4eb7ed876215dd6235c6b30a2e2b/rlkit/demos/source/dict_to_mdp_path_loader.py#L153 122 | bc_dataset_split = 0.9 123 | bc_dataset = bc_dataset[: int(bc_dataset_split * len(bc_dataset))] 124 | bc_dataset = process_bc_dataset(bc_dataset) 125 | 126 | dataset_dict = { 127 | k: np.concatenate([dataset_dict[k], bc_dataset[k]]) 128 | for k in dataset_dict 129 | } 130 | 131 | if clip_to_eps: 132 | lim = 1 - eps 133 | dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim) 134 | 135 | dones = np.full_like(dataset_dict["rewards"], False, dtype=bool) 136 | 137 | for i in range(len(dones) - 1): 138 | if ( 139 | np.linalg.norm( 140 | dataset_dict["observations"][i + 1] 141 | - dataset_dict["next_observations"][i] 142 | ) 143 | > 1e-6 144 | or dataset_dict["terminals"][i] == 1.0 145 | ): 146 | dones[i] = True 147 | 148 | if remove_terminals: 149 | dataset_dict["terminals"] = np.zeros_like(dataset_dict["terminals"]) 150 | 151 | dones[-1] = True 152 | 153 | dataset_dict["masks"] = 1.0 - dataset_dict["terminals"] 154 | del dataset_dict["terminals"] 155 | 156 | for k, v in dataset_dict.items(): 157 | dataset_dict[k] = v.astype(np.float32) 158 | 159 | dataset_dict["dones"] = dones 160 | 161 | super().__init__(dataset_dict) 162 | -------------------------------------------------------------------------------- /rlpd/data/cog_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | 10 | import os 11 | 12 | import numpy as np 13 | import gym 14 | 15 | from rlpd.data import MemoryEfficientReplayBuffer 16 | 17 | def dict_to_list(D): 18 | # https://stackoverflow.com/questions/5558418/list-of-dicts-to-from-dict-of-lists 19 | return [dict(zip(D, t)) for t in zip(*D.values())] 20 | 21 | class COGDataset(MemoryEfficientReplayBuffer): 22 | def __init__( 23 | self, 24 | env: gym.Env, 25 | dataset_path: str, 26 | capacity: int = 500_000, 27 | subsample_ratio: float = 1.0, 28 | pixel_keys: tuple = ("pixels",), 29 | np_rng = None, 30 | load_successes: bool = True, 31 | ): 32 | self.np_rng = np_rng 33 | super().__init__( 34 | env.observation_space, 35 | env.action_space, 36 | capacity=capacity, 37 | pixel_keys=pixel_keys 38 | ) 39 | self.successful_offline_prior_trajs = [] 40 | self.successful_offline_task_trajs = [] 41 | 42 | self._load_data_from_dir(dataset_path, subsample_ratio) 43 | 44 | self.load_successes = load_successes 45 | if self.load_successes: 46 | self._load_successful_traj(dataset_path) 47 | 48 | def load_successful_traj(self): 49 | assert self.load_successes, "did not load successful trajectories upon making this dataset" 50 | prior_idx = self.np_rng.integers(len(self.successful_offline_prior_trajs)) 51 | task_idx = self.np_rng.integers(len(self.successful_offline_task_trajs)) 52 | prior_traj = self.successful_offline_prior_trajs[prior_idx] 53 | task_traj = self.successful_offline_task_trajs[task_idx] 54 | return prior_traj + task_traj 55 | 56 | def _load_data_from_dir(self, dataset_path, subsample_ratio=1.0): 57 | print("subsample ratio:", subsample_ratio * subsample_ratio) # sub-sampled twice 58 | for f in os.listdir(dataset_path): 59 | full_path = os.path.join(dataset_path, f) 60 | if f.endswith('.npy'): 61 | print("*"*20, "\nloading data from:", full_path) 62 | data = np.load(full_path, allow_pickle=True) 63 | print("prior subsampling # trajs:", len(data)) 64 | data = self._subsample_data(data, subsample_ratio) 65 | self._load_data(data, subsample_ratio) 66 | print("post subsampling # trajs:", len(self)) 67 | 68 | def _subsample_data(self, data, r=1.0): 69 | assert 0 <= r <= 1 70 | n = len(data) 71 | idxs = self.np_rng.choice(n, size=int(n*r), replace=False) 72 | return data[idxs] 73 | 74 | def _load_data(self, data, subsample_ratio=1.0): 75 | cutoff = int(len(data) * subsample_ratio) 76 | for i, traj in enumerate(data): 77 | if i > cutoff: 78 | break 79 | trans = dict_to_list(traj) 80 | for tran in trans: 81 | data_dict = self._make_data_dict(tran) 82 | self.insert(data_dict) 83 | 84 | def _load_successful_traj(self, dataset_path): 85 | # load successful offline trajectories for visualizations / evaluation 86 | prior_data = np.load(os.path.join(dataset_path, 'successful', 'prior_success.npy'), allow_pickle=True) 87 | task_data = np.load(os.path.join(dataset_path, 'successful', 'task_success.npy'), allow_pickle=True) 88 | 89 | for traj in prior_data: 90 | trans = dict_to_list(traj) 91 | trans = [self._make_data_dict(tran) for tran in trans] 92 | self.successful_offline_prior_trajs.append(trans) 93 | 94 | for traj in task_data: 95 | trans = dict_to_list(traj) 96 | trans = [self._make_data_dict(tran) for tran in trans] 97 | self.successful_offline_task_trajs.append(trans) 98 | 99 | def _make_data_dict(self, tran): 100 | return dict( 101 | observations={"pixels": np.array(tran["observations"]["image"])[..., None]}, 102 | actions=np.array(tran["actions"]), 103 | next_observations={"pixels": np.array(tran["next_observations"]["image"])[..., None]}, 104 | rewards=np.array(tran["rewards"]), 105 | masks=1-np.array(tran["terminals"], dtype=float), 106 | dones=np.array(tran["agent_infos"]["done"]) 107 | ) 108 | -------------------------------------------------------------------------------- /rlpd/data/d4rl_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/data/d4rl_datasets.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | import d4rl 32 | import gym 33 | import numpy as np 34 | 35 | from rlpd.data.dataset import Dataset 36 | import copy 37 | 38 | 39 | def filter_antmaze(tran, env, np_rng, mode="all"): 40 | if "large" in env: 41 | right_cutoff = 20 42 | bottom_cutoff = 10 43 | block_cutoff = 5 44 | elif "medium" in env: 45 | right_cutoff = 10 46 | bottom_cutoff = 10 47 | block_cutoff = 5 48 | else: 49 | raise NotImplementedError 50 | 51 | x, y = tran["observations"][:2] 52 | delta = (tran["next_observations"] - tran["observations"])[:2] 53 | 54 | # observation based filters 55 | if mode == "all": 56 | return True 57 | elif mode == "right": 58 | return x < right_cutoff 59 | elif mode == "bottom": 60 | return y < bottom_cutoff 61 | elif mode == "no_corner": 62 | return x < right_cutoff or y < bottom_cutoff 63 | elif mode == "stripes": 64 | return (x // block_cutoff) % 2 == 0 65 | elif mode == "checkers": 66 | return (x // block_cutoff) % 2 == (y // block_cutoff) % 2 67 | elif mode == "subopt": 68 | assert "large" in env 69 | return ( 70 | (x < 5 and y < 10) 71 | or (x < 25 and 5 < y < 10) 72 | or (15 < x and y < 5) 73 | or (30 < x and y < 20) 74 | or (25 < x and 15 < y) 75 | ) 76 | elif mode == "10perc": 77 | return np_rng.random() < 0.1 78 | elif mode == "1perc": 79 | return np_rng.random() < 0.01 80 | elif mode == "01perc": 81 | return np_rng.random() < 0.001 82 | elif mode == "001perc": 83 | return np_rng.random() < 0.0001 84 | 85 | # action based filters 86 | elif mode == "southwest": 87 | return np.dot(delta, np.ones(delta.shape)) < 0 88 | elif mode == "northeast": 89 | return np.dot(delta, np.ones(delta.shape)) > 0 90 | elif mode == "southwest-10perc": 91 | return np.dot(delta, np.ones(delta.shape)) < 0 and np_rng.random() < 0.1 92 | elif mode == "southwest-1perc": 93 | return np.dot(delta, np.ones(delta.shape)) < 0 and np_rng.random() < 0.01 94 | elif mode == "southwest-01perc": 95 | return np.dot(delta, np.ones(delta.shape)) < 0 and np_rng.random() < 0.001 96 | elif mode == "southwest-001perc": 97 | return np.dot(delta, np.ones(delta.shape)) < 0 and np_rng.random() < 0.0001 98 | else: 99 | raise NotImplementedError 100 | 101 | 102 | class D4RLDataset(Dataset): 103 | def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5): 104 | dataset_dict = d4rl.qlearning_dataset(env) 105 | 106 | if clip_to_eps: 107 | lim = 1 - eps 108 | dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim) 109 | 110 | dones = np.full_like(dataset_dict["rewards"], False, dtype=bool) 111 | 112 | for i in range(len(dones) - 1): 113 | if ( 114 | np.linalg.norm( 115 | dataset_dict["observations"][i + 1] 116 | - dataset_dict["next_observations"][i] 117 | ) 118 | > 1e-6 119 | or dataset_dict["terminals"][i] == 1.0 120 | ): 121 | dones[i] = True 122 | 123 | dones[-1] = True 124 | 125 | dataset_dict["masks"] = 1.0 - dataset_dict["terminals"] 126 | del dataset_dict["terminals"] 127 | 128 | for k, v in dataset_dict.items(): 129 | dataset_dict[k] = v.astype(np.float32) 130 | 131 | dataset_dict["dones"] = dones 132 | 133 | super().__init__(dataset_dict) 134 | -------------------------------------------------------------------------------- /rlpd/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/data/dataset.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from functools import partial 31 | from random import sample 32 | from typing import Dict, Iterable, Optional, Tuple, Union 33 | 34 | import jax 35 | import jax.numpy as jnp 36 | import numpy as np 37 | from flax.core import frozen_dict 38 | from gym.utils import seeding 39 | 40 | from rlpd.types import DataType 41 | 42 | DatasetDict = Dict[str, DataType] 43 | 44 | 45 | def _check_lengths(dataset_dict: DatasetDict, dataset_len: Optional[int] = None) -> int: 46 | for v in dataset_dict.values(): 47 | if isinstance(v, dict): 48 | dataset_len = dataset_len or _check_lengths(v, dataset_len) 49 | elif isinstance(v, np.ndarray): 50 | item_len = len(v) 51 | dataset_len = dataset_len or item_len 52 | assert dataset_len == item_len, "Inconsistent item lengths in the dataset." 53 | else: 54 | raise TypeError("Unsupported type.") 55 | return dataset_len 56 | 57 | 58 | def _subselect(dataset_dict: DatasetDict, index: np.ndarray) -> DatasetDict: 59 | new_dataset_dict = {} 60 | for k, v in dataset_dict.items(): 61 | if isinstance(v, dict): 62 | new_v = _subselect(v, index) 63 | elif isinstance(v, np.ndarray): 64 | new_v = v[index] 65 | else: 66 | raise TypeError("Unsupported type.") 67 | new_dataset_dict[k] = new_v 68 | return new_dataset_dict 69 | 70 | 71 | def _sample( 72 | dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray 73 | ) -> DatasetDict: 74 | if isinstance(dataset_dict, np.ndarray): 75 | return dataset_dict[indx] 76 | elif isinstance(dataset_dict, dict): 77 | batch = {} 78 | for k, v in dataset_dict.items(): 79 | batch[k] = _sample(v, indx) 80 | else: 81 | raise TypeError("Unsupported type.") 82 | return batch 83 | 84 | 85 | class Dataset(object): 86 | def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): 87 | self.dataset_dict = dataset_dict 88 | self.dataset_len = _check_lengths(dataset_dict) 89 | 90 | # Seeding similar to OpenAI Gym: 91 | # https://github.com/openai/gym/blob/master/gym/spaces/space.py#L46 92 | self._np_random = None 93 | self._seed = None 94 | if seed is not None: 95 | self.seed(seed) 96 | 97 | @property 98 | def np_random(self) -> np.random.RandomState: 99 | if self._np_random is None: 100 | self.seed() 101 | return self._np_random 102 | 103 | def seed(self, seed: Optional[int] = None) -> list: 104 | self._np_random, self._seed = seeding.np_random(seed) 105 | return [self._seed] 106 | 107 | def __len__(self) -> int: 108 | return self.dataset_len 109 | 110 | def get_iter(self, batch_size): 111 | for i in range(len(self) // batch_size): 112 | indx = np.arange(i * batch_size, (i + 1) * batch_size) 113 | indx = np.clip(indx, a_min=0, a_max=len(self) - 1) 114 | batch = dict() 115 | keys = self.dataset_dict.keys() 116 | 117 | for k in keys: 118 | if isinstance(self.dataset_dict[k], dict): 119 | batch[k] = _sample(self.dataset_dict[k], indx) 120 | else: 121 | batch[k] = self.dataset_dict[k][indx] 122 | 123 | yield frozen_dict.freeze(batch) 124 | 125 | def sample( 126 | self, 127 | batch_size: int, 128 | keys: Optional[Iterable[str]] = None, 129 | indx: Optional[np.ndarray] = None, 130 | ) -> frozen_dict.FrozenDict: 131 | if indx is None: 132 | if hasattr(self.np_random, "integers"): 133 | indx = self.np_random.integers(len(self), size=batch_size) 134 | else: 135 | indx = self.np_random.randint(len(self), size=batch_size) 136 | 137 | batch = dict() 138 | 139 | if keys is None: 140 | keys = self.dataset_dict.keys() 141 | 142 | for k in keys: 143 | if isinstance(self.dataset_dict[k], dict): 144 | batch[k] = _sample(self.dataset_dict[k], indx) 145 | else: 146 | batch[k] = self.dataset_dict[k][indx] 147 | 148 | return frozen_dict.freeze(batch) 149 | 150 | def sample_jax(self, batch_size: int, keys: Optional[Iterable[str]] = None): 151 | if not hasattr(self, "rng"): 152 | self.rng = jax.random.PRNGKey(self._seed or 42) 153 | 154 | if keys is None: 155 | keys = self.dataset_dict.keys() 156 | 157 | jax_dataset_dict = {k: self.dataset_dict[k] for k in keys} 158 | jax_dataset_dict = jax.device_put(jax_dataset_dict) 159 | 160 | @jax.jit 161 | def _sample_jax(rng): 162 | key, rng = jax.random.split(rng) 163 | indx = jax.random.randint( 164 | key, (batch_size,), minval=0, maxval=len(self) 165 | ) 166 | return rng, jax.tree_map( 167 | lambda d: jnp.take(d, indx, axis=0), jax_dataset_dict 168 | ) 169 | 170 | self._sample_jax = _sample_jax 171 | 172 | self.rng, sample = self._sample_jax(self.rng) 173 | return sample 174 | 175 | def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: 176 | assert 0 < ratio and ratio < 1 177 | train_index = np.index_exp[: int(self.dataset_len * ratio)] 178 | test_index = np.index_exp[int(self.dataset_len * ratio) :] 179 | 180 | index = np.arange(len(self), dtype=np.int32) 181 | self.np_random.shuffle(index) 182 | train_index = index[: int(self.dataset_len * ratio)] 183 | test_index = index[int(self.dataset_len * ratio) :] 184 | 185 | train_dataset_dict = _subselect(self.dataset_dict, train_index) 186 | test_dataset_dict = _subselect(self.dataset_dict, test_index) 187 | return Dataset(train_dataset_dict), Dataset(test_dataset_dict) 188 | 189 | def _trajectory_boundaries_and_returns(self) -> Tuple[list, list, list]: 190 | episode_starts = [0] 191 | episode_ends = [] 192 | 193 | episode_return = 0 194 | episode_returns = [] 195 | 196 | for i in range(len(self)): 197 | episode_return += self.dataset_dict["rewards"][i] 198 | 199 | if self.dataset_dict["dones"][i]: 200 | episode_returns.append(episode_return) 201 | episode_ends.append(i + 1) 202 | if i + 1 < len(self): 203 | episode_starts.append(i + 1) 204 | episode_return = 0.0 205 | 206 | return episode_starts, episode_ends, episode_returns 207 | 208 | def filter_by_fn(self, fn): 209 | bool_indx = np.full((len(self),), False, dtype=bool) 210 | for i in range(len(self)): 211 | tran = {k: v[i] for k, v in self.dataset_dict.items()} 212 | bool_indx[i] = fn(tran) 213 | 214 | self.dataset_dict = _subselect(self.dataset_dict, bool_indx) 215 | self.dataset_len = _check_lengths(self.dataset_dict) 216 | 217 | def filter( 218 | self, take_top: Optional[float] = None, threshold: Optional[float] = None 219 | ): 220 | assert (take_top is None and threshold is not None) or ( 221 | take_top is not None and threshold is None 222 | ) 223 | 224 | ( 225 | episode_starts, 226 | episode_ends, 227 | episode_returns, 228 | ) = self._trajectory_boundaries_and_returns() 229 | 230 | if take_top is not None: 231 | threshold = np.percentile(episode_returns, 100 - take_top) 232 | 233 | bool_indx = np.full((len(self),), False, dtype=bool) 234 | 235 | for i in range(len(episode_returns)): 236 | if episode_returns[i] >= threshold: 237 | bool_indx[episode_starts[i] : episode_ends[i]] = True 238 | 239 | self.dataset_dict = _subselect(self.dataset_dict, bool_indx) 240 | 241 | self.dataset_len = _check_lengths(self.dataset_dict) 242 | 243 | def normalize_returns(self, scaling: float = 1000): 244 | (_, _, episode_returns) = self._trajectory_boundaries_and_returns() 245 | self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min( 246 | episode_returns 247 | ) 248 | self.dataset_dict["rewards"] *= scaling 249 | -------------------------------------------------------------------------------- /rlpd/data/memory_efficient_replay_buffer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/data/memory_efficient_replay_buffer.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | import copy 32 | from typing import Iterable, Optional, Tuple 33 | 34 | import gym 35 | import numpy as np 36 | from flax.core import frozen_dict 37 | from gym.spaces import Box 38 | 39 | from rlpd.data.dataset import DatasetDict, _sample 40 | from rlpd.data.replay_buffer import ReplayBuffer 41 | 42 | 43 | class MemoryEfficientReplayBuffer(ReplayBuffer): 44 | def __init__( 45 | self, 46 | observation_space: gym.Space, 47 | action_space: gym.Space, 48 | capacity: int, 49 | pixel_keys: Tuple[str, ...] = ("pixels",), 50 | ): 51 | self.pixel_keys = pixel_keys 52 | 53 | observation_space = copy.deepcopy(observation_space) 54 | self._num_stack = None 55 | for pixel_key in self.pixel_keys: 56 | pixel_obs_space = observation_space.spaces[pixel_key] 57 | if self._num_stack is None: 58 | self._num_stack = pixel_obs_space.shape[-1] 59 | else: 60 | assert self._num_stack == pixel_obs_space.shape[-1] 61 | self._unstacked_dim_size = pixel_obs_space.shape[-2] 62 | low = pixel_obs_space.low[..., 0] 63 | high = pixel_obs_space.high[..., 0] 64 | unstacked_pixel_obs_space = Box( 65 | low=low, high=high, dtype=pixel_obs_space.dtype 66 | ) 67 | observation_space.spaces[pixel_key] = unstacked_pixel_obs_space 68 | 69 | next_observation_space_dict = copy.deepcopy(observation_space.spaces) 70 | for pixel_key in self.pixel_keys: 71 | next_observation_space_dict.pop(pixel_key) 72 | next_observation_space = gym.spaces.Dict(next_observation_space_dict) 73 | 74 | self._first = True 75 | self._is_correct_index = np.full(capacity, False, dtype=bool) 76 | 77 | super().__init__( 78 | observation_space, 79 | action_space, 80 | capacity, 81 | next_observation_space=next_observation_space, 82 | ) 83 | 84 | def insert(self, data_dict: DatasetDict): 85 | if self._insert_index == 0 and self._capacity == len(self) and not self._first: 86 | indxs = np.arange(len(self) - self._num_stack, len(self)) 87 | for indx in indxs: 88 | element = super().sample(1, indx=indx) 89 | self._is_correct_index[self._insert_index] = False 90 | super().insert(element) 91 | 92 | data_dict = data_dict.copy() 93 | data_dict["observations"] = data_dict["observations"].copy() 94 | data_dict["next_observations"] = data_dict["next_observations"].copy() 95 | 96 | obs_pixels = {} 97 | next_obs_pixels = {} 98 | for pixel_key in self.pixel_keys: 99 | obs_pixels[pixel_key] = data_dict["observations"].pop(pixel_key) 100 | next_obs_pixels[pixel_key] = data_dict["next_observations"].pop(pixel_key) 101 | 102 | if self._first: 103 | for i in range(self._num_stack): 104 | for pixel_key in self.pixel_keys: 105 | data_dict["observations"][pixel_key] = obs_pixels[pixel_key][..., i] 106 | 107 | self._is_correct_index[self._insert_index] = False 108 | super().insert(data_dict) 109 | 110 | for pixel_key in self.pixel_keys: 111 | data_dict["observations"][pixel_key] = next_obs_pixels[pixel_key][..., -1] 112 | 113 | self._first = data_dict["dones"] 114 | 115 | self._is_correct_index[self._insert_index] = True 116 | super().insert(data_dict) 117 | 118 | for i in range(self._num_stack): 119 | indx = (self._insert_index + i) % len(self) 120 | self._is_correct_index[indx] = False 121 | 122 | def sample( 123 | self, 124 | batch_size: int, 125 | keys: Optional[Iterable[str]] = None, 126 | indx: Optional[np.ndarray] = None, 127 | pack_obs_and_next_obs: bool = False, 128 | ) -> frozen_dict.FrozenDict: 129 | """Samples from the replay buffer. 130 | 131 | Args: 132 | batch_size: Minibatch size. 133 | keys: Keys to sample. 134 | indx: Take indices instead of sampling. 135 | pack_obs_and_next_obs: whether to pack img and next_img into one image. 136 | It's useful when they have overlapping frames. 137 | 138 | Returns: 139 | A frozen dictionary. 140 | """ 141 | 142 | if indx is None: 143 | if hasattr(self.np_random, "integers"): 144 | indx = self.np_random.integers(len(self), size=batch_size) 145 | else: 146 | indx = self.np_random.randint(len(self), size=batch_size) 147 | 148 | for i in range(batch_size): 149 | while not self._is_correct_index[indx[i]]: 150 | if hasattr(self.np_random, "integers"): 151 | indx[i] = self.np_random.integers(len(self)) 152 | else: 153 | indx[i] = self.np_random.randint(len(self)) 154 | else: 155 | pass 156 | 157 | if keys is None: 158 | keys = self.dataset_dict.keys() 159 | else: 160 | assert "observations" in keys 161 | 162 | keys = list(keys) 163 | keys.remove("observations") 164 | 165 | batch = super().sample(batch_size, keys, indx) 166 | batch = batch.unfreeze() 167 | 168 | obs_keys = self.dataset_dict["observations"].keys() 169 | obs_keys = list(obs_keys) 170 | for pixel_key in self.pixel_keys: 171 | obs_keys.remove(pixel_key) 172 | 173 | batch["observations"] = {} 174 | for k in obs_keys: 175 | batch["observations"][k] = _sample( 176 | self.dataset_dict["observations"][k], indx 177 | ) 178 | 179 | for pixel_key in self.pixel_keys: 180 | obs_pixels = self.dataset_dict["observations"][pixel_key] 181 | obs_pixels = np.lib.stride_tricks.sliding_window_view( 182 | obs_pixels, self._num_stack + 1, axis=0 183 | ) 184 | obs_pixels = obs_pixels[indx - self._num_stack] 185 | 186 | if pack_obs_and_next_obs: 187 | batch["observations"][pixel_key] = obs_pixels 188 | else: 189 | batch["observations"][pixel_key] = obs_pixels[..., :-1] 190 | if "next_observations" in keys: 191 | batch["next_observations"][pixel_key] = obs_pixels[..., 1:] 192 | 193 | return frozen_dict.freeze(batch) 194 | -------------------------------------------------------------------------------- /rlpd/data/replay_buffer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/data/replay_buffer.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | import collections 32 | from typing import Optional, Union 33 | 34 | import gym 35 | import gym.spaces 36 | import jax 37 | import numpy as np 38 | 39 | from rlpd.data.dataset import Dataset, DatasetDict 40 | 41 | 42 | def _init_replay_dict( 43 | obs_space: gym.Space, capacity: int 44 | ) -> Union[np.ndarray, DatasetDict]: 45 | if isinstance(obs_space, gym.spaces.Box): 46 | return np.empty((capacity, *obs_space.shape), dtype=obs_space.dtype) 47 | elif isinstance(obs_space, gym.spaces.Dict): 48 | data_dict = {} 49 | for k, v in obs_space.spaces.items(): 50 | data_dict[k] = _init_replay_dict(v, capacity) 51 | return data_dict 52 | else: 53 | raise TypeError() 54 | 55 | 56 | def _insert_recursively( 57 | dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int 58 | ): 59 | if isinstance(dataset_dict, np.ndarray): 60 | dataset_dict[insert_index] = data_dict 61 | elif isinstance(dataset_dict, dict): 62 | assert dataset_dict.keys() == data_dict.keys() 63 | for k in dataset_dict.keys(): 64 | _insert_recursively(dataset_dict[k], data_dict[k], insert_index) 65 | else: 66 | raise TypeError() 67 | 68 | 69 | def _insert_recursively_batch( 70 | dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int, size: int 71 | ): 72 | if isinstance(dataset_dict, np.ndarray): 73 | dataset_dict[insert_index : insert_index + size] = data_dict 74 | elif isinstance(dataset_dict, dict): 75 | assert dataset_dict.keys() == data_dict.keys() 76 | for k in dataset_dict.keys(): 77 | _insert_recursively_batch(dataset_dict[k], data_dict[k], insert_index, size) 78 | else: 79 | raise TypeError() 80 | 81 | 82 | class ReplayBuffer(Dataset): 83 | def __init__( 84 | self, 85 | observation_space: gym.Space, 86 | action_space: gym.Space, 87 | capacity: int, 88 | next_observation_space: Optional[gym.Space] = None, 89 | ): 90 | if next_observation_space is None: 91 | next_observation_space = observation_space 92 | 93 | observation_data = _init_replay_dict(observation_space, capacity) 94 | next_observation_data = _init_replay_dict(next_observation_space, capacity) 95 | dataset_dict = dict( 96 | observations=observation_data, 97 | next_observations=next_observation_data, 98 | actions=np.empty((capacity, *action_space.shape), dtype=action_space.dtype), 99 | rewards=np.empty((capacity,), dtype=np.float32), 100 | masks=np.empty((capacity,), dtype=np.float32), 101 | dones=np.empty((capacity,), dtype=np.float32), 102 | ) 103 | 104 | super().__init__(dataset_dict) 105 | 106 | self._size = 0 107 | self._capacity = capacity 108 | self._insert_index = 0 109 | 110 | def __len__(self) -> int: 111 | return self._size 112 | 113 | def insert(self, data_dict: DatasetDict): 114 | _insert_recursively(self.dataset_dict, data_dict, self._insert_index) 115 | 116 | self._insert_index = (self._insert_index + 1) % self._capacity 117 | self._size = min(self._size + 1, self._capacity) 118 | 119 | def insert_batch(self, data_dict: DatasetDict): 120 | first_key = list(data_dict.keys())[0] 121 | batch_size = data_dict[first_key].shape[0] 122 | 123 | if self._insert_index + batch_size > self._capacity: 124 | self._insert_index = 0 125 | self._size = max(self._size, self._insert_index + batch_size) 126 | _insert_recursively_batch( 127 | self.dataset_dict, data_dict, self._insert_index, batch_size 128 | ) 129 | 130 | def get_iterator(self, queue_size: int = 2, sample_args: dict = {}): 131 | # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device 132 | # queue_size = 2 should be ok for one GPU. 133 | 134 | queue = collections.deque() 135 | 136 | def enqueue(n): 137 | for _ in range(n): 138 | data = self.sample(**sample_args) 139 | queue.append(jax.device_put(data)) 140 | 141 | enqueue(queue_size) 142 | while queue: 143 | yield queue.popleft() 144 | enqueue(1) 145 | -------------------------------------------------------------------------------- /rlpd/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/distributions/__init__.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from rlpd.distributions.tanh_deterministic import TanhDeterministic 32 | from rlpd.distributions.tanh_normal import Normal, TanhNormal 33 | -------------------------------------------------------------------------------- /rlpd/distributions/tanh_deterministic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/distributions/tanh_deterministic.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from typing import Type 32 | 33 | import flax.linen as nn 34 | import jax.numpy as jnp 35 | 36 | from rlpd.networks import default_init 37 | 38 | 39 | class TanhDeterministic(nn.Module): 40 | base_cls: Type[nn.Module] 41 | action_dim: int 42 | 43 | @nn.compact 44 | def __call__(self, inputs, *args, **kwargs) -> jnp.ndarray: 45 | x = self.base_cls()(inputs, *args, **kwargs) 46 | 47 | means = nn.Dense( 48 | self.action_dim, kernel_init=default_init(), name="OutputDenseMean" 49 | )(x) 50 | 51 | means = nn.tanh(means) 52 | 53 | return means 54 | -------------------------------------------------------------------------------- /rlpd/distributions/tanh_normal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/distributions/tanh_normal.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | import functools 32 | from typing import Optional, Type 33 | 34 | import tensorflow_probability 35 | 36 | from rlpd.distributions.tanh_transformed import TanhTransformedDistribution 37 | 38 | tfp = tensorflow_probability.substrates.jax 39 | tfd = tfp.distributions 40 | 41 | import flax.linen as nn 42 | import jax.numpy as jnp 43 | 44 | from rlpd.networks import default_init 45 | 46 | 47 | class Normal(nn.Module): 48 | base_cls: Type[nn.Module] 49 | action_dim: int 50 | log_std_min: Optional[float] = -20 51 | log_std_max: Optional[float] = 2 52 | state_dependent_std: bool = True 53 | squash_tanh: bool = False 54 | 55 | @nn.compact 56 | def __call__(self, inputs, *args, **kwargs) -> tfd.Distribution: 57 | x = self.base_cls()(inputs, *args, **kwargs) 58 | 59 | means = nn.Dense( 60 | self.action_dim, kernel_init=default_init(), name="OutputDenseMean" 61 | )(x) 62 | if self.state_dependent_std: 63 | log_stds = nn.Dense( 64 | self.action_dim, kernel_init=default_init(), name="OutputDenseLogStd" 65 | )(x) 66 | else: 67 | log_stds = self.param( 68 | "OutpuLogStd", nn.initializers.zeros, (self.action_dim,), jnp.float32 69 | ) 70 | 71 | log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max) 72 | 73 | distribution = tfd.MultivariateNormalDiag( 74 | loc=means, scale_diag=jnp.exp(log_stds) 75 | ) 76 | 77 | if self.squash_tanh: 78 | return TanhTransformedDistribution(distribution) 79 | else: 80 | return distribution 81 | 82 | 83 | TanhNormal = functools.partial(Normal, squash_tanh=True) 84 | -------------------------------------------------------------------------------- /rlpd/distributions/tanh_transformed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/distributions/tanh_transformed.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from typing import Any, Optional 32 | 33 | import tensorflow_probability 34 | 35 | tfp = tensorflow_probability.substrates.jax 36 | tfd = tfp.distributions 37 | tfb = tfp.bijectors 38 | 39 | import jax 40 | import jax.numpy as jnp 41 | 42 | # Inspired by 43 | # https://github.com/deepmind/acme/blob/300c780ffeb88661a41540b99d3e25714e2efd20/acme/jax/networks/distributional.py#L163 44 | # but modified to only compute a mode. 45 | 46 | 47 | class TanhTransformedDistribution(tfd.TransformedDistribution): 48 | def __init__(self, distribution: tfd.Distribution, validate_args: bool = False): 49 | super().__init__( 50 | distribution=distribution, bijector=tfb.Tanh(), validate_args=validate_args 51 | ) 52 | 53 | def mode(self) -> jnp.ndarray: 54 | return self.bijector.forward(self.distribution.mode()) 55 | 56 | @classmethod 57 | def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): 58 | td_properties = super()._parameter_properties(dtype, num_classes=num_classes) 59 | del td_properties["bijector"] 60 | return td_properties 61 | -------------------------------------------------------------------------------- /rlpd/evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/evaluation.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from typing import Dict 31 | 32 | import gym 33 | import numpy as np 34 | 35 | 36 | def evaluate(agent, env: gym.Env, num_episodes: int) -> Dict[str, float]: 37 | 38 | trajs = [] 39 | cum_returns = [] 40 | cum_lengths = [] 41 | for i in range(num_episodes): 42 | observation, done = env.reset(), False 43 | traj = [observation] 44 | cum_return = 0 45 | cum_length = 0 46 | while not done: 47 | action = agent.eval_actions(observation) 48 | observation, reward, done, _ = env.step(action) 49 | cum_return += reward 50 | cum_length += 1 51 | traj.append(observation) 52 | cum_returns.append(cum_return) 53 | cum_lengths.append(cum_length) 54 | trajs.append({"observation": np.stack(traj, axis=0)}) 55 | return {"return": np.mean(cum_returns), "length": np.mean(cum_lengths)}, trajs 56 | -------------------------------------------------------------------------------- /rlpd/gc_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from the official ICVF codebase: https://github.com/dibyaghosh/icvf_release/blob/main/src/gc_dataset.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2023 Dibya Ghosh 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | """ 28 | 29 | import dataclasses 30 | import numpy as np 31 | import jax 32 | import ml_collections 33 | from flax.core import frozen_dict 34 | 35 | 36 | @dataclasses.dataclass 37 | class GCDataset: 38 | dataset: dict # Not actually a dict, but I hate typing 39 | p_randomgoal: float 40 | p_trajgoal: float 41 | p_currgoal: float 42 | terminal_key: str = "dones" 43 | reward_scale: float = 1.0 44 | reward_shift: float = -1.0 45 | terminal: bool = True 46 | max_distance: int = None 47 | curr_goal_shift: int = 0 48 | 49 | @staticmethod 50 | def get_default_config(): 51 | return ml_collections.ConfigDict( 52 | { 53 | "p_randomgoal": 0.3, 54 | "p_trajgoal": 0.5, 55 | "p_currgoal": 0.2, 56 | "reward_scale": 1.0, 57 | "reward_shift": -1.0, 58 | "terminal": True, 59 | "max_distance": ml_collections.config_dict.placeholder(int), 60 | "curr_goal_shift": 0, 61 | } 62 | ) 63 | 64 | def __post_init__(self): 65 | (self.terminal_locs,) = np.nonzero( 66 | self.dataset.dataset_dict[self.terminal_key] > 0 67 | ) 68 | self.terminal_locs = np.concatenate( 69 | [self.terminal_locs, [len(self.dataset) - 1]], axis=0 70 | ) 71 | print(f"Number of terminal states: {len(self.terminal_locs)}") 72 | if len(self.terminal_locs) == 0: 73 | print("No terminal states found in dataset") 74 | self.terminal_locs = np.arange(100, len(self.dataset) + 100, 100) 75 | print("Manually setting terminal states to every 100th state") 76 | assert np.isclose(self.p_randomgoal + self.p_trajgoal + self.p_currgoal, 1.0) 77 | 78 | def sample_goals(self, indx, p_randomgoal=None, p_trajgoal=None, p_currgoal=None): 79 | if p_randomgoal is None: 80 | p_randomgoal = self.p_randomgoal 81 | if p_trajgoal is None: 82 | p_trajgoal = self.p_trajgoal 83 | if p_currgoal is None: 84 | p_currgoal = self.p_currgoal 85 | 86 | batch_size = len(indx) 87 | # Random goals 88 | goal_indx = np.random.randint( 89 | len(self.dataset) - self.curr_goal_shift, size=batch_size 90 | ) 91 | 92 | # Goals from the same trajectory 93 | final_state_indx = self.terminal_locs[np.searchsorted(self.terminal_locs, indx)] 94 | if self.max_distance is not None: 95 | final_state_indx = np.clip(final_state_indx, 0, indx + self.max_distance) 96 | 97 | distance = np.random.rand(batch_size) 98 | middle_goal_indx = np.round( 99 | ((indx) * distance + final_state_indx * (1 - distance)) 100 | ).astype(int) 101 | 102 | goal_indx = np.where( 103 | np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal), 104 | middle_goal_indx, 105 | goal_indx, 106 | ) 107 | 108 | # Goals at the current state 109 | goal_indx = np.where(np.random.rand(batch_size) < p_currgoal, indx, goal_indx) 110 | return goal_indx 111 | 112 | def sample(self, batch_size: int, indx=None): 113 | if indx is None: 114 | indx = np.random.randint(len(self.dataset) - 1, size=batch_size) 115 | 116 | batch = self.dataset.sample(batch_size, indx) 117 | goal_indx = self.sample_goals(indx) 118 | 119 | success = indx == goal_indx 120 | batch["rewards"] = success.astype(float) * self.reward_scale + self.reward_shift 121 | if self.terminal: 122 | batch["masks"] = 1.0 - success.astype(float) 123 | else: 124 | batch["masks"] = np.ones(batch_size) 125 | batch["goals"] = jax.tree_map( 126 | lambda arr: arr[goal_indx + self.curr_goal_shift], 127 | self.dataset["observations"], 128 | ) 129 | 130 | return batch 131 | 132 | 133 | @dataclasses.dataclass 134 | class GCSDataset(GCDataset): 135 | p_samegoal: float = 0.5 136 | intent_sametraj: bool = False 137 | 138 | @staticmethod 139 | def get_default_config(): 140 | return ml_collections.ConfigDict( 141 | { 142 | "p_randomgoal": 0.3, 143 | "p_trajgoal": 0.5, 144 | "p_currgoal": 0.2, 145 | "reward_scale": 1.0, 146 | "reward_shift": -1.0, 147 | "terminal": True, 148 | "p_samegoal": 0.5, 149 | "intent_sametraj": False, 150 | "max_distance": ml_collections.config_dict.placeholder(int), 151 | "curr_goal_shift": 0, 152 | } 153 | ) 154 | 155 | def sample(self, batch_size: int, indx=None): 156 | if indx is None: 157 | indx = np.random.randint(len(self.dataset) - 100, size=batch_size) 158 | 159 | batch = frozen_dict.unfreeze(self.dataset.sample(batch_size, indx=indx)) 160 | 161 | if self.intent_sametraj: 162 | desired_goal_indx = self.sample_goals( 163 | indx, 164 | p_randomgoal=0.0, 165 | p_trajgoal=1.0 - self.p_currgoal, 166 | p_currgoal=self.p_currgoal, 167 | ) 168 | else: 169 | desired_goal_indx = self.sample_goals(indx) 170 | 171 | goal_indx = self.sample_goals(indx) 172 | goal_indx = np.where( 173 | np.random.rand(batch_size) < self.p_samegoal, desired_goal_indx, goal_indx 174 | ) 175 | 176 | success = indx == goal_indx 177 | desired_success = indx == desired_goal_indx 178 | 179 | batch["rewards"] = success.astype(float) * self.reward_scale + self.reward_shift 180 | batch["desired_rewards"] = ( 181 | desired_success.astype(float) * self.reward_scale + self.reward_shift 182 | ) 183 | 184 | if self.terminal: 185 | batch["masks"] = 1.0 - success.astype(float) 186 | batch["desired_masks"] = 1.0 - desired_success.astype(float) 187 | 188 | else: 189 | batch["masks"] = np.ones(batch_size) 190 | batch["desired_masks"] = np.ones(batch_size) 191 | 192 | goal_indx = np.clip(goal_indx + self.curr_goal_shift, 0, len(self.dataset) - 1) 193 | desired_goal_indx = np.clip( 194 | desired_goal_indx + self.curr_goal_shift, 0, len(self.dataset) - 1 195 | ) 196 | batch["goals"] = self.dataset.sample(batch_size, indx=goal_indx)["observations"] 197 | batch["desired_goals"] = self.dataset.sample( 198 | batch_size, indx=desired_goal_indx 199 | )["observations"] 200 | 201 | return batch 202 | -------------------------------------------------------------------------------- /rlpd/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/networks/__init__.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from rlpd.networks.ensemble import Ensemble, subsample_ensemble 32 | from rlpd.networks.mlp import MLP, default_init 33 | from rlpd.networks.mlp_resnet import MLPResNetV2 34 | from rlpd.networks.pixel_multiplexer import PixelMultiplexer 35 | from rlpd.networks.state_action_value import ( 36 | StateActionValue, 37 | StateActionFeature, 38 | StateValue, 39 | StateFeature, 40 | ) 41 | -------------------------------------------------------------------------------- /rlpd/networks/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/networks/encoders/__init__.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from rlpd.networks.encoders.d4pg_encoder import D4PGEncoder 31 | -------------------------------------------------------------------------------- /rlpd/networks/encoders/d4pg_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/networks/encoders/d4pg_encoder.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from typing import Sequence 32 | 33 | import flax.linen as nn 34 | import jax.numpy as jnp 35 | 36 | from rlpd.networks import default_init 37 | 38 | 39 | class D4PGEncoder(nn.Module): 40 | features: Sequence[int] = (32, 32, 32, 32) 41 | filters: Sequence[int] = (2, 1, 1, 1) 42 | strides: Sequence[int] = (2, 1, 1, 1) 43 | padding: str = "VALID" 44 | 45 | @nn.compact 46 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 47 | assert len(self.features) == len(self.strides) 48 | 49 | for features, filter_, stride in zip(self.features, self.filters, self.strides): 50 | x = nn.Conv( 51 | features, 52 | kernel_size=(filter_, filter_), 53 | strides=(stride, stride), 54 | kernel_init=default_init(), 55 | padding=self.padding, 56 | )(x) 57 | x = nn.relu(x) 58 | 59 | return x.reshape((*x.shape[:-3], -1)) 60 | -------------------------------------------------------------------------------- /rlpd/networks/ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/networks/ensemble.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from typing import Type 31 | 32 | import flax.linen as nn 33 | import jax 34 | import jax.numpy as jnp 35 | from flax.core import FrozenDict 36 | 37 | 38 | class Ensemble(nn.Module): 39 | net_cls: Type[nn.Module] 40 | num: int = 2 41 | 42 | @nn.compact 43 | def __call__(self, *args): 44 | ensemble = nn.vmap( 45 | self.net_cls, 46 | variable_axes={"params": 0}, 47 | split_rngs={"params": True, "dropout": True}, 48 | in_axes=None, 49 | out_axes=0, 50 | axis_size=self.num, 51 | ) 52 | return ensemble()(*args) 53 | 54 | 55 | def subsample_ensemble(key: jax.random.PRNGKey, params, num_sample: int, num_qs: int): 56 | params = FrozenDict(params) 57 | if num_sample is not None: 58 | all_indx = jnp.arange(0, num_qs) 59 | indx = jax.random.choice(key, a=all_indx, shape=(num_sample,), replace=False) 60 | 61 | if "Ensemble_0" in params: 62 | ens_params = jax.tree_util.tree_map( 63 | lambda param: param[indx], params["Ensemble_0"] 64 | ) 65 | params = params.copy(add_or_replace={"Ensemble_0": ens_params}) 66 | else: 67 | params = jax.tree_util.tree_map(lambda param: param[indx], params) 68 | return params 69 | -------------------------------------------------------------------------------- /rlpd/networks/mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/networks/mlp.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from typing import Callable, Optional, Sequence 32 | 33 | import flax.linen as nn 34 | import jax.numpy as jnp 35 | 36 | default_init = nn.initializers.xavier_uniform 37 | 38 | 39 | class MLP(nn.Module): 40 | hidden_dims: Sequence[int] 41 | activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu 42 | activate_final: bool = False 43 | use_layer_norm: bool = False 44 | scale_final: Optional[float] = None 45 | dropout_rate: Optional[float] = None 46 | use_pnorm: bool = False 47 | 48 | @nn.compact 49 | def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: 50 | 51 | for i, size in enumerate(self.hidden_dims): 52 | if i + 1 == len(self.hidden_dims) and self.scale_final is not None: 53 | x = nn.Dense(size, kernel_init=default_init(self.scale_final))(x) 54 | else: 55 | x = nn.Dense(size, kernel_init=default_init())(x) 56 | 57 | if i + 1 < len(self.hidden_dims) or self.activate_final: 58 | if self.dropout_rate is not None and self.dropout_rate > 0: 59 | x = nn.Dropout(rate=self.dropout_rate)( 60 | x, deterministic=not training 61 | ) 62 | if self.use_layer_norm: 63 | x = nn.LayerNorm()(x) 64 | x = self.activations(x) 65 | if self.use_pnorm: 66 | x /= jnp.linalg.norm(x, axis=-1, keepdims=True).clip(1e-10) 67 | return x 68 | -------------------------------------------------------------------------------- /rlpd/networks/mlp_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/networks/mlp_resnet.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from typing import Any, Callable, Sequence 32 | 33 | import flax.linen as nn 34 | import jax.numpy as jnp 35 | from flax import linen as nn 36 | 37 | ModuleDef = Any 38 | 39 | 40 | class MLPResNetV2Block(nn.Module): 41 | """MLPResNet block.""" 42 | 43 | features: int 44 | act: Callable 45 | 46 | @nn.compact 47 | def __call__(self, x): 48 | residual = x 49 | y = nn.LayerNorm()(x) 50 | y = self.act(y) 51 | y = nn.Dense(self.features)(y) 52 | y = nn.LayerNorm()(y) 53 | y = self.act(y) 54 | y = nn.Dense(self.features)(y) 55 | 56 | if residual.shape != y.shape: 57 | residual = nn.Dense(self.features)(residual) 58 | 59 | return residual + y 60 | 61 | 62 | class MLPResNetV2(nn.Module): 63 | """MLPResNetV2.""" 64 | 65 | num_blocks: int 66 | features: int = 256 67 | dtype: Any = jnp.float32 68 | act: Callable = nn.relu 69 | 70 | @nn.compact 71 | def __call__(self, x, training=False): 72 | x = nn.Dense(self.features)(x) 73 | for _ in range(self.num_blocks): 74 | x = MLPResNetV2Block(self.features, act=self.act)(x) 75 | x = nn.LayerNorm()(x) 76 | x = self.act(x) 77 | return x 78 | -------------------------------------------------------------------------------- /rlpd/networks/pixel_multiplexer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/networks/pixel_multiplexer.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | from typing import Dict, Optional, Tuple, Type, Union 32 | 33 | import flax.linen as nn 34 | import jax 35 | import jax.numpy as jnp 36 | from flax.core.frozen_dict import FrozenDict 37 | 38 | from rlpd.networks import default_init 39 | 40 | 41 | class PixelMultiplexer(nn.Module): 42 | encoder_cls: Type[nn.Module] 43 | network_cls: Type[nn.Module] 44 | latent_dim: int 45 | stop_gradient: bool = False 46 | pixel_keys: Tuple[str, ...] = ("pixels",) 47 | depth_keys: Tuple[str, ...] = () 48 | 49 | @nn.compact 50 | def __call__( 51 | self, 52 | observations: Union[FrozenDict, Dict], 53 | actions: Optional[jnp.ndarray] = None, 54 | training: bool = False, 55 | ) -> jnp.ndarray: 56 | observations = FrozenDict(observations) 57 | if len(self.depth_keys) == 0: 58 | depth_keys = [None] * len(self.pixel_keys) 59 | else: 60 | depth_keys = self.depth_keys 61 | 62 | xs = [] 63 | for i, (pixel_key, depth_key) in enumerate(zip(self.pixel_keys, depth_keys)): 64 | x = observations[pixel_key].astype(jnp.float32) / 255.0 65 | if depth_key is not None: 66 | # The last dim is always for stacking, even if it's 1. 67 | x = jnp.concatenate([x, observations[depth_key]], axis=-2) 68 | 69 | x = jnp.reshape(x, (*x.shape[:-2], -1)) 70 | 71 | x = self.encoder_cls(name=f"encoder_{i}")(x) 72 | 73 | if self.stop_gradient: 74 | # We do not update conv layers with policy gradients. 75 | x = jax.lax.stop_gradient(x) 76 | 77 | x = nn.Dense(self.latent_dim, kernel_init=default_init())(x) 78 | x = nn.LayerNorm()(x) 79 | x = nn.tanh(x) 80 | xs.append(x) 81 | 82 | x = jnp.concatenate(xs, axis=-1) 83 | 84 | if "state" in observations: 85 | y = nn.Dense(self.latent_dim, kernel_init=default_init())( 86 | observations["state"] 87 | ) 88 | y = nn.LayerNorm()(y) 89 | y = nn.tanh(y) 90 | 91 | x = jnp.concatenate([x, y], axis=-1) 92 | 93 | if actions is None: 94 | return self.network_cls()(x, training) 95 | else: 96 | return self.network_cls()(x, actions, training) 97 | -------------------------------------------------------------------------------- /rlpd/networks/state_action_value.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/networks/state_action_value.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | 31 | import flax.linen as nn 32 | import jax.numpy as jnp 33 | 34 | from rlpd.networks import default_init 35 | 36 | 37 | class StateActionValue(nn.Module): 38 | base_cls: nn.Module 39 | 40 | @nn.compact 41 | def __call__( 42 | self, observations: jnp.ndarray, actions: jnp.ndarray, *args, **kwargs 43 | ) -> jnp.ndarray: 44 | inputs = jnp.concatenate([observations, actions], axis=-1) 45 | outputs = self.base_cls()(inputs, *args, **kwargs) 46 | 47 | value = nn.Dense(1, kernel_init=default_init())(outputs) 48 | 49 | return jnp.squeeze(value, -1) 50 | 51 | 52 | class StateActionFeature(nn.Module): 53 | base_cls: nn.Module 54 | feature_dim: int 55 | 56 | @nn.compact 57 | def __call__( 58 | self, observations: jnp.ndarray, actions: jnp.ndarray, *args, **kwargs 59 | ) -> jnp.ndarray: 60 | inputs = jnp.concatenate([observations, actions], axis=-1) 61 | outputs = self.base_cls()(inputs, *args, **kwargs) 62 | feature = nn.Dense(self.feature_dim, kernel_init=default_init())(outputs) 63 | return feature 64 | 65 | 66 | class StateValue(nn.Module): 67 | base_cls: nn.Module 68 | 69 | @nn.compact 70 | def __call__(self, observations: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: 71 | inputs = observations 72 | outputs = self.base_cls()(inputs, *args, **kwargs) 73 | 74 | value = nn.Dense(1, kernel_init=default_init())(outputs) 75 | 76 | return jnp.squeeze(value, -1) 77 | 78 | 79 | class StateFeature(nn.Module): 80 | base_cls: nn.Module 81 | feature_dim: int 82 | 83 | @nn.compact 84 | def __call__(self, observations: jnp.ndarray, *args, **kwargs) -> jnp.ndarray: 85 | inputs = observations 86 | outputs = self.base_cls()(inputs, *args, **kwargs) 87 | feature = nn.Dense(self.feature_dim, kernel_init=default_init())(outputs) 88 | return feature 89 | -------------------------------------------------------------------------------- /rlpd/types.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/types.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from typing import Any, Dict, Union 31 | 32 | import flax 33 | import numpy as np 34 | 35 | DataType = Union[np.ndarray, Dict[str, "DataType"]] 36 | PRNGKey = Any 37 | Params = flax.core.FrozenDict[str, Any] 38 | -------------------------------------------------------------------------------- /rlpd/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/wrappers/__init__.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | import gym 31 | from gym.wrappers.flatten_observation import FlattenObservation 32 | 33 | from rlpd.wrappers.pixels import wrap_pixels 34 | from rlpd.wrappers.single_precision import SinglePrecision 35 | from rlpd.wrappers.universal_seed import UniversalSeed 36 | 37 | 38 | def wrap_gym(env: gym.Env, rescale_actions: bool = True) -> gym.Env: 39 | env = SinglePrecision(env) 40 | env = UniversalSeed(env) 41 | if rescale_actions: 42 | env = gym.wrappers.RescaleAction(env, -1, 1) 43 | 44 | if isinstance(env.observation_space, gym.spaces.Dict): 45 | env = FlattenObservation(env) 46 | 47 | env = gym.wrappers.ClipAction(env) 48 | 49 | return env 50 | -------------------------------------------------------------------------------- /rlpd/wrappers/frame_stack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/wrappers/frame_stack.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | import collections 31 | 32 | import gym 33 | import numpy as np 34 | from gym.spaces import Box 35 | 36 | 37 | class FrameStack(gym.Wrapper): 38 | def __init__(self, env, num_stack: int, stacking_key: str = "pixels"): 39 | super().__init__(env) 40 | self._num_stack = num_stack 41 | self._stacking_key = stacking_key 42 | 43 | assert stacking_key in self.observation_space.spaces 44 | pixel_obs_spaces = self.observation_space.spaces[stacking_key] 45 | 46 | self._env_dim = pixel_obs_spaces.shape[-1] 47 | 48 | low = np.repeat(pixel_obs_spaces.low[..., np.newaxis], num_stack, axis=-1) 49 | high = np.repeat(pixel_obs_spaces.high[..., np.newaxis], num_stack, axis=-1) 50 | new_pixel_obs_spaces = Box(low=low, high=high, dtype=pixel_obs_spaces.dtype) 51 | self.observation_space.spaces[stacking_key] = new_pixel_obs_spaces 52 | 53 | self._frames = collections.deque(maxlen=num_stack) 54 | 55 | def reset(self): 56 | obs = self.env.reset() 57 | for i in range(self._num_stack): 58 | self._frames.append(obs[self._stacking_key]) 59 | obs[self._stacking_key] = self.frames 60 | return obs 61 | 62 | @property 63 | def frames(self): 64 | return np.stack(self._frames, axis=-1) 65 | 66 | def step(self, action): 67 | obs, reward, done, info = self.env.step(action) 68 | self._frames.append(obs[self._stacking_key]) 69 | obs[self._stacking_key] = self.frames 70 | return obs, reward, done, info 71 | -------------------------------------------------------------------------------- /rlpd/wrappers/pixels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/wrappers/pixels.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | from typing import Optional, Tuple 31 | 32 | import gym 33 | from gym.wrappers.pixel_observation import PixelObservationWrapper 34 | 35 | from rlpd.wrappers.frame_stack import FrameStack 36 | from rlpd.wrappers.repeat_action import RepeatAction 37 | from rlpd.wrappers.universal_seed import UniversalSeed 38 | 39 | 40 | def wrap_pixels( 41 | env: gym.Env, 42 | action_repeat: int, 43 | image_size: int = 84, 44 | num_stack: Optional[int] = 3, 45 | camera_id: int = 0, 46 | pixel_keys: Tuple[str, ...] = ("pixels",), 47 | ) -> gym.Env: 48 | if action_repeat > 1: 49 | env = RepeatAction(env, action_repeat) 50 | 51 | env = UniversalSeed(env) 52 | env = gym.wrappers.RescaleAction(env, -1, 1) 53 | 54 | env = PixelObservationWrapper( 55 | env, 56 | pixels_only=True, 57 | render_kwargs={ 58 | "pixels": { 59 | "height": image_size, 60 | "width": image_size, 61 | "camera_id": camera_id, 62 | } 63 | }, 64 | pixel_keys=pixel_keys, 65 | ) 66 | 67 | if num_stack is not None: 68 | env = FrameStack(env, num_stack=num_stack) 69 | 70 | env = gym.wrappers.ClipAction(env) 71 | 72 | return env, pixel_keys 73 | -------------------------------------------------------------------------------- /rlpd/wrappers/repeat_action.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/wrappers/repeat_action.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | import gym 31 | import numpy as np 32 | 33 | 34 | class RepeatAction(gym.Wrapper): 35 | def __init__(self, env, action_repeat=4): 36 | super().__init__(env) 37 | self._action_repeat = action_repeat 38 | 39 | def step(self, action: np.ndarray): 40 | total_reward = 0.0 41 | done = None 42 | combined_info = {} 43 | 44 | for _ in range(self._action_repeat): 45 | obs, reward, done, info = self.env.step(action) 46 | total_reward += reward 47 | combined_info.update(info) 48 | if done: 49 | break 50 | 51 | return obs, total_reward, done, combined_info 52 | -------------------------------------------------------------------------------- /rlpd/wrappers/single_precision.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/wrappers/single_precision.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | import copy 31 | 32 | import gym 33 | import numpy as np 34 | from gym.spaces import Box, Dict 35 | 36 | 37 | def _convert_space(obs_space): 38 | if isinstance(obs_space, Box): 39 | obs_space = Box(obs_space.low, obs_space.high, obs_space.shape) 40 | elif isinstance(obs_space, Dict): 41 | for k, v in obs_space.spaces.items(): 42 | obs_space.spaces[k] = _convert_space(v) 43 | obs_space = Dict(obs_space.spaces) 44 | else: 45 | raise NotImplementedError 46 | return obs_space 47 | 48 | 49 | def _convert_obs(obs): 50 | if isinstance(obs, np.ndarray): 51 | if obs.dtype == np.float64: 52 | return obs.astype(np.float32) 53 | else: 54 | return obs 55 | elif isinstance(obs, dict): 56 | obs = copy.copy(obs) 57 | for k, v in obs.items(): 58 | obs[k] = _convert_obs(v) 59 | return obs 60 | 61 | 62 | class SinglePrecision(gym.ObservationWrapper): 63 | def __init__(self, env): 64 | super().__init__(env) 65 | 66 | obs_space = copy.deepcopy(self.env.observation_space) 67 | self.observation_space = _convert_space(obs_space) 68 | 69 | def observation(self, observation): 70 | return _convert_obs(observation) 71 | -------------------------------------------------------------------------------- /rlpd/wrappers/universal_seed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/ikostrikov/rlpd/blob/main/rlpd/wrappers/universal_seed.py 3 | 4 | Original lincense information: 5 | 6 | MIT License 7 | 8 | Copyright (c) 2022 Ilya Kostrikov, Philip J. Ball, Laura Smith 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | import gym 31 | 32 | 33 | class UniversalSeed(gym.Wrapper): 34 | def seed(self, seed: int): 35 | seeds = self.env.seed(seed) 36 | self.env.observation_space.seed(seed) 37 | self.env.action_space.seed(seed) 38 | return seeds 39 | -------------------------------------------------------------------------------- /run_antmaze.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | python train_finetuning.py --eval_episodes=10 --checkpoint_buffer=True --exp_prefix=exp_data/ours --env_name=antmaze-medium-diverse-v2 --max_steps=300000 --config.backup_entropy=False --config.num_min_qs=1 --project_name=release-explore-antmaze-save --offline_relabel_type=pred --seed=0 --use_rnd_offline=True --use_rnd_online=False 8 | python train_finetuning.py --eval_episodes=10 --checkpoint_buffer=True --exp_prefix=exp_data/naive --env_name=antmaze-medium-diverse-v2 --max_steps=300000 --config.backup_entropy=False --config.num_min_qs=1 --project_name=release-explore-antmaze-save --offline_relabel_type=pred --seed=0 --use_rnd_offline=False --use_rnd_online=False 9 | python train_finetuning.py --eval_episodes=10 --checkpoint_buffer=True --exp_prefix=exp_data/online_rnd --env_name=antmaze-medium-diverse-v2 --max_steps=300000 --config.backup_entropy=False --config.num_min_qs=1 --project_name=release-explore-antmaze-save --offline_ratio=0 --seed=0 --use_rnd_offline=False --use_rnd_online=True 10 | python train_finetuning.py --eval_episodes=10 --checkpoint_buffer=True --exp_prefix=exp_data/online --env_name=antmaze-medium-diverse-v2 --max_steps=300000 --config.backup_entropy=False --config.num_min_qs=1 --project_name=release-explore-antmaze-save --offline_ratio=0 --seed=0 --use_rnd_offline=False --use_rnd_online=False 11 | -------------------------------------------------------------------------------- /submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description="Process some integers.") 12 | parser.add_argument("-j", default=4, type=int) 13 | parser.add_argument("--partition", type=str) 14 | parser.add_argument("--name", type=str) 15 | parser.add_argument("--conda_env_name", type=str) 16 | parser.add_argument("--constraint", type=str) 17 | 18 | args, unknown = parser.parse_known_args() 19 | 20 | partition = args.partition 21 | name = args.name 22 | conda_env_name = args.conda_env_name 23 | constraint = args.constraint 24 | 25 | print(unknown) 26 | 27 | 28 | def parse(args): 29 | prefix = "" 30 | for index in range(len(args)): 31 | prefix += " " 32 | arg = args[index] 33 | i = arg.find("=") 34 | if i == -1: 35 | content = arg 36 | else: 37 | prefix += arg[: i + 1] 38 | content = arg[i + 1 :] 39 | 40 | if "," in content: 41 | elements = content.split(",") 42 | for r in parse(args[index + 1 :]): 43 | for element in elements: 44 | yield prefix + element + r 45 | return 46 | else: 47 | prefix += content 48 | yield prefix 49 | 50 | 51 | python_command_list = list(parse(unknown)) 52 | 53 | num_jobs = len(python_command_list) 54 | 55 | num_arr = (num_jobs - 1) // args.j + 1 56 | 57 | print("\n".join(python_command_list)) 58 | 59 | path = os.getcwd() 60 | 61 | d_str = "\n ".join( 62 | [ 63 | "[{}]='{}'".format(i + 1, command[1:]) 64 | for i, command in enumerate(python_command_list) 65 | ] 66 | ) 67 | 68 | sbatch_str = f"""#!/bin/bash 69 | #SBATCH --job-name=explore 70 | #SBATCH --open-mode=append 71 | #SBATCH --output=logs/out/%x_%j.txt 72 | #SBATCH --error=logs/err/%x_%j.txt 73 | #SBATCH --time=48:00:00 74 | #SBATCH --array=1-{num_arr} 75 | 76 | #SBATCH --partition={partition} 77 | #SBATCH --ntasks-per-node=1 78 | #SBATCH --cpus-per-task=10 79 | #SBATCH --gpus-per-node=1 80 | #SBATCH --constraint={constraint} 81 | 82 | TASK_ID=$((SLURM_ARRAY_TASK_ID-1)) 83 | PARALLEL_N={args.j} 84 | JOB_N={num_jobs} 85 | 86 | COM_ID_S=$((TASK_ID * PARALLEL_N + 1)) 87 | 88 | source ~/.bashrc 89 | 90 | conda activate {conda_env_name} 91 | 92 | declare -a commands=( 93 | {d_str} 94 | ) 95 | 96 | cd {path} 97 | 98 | parallel --delay 20 --linebuffer -j {args.j} {{1}} ::: \"${{commands[@]:$COM_ID_S:$PARALLEL_N}}\" 99 | """ 100 | 101 | with open(f"sbatch/{name}.sh", "w") as f: 102 | f.write(sbatch_str) 103 | -------------------------------------------------------------------------------- /submit_all.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #!/bin/bash 8 | 9 | mkdir logs/out/ -p 10 | mkdir logs/err/ -p 11 | 12 | sbatch sbatch/antmaze-main.sh 13 | sbatch sbatch/antmaze-bc-jsrl.sh 14 | sbatch sbatch/antmaze-bc.sh 15 | sbatch sbatch/antmaze-online.sh 16 | sbatch sbatch/antmaze-oracle.sh 17 | sbatch sbatch/antmaze-min.sh 18 | 19 | sbatch sbatch/adroit-main.sh 20 | sbatch sbatch/adroit-main-relocate-reset.sh 21 | sbatch sbatch/adroit-bc-jsrl.sh 22 | sbatch sbatch/adroit-bc.sh 23 | sbatch sbatch/adroit-min.sh 24 | sbatch sbatch/adroit-online.sh 25 | sbatch sbatch/adroit-oracle.sh 26 | 27 | sbatch sbatch/cog-bc-jsrl.sh 28 | sbatch sbatch/cog-bc.sh 29 | sbatch sbatch/cog-main-ours.sh 30 | sbatch sbatch/cog-main-oracle.sh 31 | sbatch sbatch/cog-min.sh 32 | sbatch sbatch/cog-online.sh 33 | --------------------------------------------------------------------------------