├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── VRL3.png ├── docker └── Dockerfile ├── plot_utils ├── drq_converter.py ├── plot_config.py ├── rrl_result_converter.py ├── vrl3_plot_example.py ├── vrl3_plot_helper.py ├── vrl3_plot_mover.py └── vrl3_plot_runner.py └── src ├── .gitignore ├── adroit.py ├── cfgs_adroit ├── config.yaml └── task │ ├── door.yaml │ ├── hammer.yaml │ ├── pen.yaml │ └── relocate.yaml ├── dmc.py ├── logger.py ├── replay_buffer.py ├── rrl_local ├── rrl_encoder.py ├── rrl_multicam.py └── rrl_utils.py ├── stage1_models.py ├── testing ├── computation_time_test.py └── zzz.py ├── train_adroit.py ├── train_stage1.py ├── transfer_util.py ├── utils.py ├── video.py └── vrl3_agent.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .idea/ 141 | 142 | # experiment files 143 | exp_local/ 144 | data/ 145 | 146 | images/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Visual Deep Reinforcement Learning in 3 Stages (VRL3) 2 | 3 | Official code for the paper VRL3: A Data-Driven Framework for Visual Deep Reinforcement Learning. Summary site: https://sites.google.com/nyu.edu/vrl3. 4 | 5 | ![CheWang](VRL3.png) 6 | 7 | Code has just been released and the entire codebase is re-written to make it cleaner and improve readability, if you run into any problem, please do not hesitate to open an issue. 8 | 9 | We are also doing some further clean-up of the code now. This repo will be updated. 10 | 11 | 12 | 13 | ### Table of Contents 14 | 15 | - [Repo structure](#repo-structure) 16 | - [Environment setup](#environment-setup) 17 | - [Docker setup](#docker) 18 | - [Run experiments](#run-exp) 19 | - [Singularity setup](#singularity) 20 | - [Plotting example](#plotting) 21 | - [Technical details](#hyper) 22 | - [Computation time](#computation) 23 | - [Known issues](#known-issues) 24 | - [Acknowledgement](#acknowledgement) 25 | - [Citation](#citation) 26 | - [Contributing](#contributing) 27 | 28 | 29 | ### Updates: 30 | 31 | 03/30/2023: added example plot function and a quick tutorial. 32 | 33 | 34 | 35 | 36 | ## Repo structure and important files: 37 | 38 | ``` 39 | VRL3 # this repo 40 | │ README.md # read this file first! 41 | └───docker # dockerfile with all dependencies 42 | └───plot_utils # code for plotting, still working on it now... 43 | └───src 44 | │ train_adroit.py # setup and main training loop 45 | │ vrl3_agent.py # agent class, code for stage 2, 3 46 | │ train_stage1.py # code for stage 1 pretraining on imagenet 47 | │ stage1_models.py # the encoder classes pretrained in stage 1 48 | └───cfgs_adroit # configuration files with all hyperparameters 49 | 50 | # download these folders from the google drive link 51 | vrl3data 52 | └───demonstrations # adroit demos 53 | └───trained_models # pretrained stage 1 models 54 | vrl3examplelogs 55 | └───rrl # rrl training logs 56 | └───vrl3 # vrl3 with default hyperparams logs 57 | ``` 58 | 59 | To get started, download this repo and download adroit demos, pretrained models, and example logs with the following link: 60 | https://drive.google.com/drive/folders/1j-7BKlYmknVCBfzO3MnLLD62JBngoMkn?usp=sharing 61 | 62 | 63 | 64 | ## Environment setup 65 | 66 | The recommended way is to just use the dockerfile I provided and follow the tutorial here. You can also look at the dockerfile to know the exact dependencies or modify it to build a new dockerfile. 67 | 68 | 69 | 70 | ### Setup with docker 71 | 72 | If you have a local machine with gpu, or your cluster allows docker (you have sudo), then you can just pull my docker image and run code there. (Newest version is 1.5, where the mujoco slow rendering with gpu issue is fixed). 73 | ``` 74 | docker pull docker://cwatcherw/vrl3:1.5 75 | ``` 76 | 77 | Now, `cd` into a directory where you have the `VRL3` folder (this repo), and also the `vrl3data` folder that you downloaded from my google drive link. 78 | Then, mount `VRL3/src` to `/code`, and mount `vrl3data` to `/vrl3data` (you can also mount to other places, but you will need to adjust some commands or paths in the config files): 79 | ``` 80 | docker run -it --rm --gpus all -v "$(pwd)"/VRL3/src:/code -v "$(pwd)"/vrl3data:/vrl3data docker://cwatcherw/vrl3:1.5 81 | ``` 82 | Now you should be inside the docker container. Refer to the "Run experiments" section now. 83 | 84 | 85 | 86 | ### Run experiments 87 | 88 | Once you get into the container (either docker or singularity), first run the following commands so the paths are correct. Very important especially on singularity since it uses automount which can mess up the paths. (newest version code now uses `os.environ` to do these so you can also skip this step.) 89 | ``` 90 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/workspace/.mujoco/mujoco210/bin 91 | export MUJOCO_PY_MUJOCO_PATH=/workspace/.mujoco/mujoco210/ 92 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/nvidia/lib 93 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/workspace/.mujoco/mujoco210/bin 94 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia 95 | export MUJOCO_GL=egl 96 | ``` 97 | 98 | Go to the VRL3 code directory that you mounted. 99 | ``` 100 | cd /code 101 | ``` 102 | 103 | First quickly check if mujoco is using your GPU correctly for rendering. If everything is correct, you should see the program print out the computation time for 1000 rendering (if it's first time mujoco is imported then there will also be mujoco build messages which takes a few minutes). The time used to do rendering 1000 times should be < 0.5 seconds. 104 | ``` 105 | python testing/computation_time_test.py 106 | ``` 107 | 108 | 109 | Then you can start run VRL3: 110 | ``` 111 | cd /code 112 | python train_adroit.py task=door 113 | ``` 114 | 115 | For first-time setup, use `debug=1` to do a quick test run to see if the code is working. This will reduce training epochs and change many other hyperparameters so you will get a full run in a few minutes. 116 | ``` 117 | python train_adroit.py task=door debug=1 118 | ``` 119 | 120 | You can also run with different hyperparameters, see the `config.yaml` for a full list of them. For example: 121 | ``` 122 | python train_adroit.py task=door stage2_n_update=5000 agent.encoder_lr_scale=0.1 123 | ``` 124 | 125 | 126 | 127 | ### Setup with singularity 128 | 129 | If your cluster does not allow sudo (for example, NYU's slurm HPC), then you can use singularity, it is similar to docker. But you might need to modify some of the commands depends on how your cluster is being managed. Here is an example setup on the NYU Greene HPC. 130 | 131 | Set up singularity container (this will make a folder called `sing` in your scratch directory, and then build a singularity sandbox container called `vrl3sing`, using the `cwatcherw/vrl3:1.5` docker container which I put on my docker hub): 132 | ``` 133 | mkdir /scratch/$USER/sing/ 134 | cd /scratch/$USER/sing/ 135 | singularity build --sandbox vrl3sing docker://cwatcherw/vrl3:1.5 136 | ``` 137 | 138 | For example, on NYU HPC, start interactive session (if your school has a different hpc system, consult your hpc admin): 139 | ``` 140 | srun --pty --gres=gpu:1 --cpus-per-task=4 --mem 12000 -t 0-06:00 bash 141 | ``` 142 | Here, by default VRL3 uses 4 workers for dataloader, so we request 4 cpus. Once the job is allocated to you, go the `sing` folder where you have your container, then run it: 143 | ``` 144 | cd /scratch/$USER/sing 145 | singularity exec --nv -B /scratch/$USER/sing/VRL3/src:/code -B /scratch/$USER/sing/vrl3sing/opt/conda/lib/python3.8/site-packages/mujoco_py/:/opt/conda/lib/python3.8/site-packages/mujoco_py/ -B /scratch/$USER/sing/vrl3data:/vrl3data /scratch/$USER/sing/vrl3sing bash 146 | ``` 147 | We mount the `mujoco_py` package folder because singularity files by default are read-only, and the older version of mujoco_py wants to modify files, which can be problematic. (Unfortunately, Adroit env relies on the older version of mujoco) 148 | 149 | After the singularity container started running, now refer to the "Run experiments" section. 150 | 151 | 152 | 153 | ## Plotting example 154 | 155 | If you like to use the plotting functions we used, you will need `matplotlib`, `seaborn` and some other basic packages to use the plotting programs. You can also use your own plotting functions. 156 | 157 | An example is given in `plot_utils/vrl3_plot_example.py`. To use it: 158 | 1. make sure you downloaded the `vrl3examplelogs` folder from the drive link and unzipped it. 159 | 2. in `plot_utils/vrl3_plot_example.py`, change the `base_dir` path to where the `vrl3examplelogs` folder is on your computer. 160 | 3. similarlly, change `base_save_dir` path to where you want the figures to be generated. 161 | 4. run `plot_utils/vrl3_plot_example.py`, this will generate a few figures comparing success rate between RRL and VRL3 to the specified path. 162 | 163 | (All ablation experiment logs generated during the VRL3 research are in the folder `vrl3logs` from the drive link. `plot_utils/vrl3_plot_runner.py` was used to generate figures in the paper. Currently still need further clean up.) 164 | 165 | 166 | 167 | ## Technical details 168 | 169 | - BC loss: in the config files, I now by default disable all BC loss since our ablations show they are not really helping. 170 | - under `src/cfgs_adroit/task/relocate.yaml` you will see that relocate has `encoder_lr_scale: 0.01`, as shown in the paper, relocate requires a smaller encoder learning rate. You can set specific default parameters for each task in their separate config files. 171 | - in the paper for most experiments, I used `frame_stack=3`, however later I found we can reduce it to 1 and still get the same performance. It might be beneficial to set it to 1 so it runs faster and takes less memory. If you set this to 1, then convolutional channel expansion will only be applied for the relocate env, where the input is a stack of 3 camera images. 172 | - all values in table 2 in appendix A.2 of the paper are set to be the default values in the config files. 173 | - to apply VRL3 to other environments, please consult the hyperparameter sensitivity study in appendix A, which identifies robust and sensitive hyperparameters. 174 | 175 | 176 | 177 | ### Computation time 178 | 179 | This table compares the computation time estimates for the open source code with default hyperparameters (tested on NYU Greene with RTX 8000 and 4 cpus). When you use the code on your machine, it might be slightly faster or slower, but should not be too different. These results seem to be slightly faster than what we reported in the paper (which tested on Azure P100 GPU machines). Improved computation speed is mainly due to we now set default `frame_stack` for Adroit. 180 | 181 | | Task | Stage 2 (30K updates) | Stage 3 (4M frames) | Total | Total (paper) | 182 | |------------------|-----------------------|---------------------|---------|------------| 183 | | Door/Pen/Hammer | ~0.5 hrs | ~13 hrs | ~14 hrs | ~16 hrs | 184 | | Relocate | ~0.5 hrs | ~16 hrs | ~17 hrs | ~24 hrs | 185 | 186 | Note that VRL3's performance kind of converged already at 1M data for Door, Hammer and Relocate. So depending on what you want to achieve in your work, you may or may not need to run a full 4M frames. In the paper we run to 4M to be consistent with prior work and show VRL3 can outperform previous SOTA in both short-term and long-term performance. 187 | 188 | 189 | 190 | ### Known issues: 191 | 192 | - Some might encounter a problem where mujoco can crush at an arbitrary point during training. I have not seen this issue before but I was told reinit `self.train_env` between stage 2 and stage 3 can fix it. 193 | - If you are not using the provided docker image and you run into the problem of slow rendering, it is possible that mujoco did not find your gpu and made a `CPUExtender` instead of a `GPUExtender`. You can follow the steps in the provided dockerfile, or force it to use the `GPUExtender` (see code in `mujoco-py/mujoco_py/builder.py`) Thanks to ZheCheng Yuan for identifying above 2 issues. 194 | - Newer versions of mujoco are easier to work with. We use an older version only because Adroit relies on it. (So you can try a newer mujoco if you want to test on other environments). 195 | 196 | 197 | 198 | ## Acknowledgement 199 | 200 | VRL3 code has been mainly built on top of the DrQv2 codebase (https://github.com/facebookresearch/drqv2). Some utility functions and dockerfile are modified from the REDQ codebase (https://github.com/watchernyu/REDQ). The Adroit demo loading code is modified from the RRL codebase (https://github.com/facebookresearch/RRL). 201 | 202 | 203 | 204 | ## Citation 205 | 206 | If you use VRL3 in your research, please consider citing the paper as: 207 | ``` 208 | @inproceedings{wang2022vrl3, 209 | title={VRL3: A Data-Driven Framework for Visual Deep Reinforcement Learning}, 210 | author={Wang, Che and Luo, Xufang and Ross, Keith and Li, Dongsheng}, 211 | booktitle={Conference on Neural Information Processing Systems}, 212 | year={2022}, 213 | url={https://openreview.net/forum?id=NjKAm5wMbo2} 214 | } 215 | ``` 216 | 217 | 218 | 219 | ## Contributing 220 | 221 | 222 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 223 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 224 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 225 | 226 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 227 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 228 | provided by the bot. You will only need to do this once across all repos using our CLA. 229 | 230 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 231 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 232 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 233 | 234 | 235 | 236 | ## Trademarks 237 | 238 | 239 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 240 | trademarks or logos is subject to and must follow 241 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 242 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 243 | Any use of third-party trademarks or logos are subject to those third-party's policies. 244 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /VRL3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/VRL3/f22e4d197d953663c8e77173b0aa639ea99d17a8/VRL3.png -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # cudagl with miniconda and python 3.8, pytorch, mujoco and gym v2 envs, REDQ 2 | # current corresponding image on dockerhub: docker://cwatcherw/vrl3:1.5 3 | 4 | FROM nvidia/cudagl:11.0-base-ubuntu18.04 5 | WORKDIR /workspace 6 | ENV HOME=/workspace 7 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 8 | ENV PATH /opt/conda/bin:$PATH 9 | 10 | # idea: start with a nvidia docker with gl support (guess this one also has cuda?) 11 | # then install miniconda, borrowing docker command from miniconda's Dockerfile (https://hub.docker.com/r/continuumio/anaconda/dockerfile/) 12 | # need to make sure the miniconda python version is what we need (https://docs.conda.io/en/latest/miniconda.html for the right version) 13 | # then install other dependencies we need 14 | 15 | # nvidia GPG key alternative fix (https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772) 16 | # sudo apt-key del 7fa2af80 17 | # wget https://developer.download.nvidia.com/compute/cuda/repos/$distro/$arch/cuda-keyring_1.0-1_all.deb 18 | # sudo dpkg -i cuda-keyring_1.0-1_all.deb 19 | 20 | RUN \ 21 | # Update nvidia GPG key 22 | rm /etc/apt/sources.list.d/cuda.list \ 23 | && rm /etc/apt/sources.list.d/nvidia-ml.list \ 24 | && apt-key del 7fa2af80 \ 25 | && apt-get update && apt-get install -y --no-install-recommends wget \ 26 | && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb \ 27 | && dpkg -i cuda-keyring_1.0-1_all.deb \ 28 | && apt-get update 29 | 30 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \ 31 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 32 | git mercurial subversion 33 | 34 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_4.12.0-Linux-x86_64.sh -O ~/miniconda.sh \ 35 | && /bin/bash ~/miniconda.sh -b -p /opt/conda \ 36 | && rm ~/miniconda.sh \ 37 | && ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh \ 38 | && echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc \ 39 | && echo "conda activate base" >> ~/.bashrc 40 | 41 | RUN apt-get install -y curl grep sed dpkg \ 42 | # TINI_VERSION=`curl https://github.com/krallin/tini/releases/latest | grep -o "/v.*\"" | sed 's:^..\(.*\).$:\1:'` && \ 43 | # curl -L "https://github.com/krallin/tini/releases/download/v${TINI_VERSION}/tini_${TINI_VERSION}.deb" > tini.deb && \ 44 | # dpkg -i tini.deb && \ 45 | # rm tini.deb && \ 46 | && apt-get clean 47 | 48 | # Install some basic utilities 49 | RUN apt-get update && apt-get install -y software-properties-common \ 50 | && add-apt-repository -y ppa:redislabs/redis && apt-get update \ 51 | && apt-get install -y sudo ssh libx11-6 gcc iputils-ping \ 52 | libxrender-dev graphviz tmux htop build-essential wget cmake libgl1-mesa-glx redis \ 53 | && rm -rf /var/lib/apt/lists/* 54 | 55 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive \ 56 | && apt-get install -y zlib1g zlib1g-dev libosmesa6-dev libgl1-mesa-glx libglfw3 libglew2.0 57 | # && ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so 58 | # ---------- now we should have all major dependencies ------------ 59 | 60 | # --------- now we have cudagl + python38 --------- 61 | RUN pip install --no-cache-dir torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 62 | 63 | RUN pip install --no-cache-dir scikit-learn pandas imageio 64 | 65 | # adroit env does not work with newer version of mujoco, so we have to use the old version... 66 | # setting MUJOCO_GL to egl is needed if run on headless machine 67 | ENV MUJOCO_GL=egl 68 | ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/workspace/.mujoco/mujoco210/bin 69 | RUN mkdir -p /workspace/.mujoco \ 70 | && wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz -O mujoco210.tar.gz \ 71 | && tar -xvzf mujoco210.tar.gz -C /workspace/.mujoco \ 72 | && rm mujoco210.tar.gz 73 | 74 | # RL env: get mujoco and gym (v2 version environments) 75 | RUN pip install --no-cache-dir patchelf==0.17.2.0 mujoco-py==2.1.2.14 gym==0.21.0 76 | 77 | # first time import mujoco_py will take some time to run, do it here to avoid the wait later 78 | RUN echo "import mujoco_py" >> /workspace/import_mujoco.py \ 79 | && python /workspace/import_mujoco.py 80 | 81 | # get dmc with specific version 82 | RUN cd /workspace/ \ 83 | && git clone https://github.com/deepmind/dm_control.git \ 84 | && cd dm_control \ 85 | && git checkout 644d9e0 \ 86 | && pip install -e . 87 | 88 | # dependencies for DrQv2 and VRL3 (mainly from https://github.com/facebookresearch/drqv2/blob/main/conda_env.yml) 89 | RUN pip install --no-cache-dir absl-py==0.13.0 pyparsing==2.4.7 jupyterlab==3.0.14 scikit-image \ 90 | termcolor==1.1.0 imageio-ffmpeg==0.4.4 hydra-core==1.1.0 hydra-submitit-launcher==1.1.5 ipdb==0.13.9 \ 91 | yapf==0.31.0 opencv-python==4.5.3.56 psutil tb-nightly 92 | 93 | # get adroit (mainly following https://github.com/facebookresearch/RRL, with small changes to avoid conflict with drqv2 dependencies) 94 | RUN cd /workspace/ \ 95 | && git clone https://github.com/watchernyu/rrl-dependencies.git \ 96 | && cd rrl-dependencies \ 97 | && pip install -e mj_envs/. \ 98 | && pip install -e mjrl/. 99 | 100 | # important to have this so mujoco can compile stuff related to gpu rendering 101 | RUN apt-get update && apt-get install -y libglew-dev \ 102 | && rm -rf /var/lib/apt/lists/* 103 | 104 | # if we don't have this path `/usr/lib/nvidia`, the old mujoco can render with cpu only and not even give an error message 105 | # makes everything very slow and drive people insane 106 | RUN mkdir /usr/lib/nvidia 107 | 108 | CMD [ "/bin/bash" ] 109 | 110 | # # build docker container: 111 | # docker build . -t name:tag 112 | 113 | # # example docker command to run interactive container, enable gpu, and remove it when shutdown: 114 | # docker run -it --rm --gpus all name:tag 115 | -------------------------------------------------------------------------------- /plot_utils/drq_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # NOTE: plotting helper code is currently being cleaned up 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import os 9 | from numpy import genfromtxt 10 | 11 | source_path = 'C:\\z\\drq_author' 12 | for subdir, dirs, files in os.walk(source_path): 13 | for file in files: 14 | if '_' in file and '.csv' in file: 15 | full_path = os.path.join(subdir, file) 16 | d = pd.read_csv(full_path, index_col=0) 17 | task = file[4:-4] 18 | for seed in range(1, 11): 19 | print(seed) 20 | d_seed = d[d.seed == seed] 21 | # now basically save them separately 22 | folder_name = 'drqv2_author_%s_s%d' % (task, seed) 23 | folder_path = os.path.join(source_path, folder_name) 24 | os.mkdir(folder_path) 25 | save_path = os.path.join(folder_path, 'eval.csv') 26 | d_seed.to_csv(save_path, index=False) 27 | print('saved to', save_path) 28 | 29 | -------------------------------------------------------------------------------- /plot_utils/plot_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # NOTE: plotting helper code is currently being cleaned up 5 | 6 | DEFAULT_AMLT_PATH = 'D:\\a' 7 | DEFAULT_BASE_DATA_PATH = '/home/watcher/Desktop/projects/vrl3logs' 8 | DEFAULT_BASE_SAVE_PATH = '/home/watcher/Desktop/projects/vrl3figures' -------------------------------------------------------------------------------- /plot_utils/rrl_result_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # NOTE: plotting helper code is currently being cleaned up 5 | 6 | # for each result file in each rrl data folder 7 | # load them and convert into a format we can use easily... 8 | import os 9 | 10 | import numpy as np 11 | import pandas 12 | from numpy import genfromtxt 13 | 14 | base_path = 'C:\\z\\abl' 15 | 16 | for subdir, dirs, files in os.walk(base_path): 17 | if 'log.csv' in files: 18 | # folder name is here is typically also the variant name 19 | folder_name = os.path.basename(os.path.normpath(subdir)) # e.g. drq_nstep3_pt5_0.001_hopper_hop_s1 20 | original_log_path = os.path.join(subdir, 'log.csv') 21 | 22 | seed_string_index = folder_name.rfind('s') 23 | name_no_seed = folder_name[:seed_string_index - 1] 24 | seed_value = int(folder_name[seed_string_index + 1:]) 25 | 26 | original_data = genfromtxt(original_log_path, dtype=float, delimiter=',', names=True) 27 | iters = original_data['iteration'] 28 | sr = original_data['success_rate'] 29 | # original_data['frames'] = iters * 40000 30 | new_df = { 31 | 'iter': np.arange(len(iters)), 32 | 'frame': iters * 40000, 33 | 'success_rate': sr/100 34 | } 35 | new_df = pandas.DataFrame(new_df) 36 | save_path = os.path.join(subdir, 'eval.csv') 37 | new_df.to_csv(save_path, index=False) 38 | save_path = os.path.join(subdir, 'train.csv') 39 | new_df.to_csv(save_path, index=False) 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /plot_utils/vrl3_plot_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import os 4 | from vrl3_plot_helper import plot_aggregate_0506 5 | 6 | ############################################################################## 7 | ########## change base_dir to where you downloaded the example logs ########## 8 | ############################################################################## 9 | base_dir = '/home/watcher/Desktop/projects/vrl3examplelogs' 10 | base_save_dir = '/home/watcher/Desktop/projects/vrl3figures' 11 | 12 | FRAMEWORK_NAME = 'VRL3' 13 | 14 | NUM_MILLION_ON_X_AXIS = 4 15 | DEFAULT_Y_MIN = -0.02 16 | DEFAULT_X_MIN = -50000 17 | DEFAULT_Y_MAX = 1.02 18 | DEFAULT_X_MAX = 4050000 19 | X_LONG = 12050000 20 | ADROIT_ALL_ENVS = ["door", "pen", "hammer","relocate"] 21 | adroit_success_rate_default_min_max_dict = {'ymin':DEFAULT_Y_MIN, 'ymax':DEFAULT_Y_MAX, 'xmin':DEFAULT_X_MIN, 'xmax':DEFAULT_X_MAX} 22 | adroit_other_value_default_min_max_dict = {'ymin':None, 'ymax':None, 'xmin':DEFAULT_X_MIN, 'xmax':DEFAULT_X_MAX} 23 | 24 | def decide_placeholder_prefix(envs, folder_name, plot_name): 25 | if isinstance(envs, list): 26 | if len(envs) > 1: 27 | return 'aggregate-' + folder_name, 'aggregate_' + plot_name 28 | else: 29 | return folder_name,plot_name +'_'+ envs[0] 30 | return folder_name, plot_name +'_'+ envs 31 | 32 | def plot_paper_main_more_aggregate(envs=ADROIT_ALL_ENVS, no_legend=False): 33 | save_folder_name, save_prefix = decide_placeholder_prefix(envs, 'fs', 'adroit_main') 34 | # below are data folders under the base logs folder, the program will try to find training logs under these folders 35 | list_of_data_dir = ['rrl', 'vrl3'] 36 | paths = [os.path.join(base_dir, data_dir) for data_dir in list_of_data_dir] 37 | save_folder = os.path.join(base_save_dir, save_folder_name) 38 | 39 | labels = [FRAMEWORK_NAME, 'RRL'] 40 | variants = [ 41 | ['vrl3_ar2_fs3_pixels_resnet6_32channel_pTrue_augTrue_lr0.0001_esauto_dbs64_pt5_0.001_s2eTrue_bc5000_cql30000_w1_s2std0.1_nr10_s3std0.01_bc0.001_0.5_s2d0_sa0_ra0_demo25_qt200_qf0.5_ef0_True_etTrue'], 42 | ['rrl'], 43 | ] 44 | colors = ['tab:red', 'tab:grey', 'tab:blue', 'tab:orange', 'tab:pink','tab:brown'] 45 | dashes = ['solid' for _ in range(6)] 46 | 47 | label2variants, label2colors, label2linestyle = {}, {}, {} 48 | for i, label in enumerate(labels): 49 | label2variants[label] = variants[i] 50 | label2colors[label] = colors[i] 51 | label2linestyle[label] = dashes[i] 52 | 53 | for label, variants in label2variants.items(): # add environment 54 | variants_with_envs = [] 55 | for variant in variants: 56 | for env in envs: 57 | variants_with_envs.append(variant + '_' + env) 58 | label2variants[label] = variants_with_envs 59 | 60 | for plot_y_value in ['success_rate']: 61 | save_parent_path = os.path.join(save_folder, plot_y_value) 62 | d = adroit_success_rate_default_min_max_dict if plot_y_value == 'success_rate' else adroit_other_value_default_min_max_dict 63 | save_name = '%s_%s' % (save_prefix, plot_y_value) 64 | print(save_parent_path, save_name) 65 | plot_aggregate_0506(paths, label2variants=label2variants, save_folder=save_parent_path, save_name=save_name, 66 | plot_y_value=plot_y_value, label2colors=label2colors, label2linestyle=label2linestyle, 67 | max_seed=11, no_legend=no_legend, **d) 68 | 69 | def plot_paper_main_more_per_task(): 70 | for env in ADROIT_ALL_ENVS: 71 | plot_paper_main_more_aggregate(envs=[env], no_legend=(env!='relocate')) 72 | 73 | plot_paper_main_more_per_task() 74 | plot_paper_main_more_aggregate() -------------------------------------------------------------------------------- /plot_utils/vrl3_plot_mover.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # NOTE: plotting helper code is currently being cleaned up 5 | 6 | from vrl3_plot_helper import * 7 | 8 | # this maps a data folder name to several amlt exp folder names. for example, 9 | data_folder_to_amlt_dict = { 10 | 'distill': ['driven-bedbug', 'enabled-starfish', 'humorous-octopus', 'singular-treefrog', # distill with cap 1, num_c 32, change distill weight 11 | 'fair-corgi', 'present-honeybee', # low number of channels 12 | 'guiding-earwig', 'bright-piranha' # high cap ul encoder 13 | ], 14 | 'n_compress':['coherent-caribou'], # change number of layer in the compression layers 15 | 'double':['select-hamster', 'modest-guppy', 'accepted-dragon'], # double encoder design, performance not great 16 | 'intermediate':['precious-hagfish', 'pleased-mosquito', 'resolved-cattle', 'selected-jawfish'], # single encoder, but try use intermediate layer output, see which layer gives best result 17 | 'double-add':['sweet-glider'], # double resnet with additional conv output, see if additional output for double resnet can help 18 | 'test-pg':['well-cow', 'optimum-corgi', 'robust-termite'], # double standard enc 8 variants, single res 8 variants. These results will help decide whether our hypothesis on policy gradient is correct or not 19 | 'naivelarge':['liberal-macaque'], # naively do single encoder, rl update, deep encoder 20 | 'newbaseline3s':['peaceful-anchovy', 'poetic-airedale', 'native-fawn', 'light-peacock', 'firm-starfish'], 21 | 'newdouble3s':['precious-koi', 'central-louse', 22 | 'sound-asp', 'welcomed-doe', 'fancy-polecat', 23 | 'oriented-kid' # humanoid 24 | ], 25 | 'utd':['innocent-gar','trusting-warthog'], 26 | 'pt': ['major-snipe', 'oriented-fox', 'brief-barnacle', 'superb-shepherd', 'golden-bobcat', 27 | 'capital-alpaca', 'premium-ringtail', # stage 2 28 | ], 29 | 'drq-base':['enabling-vervet',], # drq baseline with drq original code 30 | 'tl': ['able-starling' ,# tl freeze on easy 6 31 | ], 32 | 'ssl': ['on-gazelle', 'sincere-walrus', # freeze 33 | 'light-sheep', 'settled-mastiff', 'cheerful-muskrat'], 34 | 's2-base':['needed-hedgehog'], 35 | 'ssl-s2':['verified-impala', 'expert-cockatoo'], 36 | 'adroit-firstgroup':['lasting-ferret', 'united-basilisk'], 37 | 'adroit-secondgroup': ['upright-starling','star-opossum','distinct-viper','humorous-dassie', 38 | 'genuine-lynx', # no data aug 39 | 'choice-monkey', 'new-mole', # smaller lr 40 | 'quiet-worm', 'cuddly-cat' # sanity check 41 | ], 42 | 'adroit-stable': ['regular-monkfish', 'unique-monster', 43 | ], 44 | 'adroit-stable-hyper': ['loved-mastodon', 'equipped-owl', 'proven-urchin', 'innocent-pangolin'], # 1-seed hyper search 45 | 'adroit-stable-hyper2': ['proven-urchin', 'innocent-pangolin'], 46 | 'adroit-policy-lr': ['exciting-toad'], 47 | 'adroit-s2cql':['composed-salmon','driven-reptile'], 48 | 'adroit-s2cql2': ['adequate-reptile', 'native-kodiak', 'summary-baboon', 'safe-mammal'], 49 | 'adroit-relocatehyper': ['dashing-grizzly'], 50 | 'adroit-relocatehyper2': ['vast-robin'], 51 | 'adroit-relocatehyper3': ['quality-gull', 'peaceful-elf'], 52 | 'adroit-relocatehyper4': ['giving-krill', 'key-lizard'], 53 | 'adroit-relocatehyper3-best-10seed': ['strong-impala'], 54 | 'res6-pen': ['handy-redbird'], 55 | 'relocate-abl1': ['easy-sloth', 'casual-sloth','engaged-reptile','national-killdeer', 56 | 'sweeping-lion','finer-pup','fine-civet','guiding-mule', 57 | 'sweet-cicada','enormous-malamute','special-louse'], # first set of ablation 58 | 'abl2': ['funny-eel','unified-moose','clear-crappie','blessed-redfish','full-weevil','credible-rat','active-javelin'], 59 | 'main':['dashing-platypus','magical-magpie','adapted-tuna','bursting-clam'], 60 | 'main2':['trusting-condor','cosmic-baboon','composed-foxhound','large-goshawk'], 61 | 'main3':['loved-anemone','e0113_main3.yaml','adapting-polecat','many-cheetah'], 62 | 'main4': ['living-boxer'], 63 | 'abl':[ 64 | 's1_pretrain', 65 | 's2_bc_update','s2_cql_random','s2_cql_std','s2_cql_update','cql_weight','s2_enc_update', 66 | 's3_bc_decay','s3_bc_weight','s3_std','s23_aug', 67 | 's23_lr','s23_demo_bs','s23_enc_lr','s23_lr','s23_pretanh_weight','s23_safe_q_factor', # first round of ablations 68 | 'fs_s2_bc', 'fs_s3_freeze_enc', 'fs_s3_buffer_new', 69 | 's1_main_seeds', # 12 more seeds for main results #TODO add after we get that result, mainseeds2 currently not used 70 | 's2_naive_rl', # stage 2 do naive RL updates, and shut down safe Q target 71 | 's3_conservative', # stage 3 keep using conservative loss 72 | 's23_q_threshold', # change Q threshold value #TODO maybe add to hyper sens? 73 | 's23_pretanh_bc_limit', # TODO maybe appendix focus study? 74 | 's23_pretanh_cql_limit', # TODO maybe appendix focus study? 75 | 'special_drq', 'special_drq2', # drq baseline (4 envs) # TODO might need to add to main plots? 76 | 'special_sdrq', 'special_sdrq2', # drq baseline with safe q target technique 77 | 'fs_s3_freeze_extra', # use stage 1, freeze for stage 2,3, give intermediate level features # TODO what about this? 78 | 'special_ferm', 'special_ferm_abl', # ferm baseline (improved baseline) 79 | 'special_vrl3_ferm', # vrl3 + extra ferm updates in stage 2 80 | 'special_ferm_plus', # ferm, but with safe Q target technique 81 | 's23_bn', # bn finetune ablation # TODO focus study? 82 | 's23_mom_enc_elr', 's23_mom_enc_safeq', 's1_model', 'frame_stack', # post sub new results, TODO 'action_repeat', 83 | 'model_elr_new', 'q_min_new', # 'model_elr', 'q_min' # old jobs failed 84 | 'moral-mudfish', 'cuddly-chicken', # these two are stride experiments 85 | 'model_hs_elr3_new', 'model_hs_elr2_new', 'model_hs_elr1_new', 86 | 'fs_enc_s2only', 'fs_enc_s3only', 's2_naive_rl_safe', 'ferm_and_s1', 87 | ], 88 | # neurips 2022 new 89 | 'ablneurips':['f0728_bc_new_ablation_dhp', 'f0731_bc_new_ablation_r', 90 | 'f0728_rand_first_layer_dh', 'f0731_rand_first_layer_p', 91 | 'f0728_channel_mismatch_dh', 'f0731_channel_mismatch_p', 92 | 'f0731_channel_latent_flow_dh', 'f0731_channel_latent_flow_p', 93 | 'f0802_channel_latent_sub_dhp', 94 | # '0728_channel_mismatch_r', 'f0728_rand_first_layer_r', 'f0728_channel_latent_flow_r', 95 | 'f0802_channel_latent_sub_dhp', 96 | # 'f0806_rand_first_layer_fs1_r','f0806_channel_latent_sub_r','f0806_channel_mismatch_fs1_r','f0806_channel_latent_cat_r', 97 | # 'f0808_byol', 98 | 'f0808_bc_new_ablation_dhp', 'f0808_bc_new_ablation_r', 99 | 'f0808_channel_latent_cat_r', 'f0808_channel_latent_sub_r', 'f0808_channel_mismatch_fs1_r', 'f0808_rand_first_layer_fs1_r', 100 | 'f0817_rand_first_layer_dhp_new', 'f0817_channel_latent_cat_dhp', 'f0817_channel_latent_sub_dhp', 101 | 'f0817_channel_mismatch_dhp', 'f0817_bc_only_withs1_r', 'f0817_bc_only_nos1_dhp', 'f0817_bc_only_nos1_r', 102 | 'f0902_highcap', 'f0902_bc_new_ablation_nos1_r', 'f0902_byol' 103 | ], 104 | 'ablph':['ph_s1_pretrain', 'ph_fs_s3_freeze_enc', 'ph_s2_cql_update', 'ph_s2_enc_update', 'ph_s2_naive_rl', 'ph_s3_conservative', 'ph_s23_enc_lr', 105 | 'ph_s23_safe_q_factor', 'ph_s2_cql_random', 'ph_s2_cql_std', 'ph_s2_cql_weight', 'ph_s2_disable', 106 | 'ph_s3_std', 'ph_s23_lr', 'ph_s23_q_threshold', 'ph_s3_bc_decay', 'ph_s3_bc_weight', 'ph_s23_aug', 'ph_s23_demo_bs', 107 | 'ph_s23_pretanh_weight', 'ph_s2_bc_update', 108 | 'ph_s2_bc', 'ph_s23_bn', 'ph_s23_frame_stack', 'ph_s1_model', 109 | ], # these are for pen and hammer 110 | 'dmc':['dmc_vrl3_s3_medium', 'dmc_vrl3_s3_medium_pretrain_new', 'dmc_vrl3_s3_easy','dmc_vrl3_s3_easy_pretrain','dmc_vrl3_s3_hard_new','dmc_vrl3_s3_hard_pretrain_new', 111 | 'dmc_e25k_medium', 'dmc_e25k_easy_hard'], 112 | 'dmc_hyper':['dmc_e25k_medium_h1'], # hyper search # , 'dmc_e25k_hard_h1' dmc_e25k_medium_h2 113 | 'dmc_hyper2':['dmc_e25k_medium_h2'], 114 | 'dmc_hard_hyper1':['dmc_e25k_hard_h1'], 115 | 'dmc_newhyper_3seed': ['dmc_e25k_medium_nh', 'dmc_e25k_easy_hard_nh', 'ferm_dmc_e25k_all', 'dmc_all_drqv2fd'] 116 | 117 | # dmc_e25k_easy_hard_nh dmc_e25k_medium_nh # 3 seed full exp using new hyperparameters 118 | 119 | # tsne_relocate_nobc1 tsne_relocate_nobc2 # will just analyze on devnode 120 | 121 | ## this is command to download all 122 | # amlt results -I "*.csv" s1_pretrain s2_bc_update s2_cql_random s2_cql_std s2_cql_update cql_weight s2_enc_update s3_bc_decay s3_bc_weight s3_std s23_aug s23_lr s23_demo_bs s23_enc_lr s23_lr s23_pretanh_weight s23_safe_q_factor fs_s2_bc fs_s3_freeze_enc fs_s3_buffer_new s1_main_seeds 123 | 124 | # amlt results -I "*.csv" s2_naive_rl s3_conservative s23_q_threshold s23_pretanh_bc_limit s23_pretanh_cql_limit special_drq special_drq2 special_sdrq special_sdrq2 fs_s3_freeze_extra special_ferm special_ferm_abl special_vrl3_ferm special_ferm_plus s23_bn 125 | 126 | # amlt results -I "*.csv" s23_safe_q_factor fs_s2_bc fs_s3_freeze_enc fs_s3_buffer_new s1_main_seeds 127 | 128 | # TODO 0507: following is all new experiments we have run 129 | # amlt results -I "*.csv" ph_s1_pretrain ph_fs_s3_freeze_enc ph_s2_cql_update ph_s2_enc_update ph_s2_naive_rl ph_s3_conservative ph_s23_enc_lr 130 | # ph_s23_safe_q_factor ph_s2_cql_random ph_s2_cql_std ph_s2_cql_weight ph_s2_disable 131 | # ph_s3_std ph_s23_lr ph_s23_q_threshold 132 | # ph_s3_bc_decay ph_s3_bc_weight ph_s23_aug ph_s23_demo_bs ph_s23_pretanh_weight 133 | 134 | # ferm_dmc_e25k_all fs_enc_s2only fs_enc_s3only dmc_all_drqv2fd #### s2_naive_rl_safe ferm_and_s1 135 | } 136 | 137 | # 0525 new: ph_s2_bc ph_s23_bn ph_s23_frame_stack ph_s1_model 138 | 139 | move_data = True 140 | move_keys = ['double', 'intermediate', 'double-add', 'test-pg', 'naivelarge'] 141 | move_keys = [ 'newbaseline3s', 'newdouble3s'] 142 | move_keys = [ 'double', 'intermediate', 'double-add',] 143 | move_keys = [ 'pt'] 144 | move_keys = [ 'newbaseline3s', 'newdouble3s', 'pt'] 145 | move_keys = [ 'newbaseline3s', 'newdouble3s', 'pt', 'utd', 'drq-base', 'ssl', 'tl', 's2-base', 'ssl-s2'] 146 | move_keys = ['adroit-firstgroup', 'adroit-secondgroup', 'adroit-stable', 'adroit-stable-hyper', 'adroit-stable-hyper2', 'adroit-policy-lr'] 147 | move_keys = ['adroit-policy-lr'] 148 | move_keys = ['adroit-s2cql'] 149 | move_keys = ['adroit-relocatehyper2', 'adroit-relocatehyper3', 'res6-pen'] # hyper 3 is best 150 | move_keys = ['adroit-relocatehyper4'] 151 | move_keys = ['adroit-relocatehyper3-best-10seed', 'relocate-abl1', 'abl2', 'main', 'main2','main3', 'main4'] 152 | move_keys = ['abl', 'dmc', 'dmc_hyper', 'dmc_hyper2', 'dmc_hard_hyper1', 'dmc_newhyper_3seed', 'ablph'] 153 | move_keys = ['ablneurips'] 154 | if move_data: 155 | for data_folder in move_keys: 156 | list_of_amlt_folder = data_folder_to_amlt_dict[data_folder] 157 | for amlt_folder in list_of_amlt_folder: 158 | move_data_from_amlt(amlt_folder, data_folder) 159 | 160 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | exp_local 4 | exp 5 | exp_fixed 6 | exp_hard 7 | nbs 8 | code_snapshots 9 | exp_drqv2_*.py 10 | dmc_benchmarks.py 11 | check_sweep.py 12 | cancel_sweep.py 13 | data/ 14 | 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | pip-wheel-metadata/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | *.py,cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ -------------------------------------------------------------------------------- /src/adroit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # NOTE: adroit env code is currently being cleaned up 5 | 6 | 7 | from collections import deque 8 | from typing import Any, NamedTuple 9 | import warnings 10 | 11 | import dm_env 12 | import numpy as np 13 | from dm_env import StepType, specs 14 | from collections import OrderedDict 15 | import mj_envs 16 | # import adept_envs # TODO worry about this later 17 | import gym 18 | 19 | from mjrl.utils.gym_env import GymEnv 20 | # from rrl_local.rrl_utils import make_basic_env, make_dir 21 | from rrl_local.rrl_multicam import BasicAdroitEnv, BasicFrankaEnv 22 | 23 | # similar to dmc.py, we will have environment wrapper here... 24 | 25 | class ExtendedTimeStep(NamedTuple): 26 | step_type: Any 27 | reward: Any 28 | discount: Any 29 | observation: Any 30 | action: Any 31 | 32 | def first(self): 33 | return self.step_type == StepType.FIRST 34 | 35 | def mid(self): 36 | return self.step_type == StepType.MID 37 | 38 | def last(self): 39 | return self.step_type == StepType.LAST 40 | 41 | def __getitem__(self, attr): 42 | return getattr(self, attr) 43 | 44 | class ExtendedTimeStepAdroit(NamedTuple): 45 | step_type: Any 46 | reward: Any 47 | discount: Any 48 | observation: Any 49 | observation_sensor: Any 50 | action: Any 51 | n_goal_achieved: Any 52 | time_limit_reached: Any 53 | 54 | def first(self): 55 | return self.step_type == StepType.FIRST 56 | 57 | def mid(self): 58 | return self.step_type == StepType.MID 59 | 60 | def last(self): 61 | return self.step_type == StepType.LAST 62 | 63 | def __getitem__(self, attr): 64 | return getattr(self, attr) 65 | 66 | 67 | class ActionRepeatWrapper(dm_env.Environment): 68 | def __init__(self, env, num_repeats): 69 | self._env = env 70 | self._num_repeats = num_repeats 71 | 72 | def step(self, action): 73 | reward = 0.0 74 | discount = 1.0 75 | for i in range(self._num_repeats): 76 | time_step = self._env.step(action) 77 | reward += (time_step.reward or 0.0) * discount 78 | discount *= time_step.discount 79 | if time_step.last(): 80 | break 81 | 82 | return time_step._replace(reward=reward, discount=discount) 83 | 84 | def observation_spec(self): 85 | return self._env.observation_spec() 86 | 87 | def action_spec(self): 88 | return self._env.action_spec() 89 | 90 | def reset(self): 91 | return self._env.reset() 92 | 93 | def __getattr__(self, name): 94 | return getattr(self._env, name) 95 | 96 | 97 | class FrameStackWrapper(dm_env.Environment): 98 | def __init__(self, env, num_frames, pixels_key='pixels'): 99 | self._env = env 100 | self._num_frames = num_frames 101 | self._frames = deque([], maxlen=num_frames) 102 | self._pixels_key = pixels_key 103 | 104 | wrapped_obs_spec = env.observation_spec() 105 | assert pixels_key in wrapped_obs_spec 106 | 107 | pixels_shape = wrapped_obs_spec[pixels_key].shape 108 | # remove batch dim 109 | if len(pixels_shape) == 4: 110 | pixels_shape = pixels_shape[1:] 111 | self._obs_spec = specs.BoundedArray(shape=np.concatenate( 112 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0), 113 | dtype=np.uint8, 114 | minimum=0, 115 | maximum=255, 116 | name='observation') 117 | 118 | def _transform_observation(self, time_step): 119 | assert len(self._frames) == self._num_frames 120 | obs = np.concatenate(list(self._frames), axis=0) 121 | return time_step._replace(observation=obs) 122 | 123 | def _extract_pixels(self, time_step): 124 | pixels = time_step.observation[self._pixels_key] 125 | # remove batch dim 126 | if len(pixels.shape) == 4: 127 | pixels = pixels[0] 128 | return pixels.transpose(2, 0, 1).copy() 129 | 130 | def reset(self): 131 | time_step = self._env.reset() 132 | pixels = self._extract_pixels(time_step) 133 | for _ in range(self._num_frames): 134 | self._frames.append(pixels) 135 | return self._transform_observation(time_step) 136 | 137 | def step(self, action): 138 | time_step = self._env.step(action) 139 | pixels = self._extract_pixels(time_step) 140 | self._frames.append(pixels) 141 | return self._transform_observation(time_step) 142 | 143 | def observation_spec(self): 144 | return self._obs_spec 145 | 146 | def action_spec(self): 147 | return self._env.action_spec() 148 | 149 | def __getattr__(self, name): 150 | return getattr(self._env, name) 151 | 152 | 153 | class ActionDTypeWrapper(dm_env.Environment): 154 | def __init__(self, env, dtype): 155 | self._env = env 156 | wrapped_action_spec = env.action_spec() 157 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, 158 | dtype, 159 | wrapped_action_spec.minimum, 160 | wrapped_action_spec.maximum, 161 | 'action') 162 | 163 | def step(self, action): 164 | action = action.astype(self._env.action_spec().dtype) 165 | return self._env.step(action) 166 | 167 | def observation_spec(self): 168 | return self._env.observation_spec() 169 | 170 | def action_spec(self): 171 | return self._action_spec 172 | 173 | def reset(self): 174 | return self._env.reset() 175 | 176 | def __getattr__(self, name): 177 | return getattr(self._env, name) 178 | 179 | 180 | class ExtendedTimeStepWrapper(dm_env.Environment): 181 | def __init__(self, env): 182 | self._env = env 183 | 184 | def reset(self): 185 | time_step = self._env.reset() 186 | return self._augment_time_step(time_step) 187 | 188 | def step(self, action): 189 | time_step = self._env.step(action) 190 | return self._augment_time_step(time_step, action) 191 | 192 | def _augment_time_step(self, time_step, action=None): 193 | if action is None: 194 | action_spec = self.action_spec() 195 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 196 | return ExtendedTimeStep(observation=time_step.observation, 197 | step_type=time_step.step_type, 198 | action=action, 199 | reward=time_step.reward or 0.0, 200 | discount=time_step.discount or 1.0) 201 | 202 | def observation_spec(self): 203 | return self._env.observation_spec() 204 | 205 | def action_spec(self): 206 | return self._env.action_spec() 207 | 208 | def __getattr__(self, name): 209 | return getattr(self._env, name) 210 | 211 | 212 | def make_basic_env(env, cam_list=[], from_pixels=False, hybrid_state=None, test_image=False, channels_first=False, 213 | num_repeats=1, num_frames=1): 214 | e = GymEnv(env) 215 | env_kwargs = None 216 | if from_pixels : # TODO might want to improve this part 217 | height = 84 218 | width = 84 219 | latent_dim = height*width*len(cam_list)*3 220 | # RRL class instance is environment wrapper... 221 | e = BasicAdroitEnv(e, cameras=cam_list, 222 | height=height, width=width, latent_dim=latent_dim, hybrid_state=hybrid_state, 223 | test_image=test_image, channels_first=channels_first, num_repeats=num_repeats, num_frames=num_frames) 224 | env_kwargs = {'rrl_kwargs' : e.env_kwargs} 225 | # if not from pixels... then it's simpler 226 | return e, env_kwargs 227 | 228 | class AdroitEnv: 229 | # a wrapper class that will make Adroit env looks like a dmc env 230 | def __init__(self, env_name, test_image=False, cam_list=None, 231 | num_repeats=2, num_frames=3, env_feature_type='pixels', device=None, reward_rescale=False): 232 | default_env_to_cam_list = { 233 | 'hammer-v0': ['top'], 234 | 'door-v0': ['top'], 235 | 'pen-v0': ['vil_camera'], 236 | 'relocate-v0': ['cam1', 'cam2', 'cam3',], 237 | } 238 | if cam_list is None: 239 | cam_list = default_env_to_cam_list[env_name] 240 | self.env_name = env_name 241 | reward_rescale_dict = { 242 | 'hammer-v0': 1/100, 243 | 'door-v0': 1/20, 244 | 'pen-v0': 1/50, 245 | 'relocate-v0': 1/30, 246 | } 247 | if reward_rescale: 248 | self.reward_rescale_factor = reward_rescale_dict[env_name] 249 | else: 250 | self.reward_rescale_factor = 1 251 | 252 | # env, _ = make_basic_env(env_name, cam_list=cam_list, from_pixels=from_pixels, hybrid_state=True, 253 | # test_image=test_image, channels_first=True, num_repeats=num_repeats, num_frames=num_frames) 254 | env = GymEnv(env_name) 255 | if env_feature_type == 'state': 256 | raise NotImplementedError("state env not ready") 257 | elif env_feature_type == 'resnet18' or env_feature_type == 'resnet34' : 258 | # TODO maybe we will just throw everything into it.. 259 | height = 256 260 | width = 256 261 | latent_dim = 512 262 | env = BasicAdroitEnv(env, cameras=cam_list, 263 | height=height, width=width, latent_dim=latent_dim, hybrid_state=True, 264 | test_image=test_image, channels_first=False, num_repeats=num_repeats, num_frames=num_frames, encoder_type=env_feature_type, 265 | device=device 266 | ) 267 | elif env_feature_type == 'pixels': 268 | height = 84 269 | width = 84 270 | latent_dim = height*width*len(cam_list)*num_frames 271 | # RRL class instance is environment wrapper... 272 | env = BasicAdroitEnv(env, cameras=cam_list, 273 | height=height, width=width, latent_dim=latent_dim, hybrid_state=True, 274 | test_image=test_image, channels_first=True, num_repeats=num_repeats, num_frames=num_frames, device=device) 275 | else: 276 | raise ValueError("env feature not supported") 277 | 278 | self._env = env 279 | self.obs_dim = env.spec.observation_dim 280 | self.obs_sensor_dim = 24 281 | self.act_dim = env.spec.action_dim 282 | self.horizon = env.spec.horizon 283 | number_channel = len(cam_list) * 3 * num_frames 284 | 285 | if env_feature_type == 'pixels': 286 | self._obs_spec = specs.BoundedArray(shape=(number_channel, 84, 84), dtype='uint8', name='observation', minimum=0, maximum=255) 287 | self._obs_sensor_spec = specs.Array(shape=(self.obs_sensor_dim,), dtype='float32', name='observation_sensor') 288 | elif env_feature_type == 'resnet18' or env_feature_type == 'resnet34' : 289 | self._obs_spec = specs.Array(shape=(512 * num_frames *len(cam_list) ,), dtype='float32', name='observation') # TODO fix magic number 290 | self._obs_sensor_spec = specs.Array(shape=(self.obs_sensor_dim,), dtype='float32', name='observation_sensor') 291 | self._action_spec = specs.BoundedArray(shape=(self.act_dim,), dtype='float32', name='action', minimum=-1.0, maximum=1.0) 292 | 293 | def reset(self): 294 | # pixels and sensor values 295 | obs_pixels, obs_sensor = self._env.reset() 296 | obs_sensor = obs_sensor.astype(np.float32) 297 | action_spec = self.action_spec() 298 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 299 | 300 | time_step = ExtendedTimeStepAdroit(observation=obs_pixels, 301 | observation_sensor=obs_sensor, 302 | step_type=StepType.FIRST, 303 | action=action, 304 | reward=0.0, 305 | discount=1.0, 306 | n_goal_achieved=0, 307 | time_limit_reached=False) 308 | return time_step 309 | 310 | def get_current_obs_without_reset(self): 311 | # use this to obtain the first state in a demo 312 | obs_pixels, obs_sensor = self._env.get_obs_for_first_state_but_without_reset() 313 | obs_sensor = obs_sensor.astype(np.float32) 314 | action_spec = self.action_spec() 315 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 316 | 317 | time_step = ExtendedTimeStepAdroit(observation=obs_pixels, 318 | observation_sensor=obs_sensor, 319 | step_type=StepType.FIRST, 320 | action=action, 321 | reward=0.0, 322 | discount=1.0, 323 | n_goal_achieved=0, 324 | time_limit_reached=False) 325 | return time_step 326 | 327 | def get_pixels_with_width_height(self, w, h): 328 | return self._env.get_pixels_with_width_height(w, h) 329 | 330 | def step(self, action, force_step_type=None, debug=False): 331 | obs_all, reward, done, env_info = self._env.step(action) 332 | obs_pixels, obs_sensor = obs_all 333 | obs_sensor = obs_sensor.astype(np.float32) 334 | 335 | discount = 1.0 336 | n_goal_achieved = env_info['n_goal_achieved'] 337 | time_limit_reached = env_info['TimeLimit.truncated'] if 'TimeLimit.truncated' in env_info else False 338 | if done: 339 | steptype = StepType.LAST 340 | else: 341 | steptype = StepType.MID 342 | 343 | if done and not time_limit_reached: 344 | discount = 0.0 345 | 346 | if force_step_type is not None: 347 | if force_step_type == 'mid': 348 | steptype = StepType.MID 349 | elif force_step_type == 'last': 350 | steptype = StepType.LAST 351 | else: 352 | steptype = StepType.FIRST 353 | 354 | reward = reward * self.reward_rescale_factor 355 | 356 | time_step = ExtendedTimeStepAdroit(observation=obs_pixels, 357 | observation_sensor=obs_sensor, 358 | step_type=steptype, 359 | action=action, 360 | reward=reward, 361 | discount=discount, 362 | n_goal_achieved=n_goal_achieved, 363 | time_limit_reached=time_limit_reached) 364 | 365 | if debug: 366 | return obs_all, reward, done, env_info 367 | return time_step 368 | 369 | def observation_spec(self): 370 | return self._obs_spec 371 | 372 | def observation_sensor_spec(self): 373 | return self._obs_sensor_spec 374 | 375 | def action_spec(self): 376 | return self._action_spec 377 | 378 | def set_env_state(self, state): 379 | self._env.set_env_state(state) 380 | # def __getattr__(self, name): 381 | # return getattr(self, name) 382 | 383 | 384 | -------------------------------------------------------------------------------- /src/cfgs_adroit/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # modified from DrQv2 config file 5 | 6 | defaults: 7 | - _self_ 8 | - task@_global_: door 9 | - override hydra/launcher: submitit_local 10 | 11 | # task settings 12 | debug: 0 13 | # frame_stack: in paper actually used 3 as default for adroit, however 14 | # it can take a lot more memory especially for relocate, and later ablation shows 15 | # frame_stack does not affect performance significantly in Adroit, so here we set it to 1. 16 | frame_stack: 1 17 | action_repeat: 2 18 | discount: 0.99 19 | # eval 20 | eval_every_frames: 50000 21 | num_eval_episodes: 25 22 | stage2_eval_every_frames: 5000 23 | stage2_num_eval_episodes: 5 24 | # snapshot 25 | save_snapshot: false 26 | # replay buffer 27 | replay_buffer_size: 1000000 28 | replay_buffer_num_workers: 4 29 | nstep: 3 30 | batch_size: 256 31 | # misc 32 | seed: 1 33 | device: cuda 34 | save_video: false 35 | save_train_video: false 36 | use_tb: false 37 | save_models: false 38 | local_data_dir: '/vrl3data' 39 | show_computation_time_est: true # provide estimates on computation times 40 | show_time_est_interval: 2500 41 | # experiment 42 | experiment: exp 43 | # environment 44 | env_feature_type: 'pixels' 45 | use_sensor: true 46 | reward_rescale: true # TODO think about this... 47 | # agent 48 | lr: 1e-4 49 | feature_dim: 50 50 | 51 | # ====== stage 1 ====== 52 | stage1_model_name: 'resnet6_32channel' # model name example: resnet6_32channel, resnet6_64channel, resnet10_32channel 53 | stage1_use_pretrain: true # if false then we start from scratch 54 | 55 | # ====== stage 2 ====== 56 | load_demo: true 57 | num_demo: 25 58 | stage2_n_update: 30000 59 | 60 | # ====== stage 3 ====== 61 | num_seed_frames: 0 62 | 63 | agent: 64 | _target_: vrl3_agent.VRL3Agent 65 | obs_shape: ??? # to be specified later 66 | action_shape: ??? # to be specified later 67 | device: ${device} 68 | critic_target_tau: 0.01 69 | update_every_steps: 2 70 | use_tb: ${use_tb} 71 | lr: 1e-4 72 | hidden_dim: 1024 73 | feature_dim: 50 74 | 75 | # environment related 76 | use_sensor: ${use_sensor} 77 | 78 | # ====== stage 1 ====== 79 | stage1_model_name: ${stage1_model_name} 80 | 81 | # ====== stage 2, 3 ====== 82 | use_data_aug: true 83 | encoder_lr_scale: 1 84 | stddev_clip: 0.3 85 | # safe Q 86 | safe_q_target_factor: 0.5 # value 1 is not using safe factor, value 0 is hard cutoff. 87 | safe_q_threshold: 200 88 | # pretanh penalty 89 | pretanh_threshold: 5 90 | pretanh_penalty: 0.001 91 | 92 | # ====== stage 2 ====== 93 | stage2_update_encoder: true # decides whether encoder is frozen or finetune in stage 2 94 | cql_weight: 1 95 | cql_temp: 1 96 | cql_n_random: 10 97 | stage2_std: 0.1 98 | stage2_bc_weight: 0 # ablation shows additional BC does not help performance 99 | 100 | # ====== stage 3 ====== 101 | stage3_update_encoder: true 102 | num_expl_steps: 0 # number of random actions at start of stage 3 103 | # std decay 104 | std0: 0.01 105 | std1: 0.01 106 | std_n_decay: 500000 107 | # bc decay 108 | stage3_bc_lam0: 0 # ablation shows additional BC does not help performance 109 | stage3_bc_lam1: 0.95 110 | 111 | hydra: 112 | run: # this "dir" decides where the training logs are stored 113 | dir: /vrl3data/logs/exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${hydra.job.override_dirname} 114 | sweep: 115 | dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent_cfg.experiment} 116 | subdir: ${hydra.job.num} 117 | launcher: 118 | timeout_min: 4300 119 | cpus_per_task: 10 120 | gpus_per_node: 1 121 | tasks_per_node: 1 122 | mem_gb: 160 123 | nodes: 1 124 | submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${agent_cfg.experiment}/.slurm 125 | -------------------------------------------------------------------------------- /src/cfgs_adroit/task/door.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 4100000 2 | task_name: door-v0 3 | agent: 4 | encoder_lr_scale: 1 -------------------------------------------------------------------------------- /src/cfgs_adroit/task/hammer.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 4100000 2 | task_name: hammer-v0 3 | agent: 4 | encoder_lr_scale: 1 -------------------------------------------------------------------------------- /src/cfgs_adroit/task/pen.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 4100000 2 | task_name: pen-v0 3 | agent: 4 | encoder_lr_scale: 1 -------------------------------------------------------------------------------- /src/cfgs_adroit/task/relocate.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 4100000 2 | task_name: relocate-v0 3 | agent: 4 | encoder_lr_scale: 0.01 -------------------------------------------------------------------------------- /src/dmc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from collections import deque 6 | from typing import Any, NamedTuple 7 | 8 | import dm_env 9 | import numpy as np 10 | from dm_control import manipulation, suite 11 | from dm_control.suite.wrappers import action_scale, pixels 12 | from dm_env import StepType, specs 13 | 14 | class ExtendedTimeStep(NamedTuple): 15 | step_type: Any 16 | reward: Any 17 | discount: Any 18 | observation: Any 19 | action: Any 20 | 21 | def first(self): 22 | return self.step_type == StepType.FIRST 23 | 24 | def mid(self): 25 | return self.step_type == StepType.MID 26 | 27 | def last(self): 28 | return self.step_type == StepType.LAST 29 | 30 | def __getitem__(self, attr): 31 | return getattr(self, attr) 32 | 33 | 34 | class ActionRepeatWrapper(dm_env.Environment): 35 | def __init__(self, env, num_repeats): 36 | self._env = env 37 | self._num_repeats = num_repeats 38 | 39 | def step(self, action): 40 | reward = 0.0 41 | discount = 1.0 42 | for i in range(self._num_repeats): 43 | time_step = self._env.step(action) 44 | reward += (time_step.reward or 0.0) * discount 45 | discount *= time_step.discount 46 | if time_step.last(): 47 | break 48 | 49 | return time_step._replace(reward=reward, discount=discount) 50 | 51 | def observation_spec(self): 52 | return self._env.observation_spec() 53 | 54 | def action_spec(self): 55 | return self._env.action_spec() 56 | 57 | def reset(self): 58 | return self._env.reset() 59 | 60 | def __getattr__(self, name): 61 | return getattr(self._env, name) 62 | 63 | 64 | class FrameStackWrapper(dm_env.Environment): 65 | def __init__(self, env, num_frames, pixels_key='pixels'): 66 | self._env = env 67 | self._num_frames = num_frames 68 | self._frames = deque([], maxlen=num_frames) 69 | self._pixels_key = pixels_key 70 | 71 | wrapped_obs_spec = env.observation_spec() 72 | assert pixels_key in wrapped_obs_spec 73 | 74 | pixels_shape = wrapped_obs_spec[pixels_key].shape 75 | # remove batch dim 76 | if len(pixels_shape) == 4: 77 | pixels_shape = pixels_shape[1:] 78 | self._obs_spec = specs.BoundedArray(shape=np.concatenate( 79 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0), 80 | dtype=np.uint8, 81 | minimum=0, 82 | maximum=255, 83 | name='observation') 84 | 85 | def _transform_observation(self, time_step): 86 | assert len(self._frames) == self._num_frames 87 | obs = np.concatenate(list(self._frames), axis=0) 88 | return time_step._replace(observation=obs) 89 | 90 | def _extract_pixels(self, time_step): 91 | pixels = time_step.observation[self._pixels_key] 92 | # remove batch dim 93 | if len(pixels.shape) == 4: 94 | pixels = pixels[0] 95 | 96 | # transpose: 84 x 84 x 3 -> 3 x 84 x 84 97 | return pixels.transpose(2, 0, 1).copy() 98 | 99 | def reset(self): 100 | time_step = self._env.reset() 101 | pixels = self._extract_pixels(time_step) 102 | for _ in range(self._num_frames): 103 | self._frames.append(pixels) 104 | return self._transform_observation(time_step) 105 | 106 | def step(self, action): 107 | time_step = self._env.step(action) 108 | pixels = self._extract_pixels(time_step) 109 | self._frames.append(pixels) 110 | return self._transform_observation(time_step) 111 | 112 | def observation_spec(self): 113 | return self._obs_spec 114 | 115 | def action_spec(self): 116 | return self._env.action_spec() 117 | 118 | def __getattr__(self, name): 119 | return getattr(self._env, name) 120 | 121 | 122 | class ActionDTypeWrapper(dm_env.Environment): 123 | def __init__(self, env, dtype): 124 | self._env = env 125 | wrapped_action_spec = env.action_spec() 126 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, 127 | dtype, 128 | wrapped_action_spec.minimum, 129 | wrapped_action_spec.maximum, 130 | 'action') 131 | 132 | def step(self, action): 133 | action = action.astype(self._env.action_spec().dtype) 134 | return self._env.step(action) 135 | 136 | def observation_spec(self): 137 | return self._env.observation_spec() 138 | 139 | def action_spec(self): 140 | return self._action_spec 141 | 142 | def reset(self): 143 | return self._env.reset() 144 | 145 | def __getattr__(self, name): 146 | return getattr(self._env, name) 147 | 148 | 149 | class ExtendedTimeStepWrapper(dm_env.Environment): 150 | def __init__(self, env): 151 | self._env = env 152 | 153 | def reset(self): 154 | time_step = self._env.reset() 155 | return self._augment_time_step(time_step) 156 | 157 | def step(self, action): 158 | time_step = self._env.step(action) 159 | return self._augment_time_step(time_step, action) 160 | 161 | def _augment_time_step(self, time_step, action=None): 162 | if action is None: 163 | action_spec = self.action_spec() 164 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 165 | return ExtendedTimeStep(observation=time_step.observation, 166 | step_type=time_step.step_type, 167 | action=action, 168 | reward=time_step.reward or 0.0, 169 | discount=time_step.discount or 1.0) 170 | 171 | def observation_spec(self): 172 | return self._env.observation_spec() 173 | 174 | def action_spec(self): 175 | return self._env.action_spec() 176 | 177 | def __getattr__(self, name): 178 | return getattr(self._env, name) 179 | 180 | 181 | def make(name, frame_stack, action_repeat, seed): 182 | domain, task = name.split('_', 1) 183 | # overwrite cup to ball_in_cup 184 | domain = dict(cup='ball_in_cup').get(domain, domain) 185 | # make sure reward is not visualized 186 | if (domain, task) in suite.ALL_TASKS: 187 | env = suite.load(domain, 188 | task, 189 | task_kwargs={'random': seed}, 190 | visualize_reward=False) 191 | pixels_key = 'pixels' 192 | else: 193 | name = f'{domain}_{task}_vision' 194 | env = manipulation.load(name, seed=seed) 195 | pixels_key = 'front_close' 196 | # add wrappers 197 | env = ActionDTypeWrapper(env, np.float32) 198 | env = ActionRepeatWrapper(env, action_repeat) 199 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) 200 | # add renderings for clasical tasks 201 | if (domain, task) in suite.ALL_TASKS: 202 | # zoom in camera for quadruped 203 | camera_id = dict(quadruped=2).get(domain, 0) 204 | render_kwargs = dict(height=84, width=84, camera_id=camera_id) 205 | env = pixels.Wrapper(env, 206 | pixels_only=True, 207 | render_kwargs=render_kwargs) 208 | # stack several frames 209 | env = FrameStackWrapper(env, frame_stack, pixels_key) 210 | env = ExtendedTimeStepWrapper(env) 211 | return env 212 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import csv 6 | import datetime 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import torch 11 | import torchvision 12 | from termcolor import colored 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 16 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 17 | ('episode_reward', 'R', 'float'), 18 | ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'), 19 | ('total_time', 'T', 'time')] 20 | 21 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 22 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 23 | ('episode_reward', 'R', 'float'), 24 | ('total_time', 'T', 'time')] 25 | 26 | 27 | class AverageMeter(object): 28 | def __init__(self): 29 | self._sum = 0 30 | self._count = 0 31 | 32 | def update(self, value, n=1): 33 | self._sum += value 34 | self._count += n 35 | 36 | def value(self): 37 | return self._sum / max(1, self._count) 38 | 39 | 40 | class MetersGroup(object): 41 | def __init__(self, csv_file_name, formating): 42 | self._csv_file_name = csv_file_name 43 | self._formating = formating 44 | self._meters = defaultdict(AverageMeter) 45 | self._csv_file = None 46 | self._csv_writer = None 47 | 48 | def log(self, key, value, n=1): 49 | self._meters[key].update(value, n) 50 | 51 | def _prime_meters(self): 52 | data = dict() 53 | for key, meter in self._meters.items(): 54 | if key.startswith('train'): 55 | key = key[len('train') + 1:] 56 | else: 57 | key = key[len('eval') + 1:] 58 | key = key.replace('/', '_') 59 | data[key] = meter.value() 60 | return data 61 | 62 | def _remove_old_entries(self, data): 63 | rows = [] 64 | with self._csv_file_name.open('r') as f: 65 | reader = csv.DictReader(f) 66 | for row in reader: 67 | if float(row['episode']) >= data['episode']: 68 | break 69 | rows.append(row) 70 | with self._csv_file_name.open('w') as f: 71 | writer = csv.DictWriter(f, 72 | fieldnames=sorted(data.keys()), 73 | restval=0.0) 74 | writer.writeheader() 75 | for row in rows: 76 | writer.writerow(row) 77 | 78 | def _dump_to_csv(self, data): 79 | if self._csv_writer is None: 80 | should_write_header = True 81 | if self._csv_file_name.exists(): 82 | self._remove_old_entries(data) 83 | should_write_header = False 84 | 85 | self._csv_file = self._csv_file_name.open('a') 86 | self._csv_writer = csv.DictWriter(self._csv_file, 87 | fieldnames=sorted(data.keys()), 88 | restval=0.0) 89 | if should_write_header: 90 | self._csv_writer.writeheader() 91 | 92 | self._csv_writer.writerow(data) 93 | self._csv_file.flush() 94 | 95 | def _format(self, key, value, ty): 96 | if ty == 'int': 97 | value = int(value) 98 | return f'{key}: {value}' 99 | elif ty == 'float': 100 | return f'{key}: {value:.04f}' 101 | elif ty == 'time': 102 | value = str(datetime.timedelta(seconds=int(value))) 103 | return f'{key}: {value}' 104 | else: 105 | raise f'invalid format type: {ty}' 106 | 107 | def _dump_to_console(self, data, prefix): 108 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 109 | pieces = [f'| {prefix: <14}'] 110 | for key, disp_key, ty in self._formating: 111 | value = data.get(key, 0) 112 | pieces.append(self._format(disp_key, value, ty)) 113 | print(' | '.join(pieces)) 114 | 115 | def dump(self, step, prefix): 116 | if len(self._meters) == 0: 117 | return 118 | data = self._prime_meters() 119 | data['frame'] = step 120 | self._dump_to_csv(data) 121 | self._dump_to_console(data, prefix) 122 | self._meters.clear() 123 | 124 | 125 | class Logger(object): 126 | def __init__(self, log_dir, use_tb, stage2_logger=False): 127 | self._log_dir = log_dir 128 | if not stage2_logger: 129 | self._train_mg = MetersGroup(log_dir / 'train.csv', 130 | formating=COMMON_TRAIN_FORMAT) 131 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 132 | formating=COMMON_EVAL_FORMAT) 133 | else: 134 | self._train_mg = MetersGroup(log_dir / 'train_stage2.csv', 135 | formating=COMMON_TRAIN_FORMAT) 136 | self._eval_mg = MetersGroup(log_dir / 'eval_stage2.csv', 137 | formating=COMMON_EVAL_FORMAT) 138 | if use_tb: 139 | self._sw = SummaryWriter(str(log_dir / 'tb')) 140 | else: 141 | self._sw = None 142 | 143 | def _try_sw_log(self, key, value, step): 144 | if self._sw is not None: 145 | self._sw.add_scalar(key, value, step) 146 | 147 | def log(self, key, value, step): 148 | assert key.startswith('train') or key.startswith('eval') 149 | if type(value) == torch.Tensor: 150 | value = value.item() 151 | self._try_sw_log(key, value, step) 152 | mg = self._train_mg if key.startswith('train') else self._eval_mg 153 | mg.log(key, value) 154 | 155 | def log_metrics(self, metrics, step, ty): 156 | for key, value in metrics.items(): 157 | self.log(f'{ty}/{key}', value, step) 158 | 159 | def dump(self, step, ty=None): 160 | if ty is None or ty == 'eval': 161 | self._eval_mg.dump(step, 'eval') 162 | if ty is None or ty == 'train': 163 | self._train_mg.dump(step, 'train') 164 | 165 | def log_and_dump_ctx(self, step, ty): 166 | return LogAndDumpCtx(self, step, ty) 167 | 168 | 169 | class LogAndDumpCtx: 170 | def __init__(self, logger, step, ty): 171 | self._logger = logger 172 | self._step = step 173 | self._ty = ty 174 | 175 | def __enter__(self): 176 | return self 177 | 178 | def __call__(self, key, value): 179 | self._logger.log(f'{self._ty}/{key}', value, self._step) 180 | 181 | def __exit__(self, *args): 182 | self._logger.dump(self._step, self._ty) 183 | -------------------------------------------------------------------------------- /src/replay_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import datetime 6 | import io 7 | import random 8 | import traceback 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.data import IterableDataset 15 | 16 | 17 | def episode_len(episode): 18 | # subtract -1 because the dummy first transition 19 | return next(iter(episode.values())).shape[0] - 1 20 | 21 | 22 | def save_episode(episode, fn): 23 | with io.BytesIO() as bs: 24 | np.savez_compressed(bs, **episode) 25 | bs.seek(0) 26 | with fn.open('wb') as f: 27 | f.write(bs.read()) 28 | 29 | 30 | def load_episode(fn): 31 | with fn.open('rb') as f: 32 | episode = np.load(f) 33 | episode = {k: episode[k] for k in episode.keys()} 34 | return episode 35 | 36 | 37 | class ReplayBufferStorage: 38 | def __init__(self, data_specs, replay_dir): 39 | self._data_specs = data_specs 40 | self._replay_dir = replay_dir 41 | replay_dir.mkdir(exist_ok=True) 42 | self._current_episode = defaultdict(list) 43 | self._preload() 44 | 45 | def __len__(self): 46 | return self._num_transitions 47 | 48 | def add(self, time_step): 49 | for spec in self._data_specs: 50 | value = time_step[spec.name] 51 | if np.isscalar(value): 52 | value = np.full(spec.shape, value, spec.dtype) 53 | # print(spec.name, spec.shape, spec.dtype, value.shape, value.dtype) 54 | assert spec.shape == value.shape and spec.dtype == value.dtype 55 | self._current_episode[spec.name].append(value) 56 | if time_step.last(): 57 | episode = dict() 58 | for spec in self._data_specs: 59 | value = self._current_episode[spec.name] 60 | episode[spec.name] = np.array(value, spec.dtype) 61 | self._current_episode = defaultdict(list) 62 | self._store_episode(episode) 63 | 64 | def _preload(self): 65 | self._num_episodes = 0 66 | self._num_transitions = 0 67 | for fn in self._replay_dir.glob('*.npz'): 68 | _, _, eps_len = fn.stem.split('_') 69 | self._num_episodes += 1 70 | self._num_transitions += int(eps_len) 71 | 72 | def _store_episode(self, episode): 73 | eps_idx = self._num_episodes 74 | eps_len = episode_len(episode) 75 | self._num_episodes += 1 76 | self._num_transitions += eps_len 77 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 78 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz' 79 | save_episode(episode, self._replay_dir / eps_fn) 80 | 81 | 82 | class ReplayBuffer(IterableDataset): 83 | def __init__(self, replay_dir, max_size, num_workers, nstep, discount, 84 | fetch_every, save_snapshot, is_adroit=False, return_next_action=False): 85 | self._replay_dir = replay_dir 86 | self._size = 0 87 | self._max_size = max_size 88 | self._num_workers = max(1, num_workers) 89 | self._episode_fns = [] 90 | self._episodes = dict() 91 | self._nstep = nstep 92 | self._discount = discount 93 | self._fetch_every = fetch_every 94 | self._samples_since_last_fetch = fetch_every 95 | self._save_snapshot = save_snapshot 96 | self._is_adroit = is_adroit 97 | self._return_next_action = return_next_action 98 | 99 | def set_nstep(self, nstep): 100 | self._nstep = nstep 101 | 102 | def _sample_episode(self): 103 | eps_fn = random.choice(self._episode_fns) 104 | return self._episodes[eps_fn] 105 | 106 | def _store_episode(self, eps_fn): 107 | try: 108 | episode = load_episode(eps_fn) 109 | except: 110 | return False 111 | eps_len = episode_len(episode) 112 | while eps_len + self._size > self._max_size: 113 | early_eps_fn = self._episode_fns.pop(0) 114 | early_eps = self._episodes.pop(early_eps_fn) 115 | self._size -= episode_len(early_eps) 116 | early_eps_fn.unlink(missing_ok=True) 117 | self._episode_fns.append(eps_fn) 118 | self._episode_fns.sort() 119 | self._episodes[eps_fn] = episode 120 | self._size += eps_len 121 | 122 | if not self._save_snapshot: 123 | eps_fn.unlink(missing_ok=True) 124 | return True 125 | 126 | def _try_fetch(self): 127 | if self._samples_since_last_fetch < self._fetch_every: 128 | return 129 | self._samples_since_last_fetch = 0 130 | try: 131 | worker_id = torch.utils.data.get_worker_info().id 132 | except: 133 | worker_id = 0 134 | eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True) 135 | fetched_size = 0 136 | for eps_fn in eps_fns: 137 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]] 138 | if eps_idx % self._num_workers != worker_id: 139 | continue 140 | if eps_fn in self._episodes.keys(): 141 | break 142 | if fetched_size + eps_len > self._max_size: 143 | break 144 | fetched_size += eps_len 145 | if not self._store_episode(eps_fn): 146 | break 147 | 148 | def _sample(self): 149 | try: 150 | self._try_fetch() 151 | except: 152 | traceback.print_exc() 153 | self._samples_since_last_fetch += 1 154 | episode = self._sample_episode() 155 | # add +1 for the first dummy transition 156 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1 157 | obs = episode['observation'][idx - 1] 158 | action = episode['action'][idx] 159 | next_obs = episode['observation'][idx + self._nstep - 1] 160 | reward = np.zeros_like(episode['reward'][idx]) 161 | discount = np.ones_like(episode['discount'][idx]) 162 | for i in range(self._nstep): 163 | step_reward = episode['reward'][idx + i] 164 | reward += discount * step_reward 165 | discount *= episode['discount'][idx + i] * self._discount 166 | 167 | if self._return_next_action: 168 | next_action = episode['action'][idx + self._nstep - 1] 169 | 170 | if not self._is_adroit: 171 | if self._return_next_action: 172 | return (obs, action, reward, discount, next_obs, next_action) 173 | else: 174 | return (obs, action, reward, discount, next_obs) 175 | else: 176 | obs_sensor = episode['observation_sensor'][idx - 1] 177 | obs_sensor_next = episode['observation_sensor'][idx + self._nstep - 1] 178 | if self._return_next_action: 179 | return (obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next, next_action) 180 | else: 181 | return (obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next) 182 | 183 | def __iter__(self): 184 | while True: 185 | yield self._sample() 186 | 187 | 188 | def _worker_init_fn(worker_id): 189 | seed = np.random.get_state()[1][0] + worker_id 190 | np.random.seed(seed) 191 | random.seed(seed) 192 | 193 | 194 | def make_replay_loader(replay_dir, max_size, batch_size, num_workers, 195 | save_snapshot, nstep, discount, fetch_every=1000, is_adroit=False, return_next_action=False): 196 | max_size_per_worker = max_size // max(1, num_workers) 197 | 198 | iterable = ReplayBuffer(replay_dir, 199 | max_size_per_worker, 200 | num_workers, 201 | nstep, 202 | discount, 203 | fetch_every=fetch_every, 204 | save_snapshot=save_snapshot, 205 | is_adroit=is_adroit, 206 | return_next_action=return_next_action) 207 | 208 | loader = torch.utils.data.DataLoader(iterable, 209 | batch_size=batch_size, 210 | num_workers=num_workers, 211 | pin_memory=True, 212 | worker_init_fn=_worker_init_fn) 213 | return loader 214 | 215 | def reinit_data_loader(data_loader, batch_size, num_workers): 216 | # reinit a data loader with a new batch size 217 | loader = torch.utils.data.DataLoader(data_loader.dataset, 218 | batch_size=batch_size, 219 | num_workers=num_workers, 220 | pin_memory=True, 221 | worker_init_fn=_worker_init_fn) 222 | return loader -------------------------------------------------------------------------------- /src/rrl_local/rrl_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Rutav Shah, Indian Institute of Technlogy Kharagpur 2 | # Copyright (c) Facebook, Inc. and its affiliates 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models, transforms 8 | from torchvision.models import resnet34, resnet18 9 | from PIL import Image 10 | 11 | _encoders = {'resnet34' : resnet34, 'resnet18' : resnet18, } 12 | _transforms = { 13 | 'resnet34' : 14 | transforms.Compose([ 15 | transforms.Resize(256), 16 | transforms.CenterCrop(224), 17 | transforms.ToTensor(), 18 | transforms.Normalize([0.485, 0.456, 0.406], 19 | [0.229, 0.224, 0.225]) 20 | ]), 21 | 'resnet18' : 22 | transforms.Compose([ 23 | transforms.Resize(256), 24 | transforms.CenterCrop(224), 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.485, 0.456, 0.406], 27 | [0.229, 0.224, 0.225]) 28 | ]), 29 | } 30 | 31 | class Encoder(nn.Module): 32 | def __init__(self, encoder_type): 33 | super(Encoder, self).__init__() 34 | self.encoder_type = encoder_type 35 | if self.encoder_type in _encoders : 36 | self.model = _encoders[encoder_type](pretrained=True) 37 | else : 38 | print("Please enter a valid encoder type") 39 | raise Exception 40 | for param in self.model.parameters(): 41 | param.requires_grad = False 42 | if self.encoder_type in _encoders : 43 | num_ftrs = self.model.fc.in_features 44 | self.num_ftrs = num_ftrs 45 | self.model.fc = Identity() # fc layer is replaced with identity 46 | 47 | def forward(self, x): 48 | x = self.model(x) 49 | return x 50 | 51 | # the transform is resize - center crop - normalize (imagenet normalize) No data aug here 52 | def get_transform(self): 53 | return _transforms[self.encoder_type] 54 | 55 | def get_features(self, x): 56 | with torch.no_grad(): 57 | z = self.model(x) 58 | return z.cpu().data.numpy() ### Can't store everything in GPU :/ 59 | 60 | class IdentityEncoder(nn.Module): 61 | def __init__(self): 62 | super(IdentityEncoder, self).__init__() 63 | 64 | def forward(self, x): 65 | return x 66 | 67 | def get_transform(self): 68 | return transforms.Compose([ 69 | transforms.ToTensor(), 70 | ]) 71 | 72 | def get_features(self, x): 73 | return x.reshape(-1) 74 | 75 | class Identity(nn.Module): 76 | def __init__(self): 77 | super(Identity, self).__init__() 78 | 79 | def forward(self, x): 80 | return x 81 | -------------------------------------------------------------------------------- /src/rrl_local/rrl_multicam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Rutav Shah, Indian Institute of Technlogy Kharagpur 2 | # Copyright (c) Facebook, Inc. and its affiliates 3 | 4 | import gym 5 | # from abc import ABC 6 | import numpy as np 7 | from rrl_local.rrl_encoder import Encoder, IdentityEncoder 8 | from PIL import Image 9 | import torch 10 | from collections import deque 11 | 12 | _mj_envs = {'pen-v0', 'hammer-v0', 'door-v0', 'relocate-v0'} 13 | 14 | def make_encoder(encoder, encoder_type, device, is_eval=True) : 15 | if not encoder : 16 | if encoder_type == 'resnet34' or encoder_type == 'resnet18' : 17 | encoder = Encoder(encoder_type) 18 | elif encoder_type == 'identity' : 19 | encoder = IdentityEncoder() 20 | else : 21 | print("Please enter valid encoder_type.") 22 | raise Exception 23 | if is_eval: 24 | encoder.eval() 25 | encoder.to(device) 26 | return encoder 27 | 28 | class BasicAdroitEnv(gym.Env): # , ABC 29 | def __init__(self, env, cameras, latent_dim=512, hybrid_state=True, channels_first=False, 30 | height=84, width=84, test_image=False, num_repeats=1, num_frames=1, encoder_type=None, device=None): 31 | self._env = env 32 | self.env_id = env.env.unwrapped.spec.id 33 | self.device = device 34 | 35 | self._num_repeats = num_repeats 36 | self._num_frames = num_frames 37 | self._frames = deque([], maxlen=num_frames) 38 | 39 | self.encoder = None 40 | self.transforms = None 41 | self.encoder_type = encoder_type 42 | if encoder_type is not None: 43 | self.encoder = make_encoder(encoder=None, encoder_type=self.encoder_type, device=self.device, is_eval=True) 44 | self.transforms = self.encoder.get_transform() 45 | 46 | if test_image: 47 | print("======================adroit image test mode==============================") 48 | print("======================adroit image test mode==============================") 49 | print("======================adroit image test mode==============================") 50 | print("======================adroit image test mode==============================") 51 | self.test_image = test_image 52 | 53 | self.cameras = cameras 54 | self.latent_dim = latent_dim 55 | self.hybrid_state = hybrid_state 56 | self.channels_first = channels_first 57 | self.height = height 58 | self.width = width 59 | self.action_space = self._env.action_space 60 | self.env_kwargs = {'cameras' : cameras, 'latent_dim' : latent_dim, 'hybrid_state': hybrid_state, 61 | 'channels_first' : channels_first, 'height' : height, 'width' : width} 62 | 63 | shape = [3, self.width, self.height] 64 | self._observation_space = gym.spaces.Box( 65 | low=0, high=255, shape=shape, dtype=np.uint8 66 | ) 67 | self.sim = env.env.sim 68 | self._env.spec.observation_dim = latent_dim 69 | 70 | if hybrid_state : 71 | if self.env_id in _mj_envs: 72 | self._env.spec.observation_dim += 24 # Assuming 24 states for adroit hand. 73 | 74 | self.spec = self._env.spec 75 | self.observation_dim = self.spec.observation_dim 76 | self.horizon = self._env.env.spec.max_episode_steps 77 | 78 | def get_obs(self,): 79 | # for our case, let's output the image, and then also the sensor features 80 | if self.env_id in _mj_envs : 81 | env_state = self._env.env.get_env_state() 82 | qp = env_state['qpos'] 83 | 84 | if self.env_id == 'pen-v0': 85 | qp = qp[:-6] 86 | elif self.env_id == 'door-v0': 87 | qp = qp[4:-2] 88 | elif self.env_id == 'hammer-v0': 89 | qp = qp[2:-7] 90 | elif self.env_id == 'relocate-v0': 91 | qp = qp[6:-6] 92 | 93 | imgs = [] # number of image is number of camera 94 | 95 | if self.encoder is not None: 96 | for cam in self.cameras : 97 | img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) 98 | # img = env.env.sim.render(width=84, height=84, mode='offscreen') 99 | img = img[::-1, :, : ] # Image given has to be flipped 100 | if self.channels_first : 101 | img = img.transpose((2, 0, 1)) 102 | #img = img.astype(np.uint8) 103 | img = Image.fromarray(img) 104 | img = self.transforms(img) 105 | imgs.append(img) 106 | 107 | inp_img = torch.stack(imgs).to(self.device) # [num_cam, C, H, W] 108 | z = self.encoder.get_features(inp_img).reshape(-1) 109 | # assert z.shape[0] == self.latent_dim, "Encoded feature length : {}, Expected : {}".format(z.shape[0], self.latent_dim) 110 | pixels = z 111 | else: 112 | if not self.test_image: 113 | for cam in self.cameras : # for each camera, render once 114 | img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later 115 | # img = img[::-1, :, : ] # Image given has to be flipped 116 | if self.channels_first : 117 | img = img.transpose((2, 0, 1)) # then it's 3 x width x height 118 | # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?) 119 | #img = img.astype(np.uint8) 120 | # img = Image.fromarray(img) # TODO is this necessary? 121 | imgs.append(img) 122 | else: 123 | img = (np.random.rand(1, 84, 84) * 255).astype(np.uint8) 124 | imgs.append(img) 125 | pixels = np.concatenate(imgs, axis=0) 126 | 127 | # TODO below are what we originally had... 128 | # if not self.test_image: 129 | # for cam in self.cameras : # for each camera, render once 130 | # img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later 131 | # # img = img[::-1, :, : ] # Image given has to be flipped 132 | # if self.channels_first : 133 | # img = img.transpose((2, 0, 1)) # then it's 3 x width x height 134 | # # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?) 135 | # #img = img.astype(np.uint8) 136 | # # img = Image.fromarray(img) # TODO is this necessary? 137 | # imgs.append(img) 138 | # else: 139 | # img = (np.random.rand(1, 84, 84) * 255).astype(np.uint8) 140 | # imgs.append(img) 141 | # pixels = np.concatenate(imgs, axis=0) 142 | 143 | if not self.hybrid_state : # this defaults to True... so RRL uses hybrid state 144 | qp = None 145 | 146 | sensor_info = qp 147 | return pixels, sensor_info 148 | 149 | def get_env_infos(self): 150 | return self._env.get_env_infos() 151 | def set_seed(self, seed): 152 | return self._env.set_seed(seed) 153 | 154 | def get_stacked_pixels(self): #TODO fix it 155 | assert len(self._frames) == self._num_frames 156 | stacked_pixels = np.concatenate(list(self._frames), axis=0) 157 | return stacked_pixels 158 | 159 | def reset(self): 160 | self._env.reset() 161 | pixels, sensor_info = self.get_obs() 162 | for _ in range(self._num_frames): 163 | self._frames.append(pixels) 164 | stacked_pixels = self.get_stacked_pixels() 165 | return stacked_pixels, sensor_info 166 | 167 | def get_obs_for_first_state_but_without_reset(self): 168 | pixels, sensor_info = self.get_obs() 169 | for _ in range(self._num_frames): 170 | self._frames.append(pixels) 171 | stacked_pixels = self.get_stacked_pixels() 172 | return stacked_pixels, sensor_info 173 | 174 | def step(self, action): 175 | reward_sum = 0.0 176 | discount_prod = 1.0 # TODO pen can terminate early 177 | n_goal_achieved = 0 178 | for i_action in range(self._num_repeats): 179 | obs, reward, done, env_info = self._env.step(action) 180 | reward_sum += reward 181 | if env_info['goal_achieved'] == True: 182 | n_goal_achieved += 1 183 | if done: 184 | break 185 | env_info['n_goal_achieved'] = n_goal_achieved 186 | # now get stacked frames 187 | pixels, sensor_info = self.get_obs() 188 | self._frames.append(pixels) 189 | stacked_pixels = self.get_stacked_pixels() 190 | return [stacked_pixels, sensor_info], reward_sum, done, env_info 191 | 192 | def set_env_state(self, state): 193 | return self._env.set_env_state(state) 194 | def get_env_state(self): 195 | return self._env.get_env_state(state) 196 | 197 | def evaluate_policy(self, policy, 198 | num_episodes=5, 199 | horizon=None, 200 | gamma=1, 201 | visual=False, 202 | percentile=[], 203 | get_full_dist=False, 204 | mean_action=False, 205 | init_env_state=None, 206 | terminate_at_done=True, 207 | seed=123): 208 | # TODO this needs to be rewritten 209 | 210 | self.set_seed(seed) 211 | horizon = self.horizon if horizon is None else horizon 212 | mean_eval, std, min_eval, max_eval = 0.0, 0.0, -1e8, -1e8 213 | ep_returns = np.zeros(num_episodes) 214 | self.encoder.eval() 215 | 216 | for ep in range(num_episodes): 217 | o = self.reset() 218 | if init_env_state is not None: 219 | self.set_env_state(init_env_state) 220 | t, done = 0, False 221 | while t < horizon and (done == False or terminate_at_done == False): 222 | self.render() if visual is True else None 223 | o = self.get_obs(self._env.get_obs()) 224 | a = policy.get_action(o)[1]['evaluation'] if mean_action is True else policy.get_action(o)[0] 225 | o, r, done, _ = self.step(a) 226 | ep_returns[ep] += (gamma ** t) * r 227 | t += 1 228 | 229 | mean_eval, std = np.mean(ep_returns), np.std(ep_returns) 230 | min_eval, max_eval = np.amin(ep_returns), np.amax(ep_returns) 231 | base_stats = [mean_eval, std, min_eval, max_eval] 232 | 233 | percentile_stats = [] 234 | for p in percentile: 235 | percentile_stats.append(np.percentile(ep_returns, p)) 236 | 237 | full_dist = ep_returns if get_full_dist is True else None 238 | 239 | return [base_stats, percentile_stats, full_dist] 240 | 241 | def get_pixels_with_width_height(self, w, h): 242 | imgs = [] # number of image is number of camera 243 | 244 | for cam in self.cameras : # for each camera, render once 245 | img = self._env.env.sim.render(width=w, height=h, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later 246 | # img = img[::-1, :, : ] # Image given has to be flipped 247 | if self.channels_first : 248 | img = img.transpose((2, 0, 1)) # then it's 3 x width x height 249 | # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?) 250 | #img = img.astype(np.uint8) 251 | # img = Image.fromarray(img) # TODO is this necessary? 252 | imgs.append(img) 253 | 254 | pixels = np.concatenate(imgs, axis=0) 255 | return pixels 256 | 257 | 258 | class BasicFrankaEnv(gym.Env): 259 | def __init__(self, env, cameras, latent_dim=512, hybrid_state=True, channels_first=False, 260 | height=84, width=84, test_image=False, num_repeats=1, num_frames=1, encoder_type=None, device=None): 261 | # the parameter env is basically the kitchen env now 262 | self._env = env 263 | self.env_id = env.env.unwrapped.spec.id 264 | self.device = device 265 | 266 | self._num_repeats = num_repeats 267 | self._num_frames = num_frames 268 | self._frames = deque([], maxlen=num_frames) 269 | 270 | self.encoder = None 271 | self.transforms = None 272 | self.encoder_type = encoder_type 273 | if encoder_type is not None: 274 | self.encoder = make_encoder(encoder=None, encoder_type=self.encoder_type, device=self.device, is_eval=True) 275 | self.transforms = self.encoder.get_transform() 276 | 277 | if test_image: 278 | print("======================adroit image test mode==============================") 279 | print("======================adroit image test mode==============================") 280 | print("======================adroit image test mode==============================") 281 | print("======================adroit image test mode==============================") 282 | self.test_image = test_image 283 | 284 | self.cameras = cameras 285 | self.latent_dim = latent_dim 286 | self.hybrid_state = hybrid_state 287 | self.channels_first = channels_first 288 | self.height = height 289 | self.width = width 290 | self.action_space = self._env.action_space 291 | self.env_kwargs = {'cameras' : cameras, 'latent_dim' : latent_dim, 'hybrid_state': hybrid_state, 292 | 'channels_first' : channels_first, 'height' : height, 'width' : width} 293 | 294 | shape = [3, self.width, self.height] 295 | self._observation_space = gym.spaces.Box( 296 | low=0, high=255, shape=shape, dtype=np.uint8 297 | ) 298 | self.sim = env.env.sim 299 | self._env.spec.observation_dim = latent_dim 300 | self._env.spec.action_dim = 9 # TODO magic number 301 | self._env.spec.horizon = self._env.env.spec.max_episode_steps 302 | # print("==============") 303 | # print(dir(self._env)) 304 | # print("high: ", self._env.action_space) 305 | # quit() 306 | 307 | if hybrid_state : 308 | if self.env_id in _mj_envs: 309 | self._env.spec.observation_dim += 24 # Assuming 24 states for adroit hand. 310 | 311 | self.spec = self._env.spec 312 | self.observation_dim = self.spec.observation_dim 313 | self.horizon = self._env.env.spec.max_episode_steps 314 | 315 | def get_obs(self,): 316 | # for our case, let's output the image, and then also the sensor features 317 | imgs = [] # number of image is number of camera 318 | 319 | if self.encoder is not None: 320 | for cam in self.cameras : 321 | # img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) 322 | img = self._env.env.sim.render(width=84, height=84) 323 | # img = env.env.sim.render(width=84, height=84, mode='offscreen') 324 | img = img[::-1, :, : ] # Image given has to be flipped 325 | if self.channels_first : 326 | img = img.transpose((2, 0, 1)) 327 | #img = img.astype(np.uint8) 328 | img = Image.fromarray(img) 329 | img = self.transforms(img) 330 | imgs.append(img) 331 | 332 | inp_img = torch.stack(imgs).to(self.device) # [num_cam, C, H, W] 333 | z = self.encoder.get_features(inp_img).reshape(-1) 334 | # assert z.shape[0] == self.latent_dim, "Encoded feature length : {}, Expected : {}".format(z.shape[0], self.latent_dim) 335 | pixels = z 336 | else: 337 | if not self.test_image: 338 | for cam in self.cameras : # for each camera, render once 339 | img = self._env.env.sim.render(width=84, height=84) 340 | # img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later 341 | # img = img[::-1, :, : ] # Image given has to be flipped 342 | if self.channels_first : 343 | img = img.transpose((2, 0, 1)) # then it's 3 x width x height 344 | # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?) 345 | #img = img.astype(np.uint8) 346 | # img = Image.fromarray(img) # TODO is this necessary? 347 | imgs.append(img) 348 | else: 349 | img = (np.random.rand(1, 84, 84) * 255).astype(np.uint8) 350 | imgs.append(img) 351 | pixels = np.concatenate(imgs, axis=0) 352 | 353 | # TODO below are what we originally had... 354 | # if not self.test_image: 355 | # for cam in self.cameras : # for each camera, render once 356 | # img = self._env.env.sim.render(width=self.width, height=self.height, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later 357 | # # img = img[::-1, :, : ] # Image given has to be flipped 358 | # if self.channels_first : 359 | # img = img.transpose((2, 0, 1)) # then it's 3 x width x height 360 | # # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?) 361 | # #img = img.astype(np.uint8) 362 | # # img = Image.fromarray(img) # TODO is this necessary? 363 | # imgs.append(img) 364 | # else: 365 | # img = (np.random.rand(1, 84, 84) * 255).astype(np.uint8) 366 | # imgs.append(img) 367 | # pixels = np.concatenate(imgs, axis=0) 368 | 369 | if not self.hybrid_state : # this defaults to True... so RRL uses hybrid state 370 | qp = None 371 | 372 | sensor_info = qp 373 | return pixels, sensor_info 374 | 375 | def get_env_infos(self): 376 | return self._env.get_env_infos() 377 | def set_seed(self, seed): 378 | return self._env.set_seed(seed) 379 | 380 | def get_stacked_pixels(self): #TODO fix it 381 | assert len(self._frames) == self._num_frames 382 | stacked_pixels = np.concatenate(list(self._frames), axis=0) 383 | return stacked_pixels 384 | 385 | def reset(self): 386 | self._env.reset() 387 | pixels, sensor_info = self.get_obs() 388 | for _ in range(self._num_frames): 389 | self._frames.append(pixels) 390 | stacked_pixels = self.get_stacked_pixels() 391 | return stacked_pixels, sensor_info 392 | 393 | def get_obs_for_first_state_but_without_reset(self): 394 | pixels, sensor_info = self.get_obs() 395 | for _ in range(self._num_frames): 396 | self._frames.append(pixels) 397 | stacked_pixels = self.get_stacked_pixels() 398 | return stacked_pixels, sensor_info 399 | 400 | def step(self, action): 401 | reward_sum = 0.0 402 | discount_prod = 1.0 # TODO pen can terminate early 403 | n_goal_achieved = 0 404 | for i_action in range(self._num_repeats): 405 | obs, reward, done, env_info = self._env.step(action) 406 | reward_sum += reward 407 | if env_info['goal_achieved'] == True: 408 | n_goal_achieved += 1 409 | if done: 410 | break 411 | env_info['n_goal_achieved'] = n_goal_achieved 412 | # now get stacked frames 413 | pixels, sensor_info = self.get_obs() 414 | self._frames.append(pixels) 415 | stacked_pixels = self.get_stacked_pixels() 416 | return [stacked_pixels, sensor_info], reward_sum, done, env_info 417 | 418 | def set_env_state(self, state): 419 | return self._env.set_env_state(state) 420 | def get_env_state(self): 421 | return self._env.get_env_state(state) 422 | 423 | def evaluate_policy(self, policy, 424 | num_episodes=5, 425 | horizon=None, 426 | gamma=1, 427 | visual=False, 428 | percentile=[], 429 | get_full_dist=False, 430 | mean_action=False, 431 | init_env_state=None, 432 | terminate_at_done=True, 433 | seed=123): 434 | # TODO this needs to be rewritten 435 | 436 | self.set_seed(seed) 437 | horizon = self.horizon if horizon is None else horizon 438 | mean_eval, std, min_eval, max_eval = 0.0, 0.0, -1e8, -1e8 439 | ep_returns = np.zeros(num_episodes) 440 | self.encoder.eval() 441 | 442 | for ep in range(num_episodes): 443 | o = self.reset() 444 | if init_env_state is not None: 445 | self.set_env_state(init_env_state) 446 | t, done = 0, False 447 | while t < horizon and (done == False or terminate_at_done == False): 448 | self.render() if visual is True else None 449 | o = self.get_obs(self._env.get_obs()) 450 | a = policy.get_action(o)[1]['evaluation'] if mean_action is True else policy.get_action(o)[0] 451 | o, r, done, _ = self.step(a) 452 | ep_returns[ep] += (gamma ** t) * r 453 | t += 1 454 | 455 | mean_eval, std = np.mean(ep_returns), np.std(ep_returns) 456 | min_eval, max_eval = np.amin(ep_returns), np.amax(ep_returns) 457 | base_stats = [mean_eval, std, min_eval, max_eval] 458 | 459 | percentile_stats = [] 460 | for p in percentile: 461 | percentile_stats.append(np.percentile(ep_returns, p)) 462 | 463 | full_dist = ep_returns if get_full_dist is True else None 464 | 465 | return [base_stats, percentile_stats, full_dist] 466 | 467 | def get_pixels_with_width_height(self, w, h): 468 | imgs = [] # number of image is number of camera 469 | 470 | for cam in self.cameras : # for each camera, render once 471 | img = self._env.env.sim.render(width=w, height=h, mode='offscreen', camera_name=cam, device_id=0) # TODO device id will think later 472 | # img = img[::-1, :, : ] # Image given has to be flipped 473 | if self.channels_first : 474 | img = img.transpose((2, 0, 1)) # then it's 3 x width x height 475 | # we should do channels first... (not sure why by default it's not, maybe they did some transpose when using the encoder?) 476 | #img = img.astype(np.uint8) 477 | # img = Image.fromarray(img) # TODO is this necessary? 478 | imgs.append(img) 479 | 480 | pixels = np.concatenate(imgs, axis=0) 481 | return pixels 482 | -------------------------------------------------------------------------------- /src/rrl_local/rrl_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Rutav Shah, Indian Institute of Technlogy Kharagpur 2 | # Copyright (c) Facebook, Inc. and its affiliates 3 | 4 | from mjrl.utils.gym_env import GymEnv 5 | from rrl_local.rrl_multicam import BasicAdroitEnv 6 | import os 7 | 8 | def make_basic_env(env, cam_list=[], from_pixels=False, hybrid_state=None, test_image=False, channels_first=False, 9 | num_repeats=1, num_frames=1): 10 | e = GymEnv(env) 11 | env_kwargs = None 12 | if from_pixels : # TODO here they first check if it's from pixels... if pixel and not resnet then use 84x84?? 13 | height = 84 14 | width = 84 15 | latent_dim = height*width*len(cam_list)*3 16 | # RRL class instance is environment wrapper... 17 | e = BasicAdroitEnv(e, cameras=cam_list, 18 | height=height, width=width, latent_dim=latent_dim, hybrid_state=hybrid_state, 19 | test_image=test_image, channels_first=channels_first, num_repeats=num_repeats, num_frames=num_frames) 20 | env_kwargs = {'rrl_kwargs' : e.env_kwargs} 21 | # if not from pixels... then it's simpler 22 | return e, env_kwargs 23 | 24 | def make_env(env, cam_list=[], from_pixels=False, encoder_type=None, hybrid_state=None,) : 25 | # TODO why is encoder type put into the environment? 26 | e = GymEnv(env) 27 | env_kwargs = None 28 | if from_pixels : # TODO here they first check if it's from pixels... if pixel and not resnet then use 84x84?? 29 | height = 84 30 | width = 84 31 | latent_dim = height*width*len(cam_list)*3 32 | 33 | if encoder_type and encoder_type == 'resnet34': 34 | assert from_pixels==True 35 | height = 256 36 | width = 256 37 | latent_dim = 512*len(cam_list) #TODO each camera provides an image? 38 | if from_pixels: 39 | # RRL class instance is environment wrapper... 40 | e = BasicAdroitEnv(e, cameras=cam_list, encoder_type=encoder_type, 41 | height=height, width=width, latent_dim=latent_dim, hybrid_state=hybrid_state) 42 | env_kwargs = {'rrl_kwargs' : e.env_kwargs} 43 | return e, env_kwargs 44 | 45 | def make_dir(dir_path): 46 | if not os.path.exists(dir_path): 47 | os.makedirs(dir_path) 48 | 49 | 50 | def preprocess_args(args): 51 | job_data = {} 52 | job_data['seed'] = args.seed 53 | job_data['env'] = args.env 54 | job_data['output'] = args.output 55 | job_data['from_pixels'] = args.from_pixels 56 | job_data['hybrid_state'] = args.hybrid_state 57 | job_data['stack_frames'] = args.stack_frames 58 | job_data['encoder_type'] = args.encoder_type 59 | job_data['cam1'] = args.cam1 60 | job_data['cam2'] = args.cam2 61 | job_data['cam3'] = args.cam3 62 | job_data['algorithm'] = args.algorithm 63 | job_data['num_cpu'] = args.num_cpu 64 | job_data['save_freq'] = args.save_freq 65 | job_data['eval_rollouts'] = args.eval_rollouts 66 | job_data['demo_file'] = args.demo_file 67 | job_data['bc_batch_size'] = args.bc_batch_size 68 | job_data['bc_epochs'] = args.bc_epochs 69 | job_data['bc_learn_rate'] = args.bc_learn_rate 70 | #job_data['policy_size'] = args.policy_size 71 | job_data['policy_size'] = tuple(map(int, args.policy_size.split(', '))) 72 | job_data['vf_batch_size'] = args.vf_batch_size 73 | job_data['vf_epochs'] = args.vf_epochs 74 | job_data['vf_learn_rate'] = args.vf_learn_rate 75 | job_data['rl_step_size'] = args.rl_step_size 76 | job_data['rl_gamma'] = args.rl_gamma 77 | job_data['rl_gae'] = args.rl_gae 78 | job_data['rl_num_traj'] = args.rl_num_traj 79 | job_data['rl_num_iter'] = args.rl_num_iter 80 | job_data['lam_0'] = args.lam_0 81 | job_data['lam_1'] = args.lam_1 82 | print(job_data) 83 | return job_data 84 | 85 | 86 | -------------------------------------------------------------------------------- /src/stage1_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from numpy import identity 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | """ 9 | most code here are modified from the TORCHVISION.MODELS.RESNET 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 28 | base_width=64, dilation=1, norm_layer=None): 29 | super(BasicBlock, self).__init__() 30 | if norm_layer is None: 31 | norm_layer = nn.BatchNorm2d 32 | if groups != 1 or base_width != 64: 33 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 34 | if dilation > 1: 35 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 36 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = norm_layer(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = norm_layer(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | identity = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | identity = self.downsample(x) 57 | 58 | out += identity 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | class OneLayerBlock(nn.Module): 64 | # similar to BasicBlock, but shallower 65 | expansion = 1 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 68 | base_width=64, dilation=1, norm_layer=None): 69 | super(OneLayerBlock, self).__init__() 70 | if norm_layer is None: 71 | norm_layer = nn.BatchNorm2d 72 | if groups != 1 or base_width != 64: 73 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 74 | if dilation > 1: 75 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 76 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv3x3(inplanes, planes, stride) 78 | self.bn1 = norm_layer(planes) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | return out 87 | 88 | class Bottleneck(nn.Module): 89 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 90 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 91 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 92 | # This variant is also known as ResNet V1.5 and improves accuracy according to 93 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 94 | 95 | expansion = 4 96 | 97 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 98 | base_width=64, dilation=1, norm_layer=None): 99 | super(Bottleneck, self).__init__() 100 | if norm_layer is None: 101 | norm_layer = nn.BatchNorm2d 102 | width = int(planes * (base_width / 64.)) * groups 103 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 104 | self.conv1 = conv1x1(inplanes, width) 105 | self.bn1 = norm_layer(width) 106 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 107 | self.bn2 = norm_layer(width) 108 | self.conv3 = conv1x1(width, planes * self.expansion) 109 | self.bn3 = norm_layer(planes * self.expansion) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.downsample = downsample 112 | self.stride = stride 113 | 114 | def forward(self, x): 115 | identity = x 116 | 117 | out = self.conv1(x) 118 | out = self.bn1(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv2(out) 122 | out = self.bn2(out) 123 | out = self.relu(out) 124 | 125 | out = self.conv3(out) 126 | out = self.bn3(out) 127 | 128 | if self.downsample is not None: 129 | identity = self.downsample(x) 130 | 131 | out += identity 132 | out = self.relu(out) 133 | 134 | return out 135 | 136 | def drq_weight_init(m): 137 | # weight init scheme used in DrQv2 138 | if isinstance(m, nn.Linear): 139 | nn.init.orthogonal_(m.weight.data) 140 | if hasattr(m.bias, 'data'): 141 | m.bias.data.fill_(0.0) 142 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 143 | gain = nn.init.calculate_gain('relu') 144 | nn.init.orthogonal_(m.weight.data, gain) 145 | if hasattr(m.bias, 'data'): 146 | m.bias.data.fill_(0.0) 147 | 148 | class Stage3ShallowEncoder(nn.Module): 149 | """ 150 | this is the encoder architecture used in DrQv2 151 | """ 152 | def __init__(self, obs_shape, n_channel): 153 | super().__init__() 154 | assert len(obs_shape) == 3 155 | self.repr_dim = n_channel * 35 * 35 156 | self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2) 157 | self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1) 158 | self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1) 159 | self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1) 160 | self.relu = nn.ReLU(inplace=True) 161 | self.apply(drq_weight_init) 162 | 163 | def _forward_impl(self, x): 164 | x = self.relu(self.conv1(x)) 165 | x = self.relu(self.conv2(x)) 166 | x = self.relu(self.conv3(x)) 167 | x = self.relu(self.conv4(x)) 168 | return x 169 | 170 | def forward(self, obs): 171 | o = obs 172 | h = self._forward_impl(o) 173 | h = h.view(h.shape[0], -1) 174 | return h 175 | 176 | class ResNet84(nn.Module): 177 | """ 178 | default stage 1 encoder used by VRL3, this is modified from the PyTorch standard ResNet class 179 | but is more lightweight and this is much faster with 84x84 input size 180 | use "layers" to specify how deep the network is 181 | use "start_num_channel" to control how wide it is 182 | """ 183 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 184 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 185 | norm_layer=None, start_num_channel=32): 186 | super(ResNet84, self).__init__() 187 | if norm_layer is None: 188 | norm_layer = nn.BatchNorm2d 189 | self._norm_layer = norm_layer 190 | 191 | self.start_num_channel = start_num_channel 192 | self.inplanes = start_num_channel 193 | self.dilation = 1 194 | if replace_stride_with_dilation is None: 195 | # each element in the tuple indicates if we should replace 196 | # the 2x2 stride with a dilated convolution instead 197 | replace_stride_with_dilation = [False, False, False] 198 | if len(replace_stride_with_dilation) != 3: 199 | raise ValueError("replace_stride_with_dilation should be None " 200 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 201 | self.groups = groups 202 | self.base_width = width_per_group 203 | self.conv1 = nn.Conv2d(3, self.inplanes, 3, stride=2) 204 | self.bn1 = norm_layer(self.inplanes) 205 | self.relu = nn.ReLU(inplace=True) 206 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 207 | 208 | self.layer1 = self._make_layer(block, start_num_channel, layers[0]) 209 | self.layer2 = self._make_layer(block, start_num_channel * 2, layers[1], stride=2, 210 | dilate=replace_stride_with_dilation[0]) 211 | self.layer3 = self._make_layer(block, start_num_channel * 4, layers[2], stride=2, 212 | dilate=replace_stride_with_dilation[1]) 213 | self.layer4 = self._make_layer(block, start_num_channel * 8, layers[3], stride=2, 214 | dilate=replace_stride_with_dilation[2]) 215 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 216 | self.fc = nn.Linear(start_num_channel * 8 * block.expansion, num_classes) 217 | 218 | for m in self.modules(): 219 | if isinstance(m, nn.Conv2d): 220 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 221 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 222 | nn.init.constant_(m.weight, 1) 223 | nn.init.constant_(m.bias, 0) 224 | 225 | # Zero-initialize the last BN in each residual branch, 226 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 227 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 228 | if zero_init_residual: 229 | for m in self.modules(): 230 | if isinstance(m, Bottleneck): 231 | nn.init.constant_(m.bn3.weight, 0) 232 | elif isinstance(m, BasicBlock): 233 | nn.init.constant_(m.bn2.weight, 0) 234 | 235 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 236 | # vrl3: if block is 0, allows a smaller network size 237 | if blocks == 0: 238 | block = OneLayerBlock 239 | 240 | norm_layer = self._norm_layer 241 | downsample = None 242 | previous_dilation = self.dilation 243 | if dilate: 244 | self.dilation *= stride 245 | stride = 1 246 | if stride != 1 or self.inplanes != planes * block.expansion: 247 | downsample = nn.Sequential( 248 | conv1x1(self.inplanes, planes * block.expansion, stride), 249 | norm_layer(planes * block.expansion), 250 | ) 251 | 252 | layers = [] 253 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 254 | self.base_width, previous_dilation, norm_layer)) 255 | self.inplanes = planes * block.expansion 256 | for _ in range(1, blocks): 257 | layers.append(block(self.inplanes, planes, groups=self.groups, 258 | base_width=self.base_width, dilation=self.dilation, 259 | norm_layer=norm_layer)) 260 | 261 | return nn.Sequential(*layers) 262 | 263 | def _forward_impl(self, x): 264 | x = self.conv1(x) 265 | x = self.bn1(x) 266 | x = self.relu(x) 267 | 268 | x = self.layer1(x) 269 | x = self.layer2(x) 270 | x = self.layer3(x) 271 | x = self.layer4(x) 272 | x = self.avgpool(x) 273 | x = torch.flatten(x, 1) 274 | x = self.fc(x) 275 | 276 | return x 277 | 278 | def get_feature_size(self): 279 | assert self.start_num_channel % 32 == 0 280 | multiplier = self.start_num_channel // 32 281 | size = 256 * multiplier 282 | return size 283 | 284 | def forward(self, x): 285 | return self._forward_impl(x) 286 | 287 | def get_features(self, x): 288 | x = self.conv1(x) 289 | # print("0", x.shape) # 32 x 41 x 41 = 53792 290 | x = self.bn1(x) 291 | x = self.relu(x) 292 | 293 | x = self.layer1(x) 294 | # print("1", x.shape) # 32 x 41 x 41= 53792 295 | 296 | x = self.layer2(x) 297 | # print("2", x.shape) # 64 x 21 x 21 = 28224 298 | 299 | x = self.layer3(x) 300 | # print("3", x.shape) # 128 x 11 x 11 = 15488 301 | 302 | x = self.layer4(x) 303 | # print("4", x.shape) # 256 x 6 x 6 = 9216 304 | 305 | x = self.avgpool(x) 306 | # print("pool", x.shape) # 256 x 1 x 1 307 | 308 | final_out = torch.flatten(x, 1) 309 | # print("flatten", x.shape) # 256 310 | return final_out 311 | 312 | class Identity(nn.Module): 313 | def __init__(self, input_placeholder=None): 314 | super(Identity, self).__init__() 315 | 316 | def forward(self, x): 317 | return x 318 | 319 | -------------------------------------------------------------------------------- /src/testing/computation_time_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | # make sure mujoco and nvidia will be found 6 | os.environ['LD_LIBRARY_PATH'] = os.environ.get('LD_LIBRARY_PATH', '') + \ 7 | ':/workspace/.mujoco/mujoco210/bin:/usr/local/nvidia/lib:/usr/lib/nvidia' 8 | os.environ['MUJOCO_PY_MUJOCO_PATH'] = '/workspace/.mujoco/mujoco210/' 9 | os.environ['MUJOCO_GL'] = 'egl' 10 | # set to glfw if trying to render locally with a monitor 11 | # os.environ['MUJOCO_GL'] = 'glfw' 12 | 13 | from mjrl.utils.gym_env import GymEnv 14 | import mujoco_py 15 | import mj_envs 16 | import time 17 | 18 | e = GymEnv('hammer-v0') 19 | 20 | # when program is init, the first few renders will be a bit slower, 21 | # so we first run 1000 renders without recording their time 22 | for i in range(1000): 23 | img = e.env.sim.render(width=84, height=84, mode='offscreen', camera_name='top') 24 | 25 | # now try to know how much time it takes for 1000 rendering 26 | st = time.time() 27 | for i in range(1000): 28 | img = e.env.sim.render(width=84, height=84, mode='offscreen', camera_name='top') 29 | et = time.time() 30 | 31 | # when mujoco uses gpu to render (when things work correctly, should be fast) 32 | # should take < 0,5 seconds 33 | # if your GPU is not working correct, it might be 10x slower 34 | # if it is slow, then sth is wrong, either the program can't see your gpu 35 | # or your gpu is available, but the old mujoco failed to use it 36 | time_used = et-st 37 | if time_used < 0.5: 38 | print("time used: %.3f seconds, looks correct!" % time_used) 39 | else: 40 | print("time used: %.3f seconds." % time_used) 41 | print("WARNING!!!!! SLOWER THAN EXPECTED, YOUR MUJOCO RENDERING MIGHT NOT WORK CORRECTLY!") 42 | 43 | print("Depends on hardware but typically should be < 0.5 seconds.") 44 | 45 | 46 | -------------------------------------------------------------------------------- /src/testing/zzz.py: -------------------------------------------------------------------------------- 1 | txt = "For only {} {:.0f} {:.3f} dollars!" 2 | print(txt.format(1, 49.5555, 22)) -------------------------------------------------------------------------------- /src/train_adroit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os, psutil 5 | # make sure mujoco and nvidia will be found 6 | os.environ['LD_LIBRARY_PATH'] = os.environ.get('LD_LIBRARY_PATH', '') + \ 7 | ':/workspace/.mujoco/mujoco210/bin:/usr/local/nvidia/lib:/usr/lib/nvidia' 8 | os.environ['MUJOCO_PY_MUJOCO_PATH'] = '/workspace/.mujoco/mujoco210/' 9 | import numpy as np 10 | import shutil 11 | import warnings 12 | warnings.filterwarnings('ignore', category=DeprecationWarning) 13 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 14 | import platform 15 | os.environ['MUJOCO_GL'] = 'egl' 16 | # set to glfw if trying to render locally with a monitor 17 | # os.environ['MUJOCO_GL'] = 'glfw' 18 | os.environ['EGL_DEVICE_ID'] = '0' 19 | from distutils.dir_util import copy_tree 20 | import torchvision.transforms as transforms 21 | from pathlib import Path 22 | import cv2 23 | import imageio 24 | 25 | import hydra 26 | import torch 27 | from dm_env import StepType, TimeStep, specs 28 | 29 | import utils 30 | from logger import Logger 31 | from replay_buffer import ReplayBufferStorage, make_replay_loader 32 | from video import TrainVideoRecorder, VideoRecorder 33 | import joblib 34 | import pickle 35 | import time 36 | 37 | torch.backends.cudnn.benchmark = True 38 | 39 | # TODO this part can be done during workspace setup 40 | ENV_TYPE = 'adroit' 41 | if ENV_TYPE == 'adroit': 42 | import mj_envs 43 | from mjrl.utils.gym_env import GymEnv 44 | from rrl_local.rrl_utils import make_basic_env, make_dir 45 | from adroit import AdroitEnv 46 | else: 47 | import dmc 48 | IS_ADROIT = True if ENV_TYPE == 'adroit' else False 49 | 50 | def get_ram_info(): 51 | # get info on RAM usage 52 | d = dict(psutil.virtual_memory()._asdict()) 53 | for key in d: 54 | if key != "percent": # convert to GB 55 | d[key] = int(d[key] / 1024**3 * 100)/ 100 56 | return d 57 | 58 | def make_agent(obs_spec, action_spec, cfg): 59 | cfg.obs_shape = obs_spec.shape 60 | cfg.action_shape = action_spec.shape 61 | return hydra.utils.instantiate(cfg) 62 | 63 | def print_stage2_time_est(time_used, curr_n_update, total_n_update): 64 | time_per_update = time_used / curr_n_update 65 | est_total_time = time_per_update * total_n_update 66 | est_time_remaining = est_total_time - time_used 67 | print("Stage 2 [{:.2f}%]. Frames:[{:.0f}/{:.0f}]K. Time:[{:.2f}/{:.2f}]hrs. Overall FPS: {}.".format( 68 | curr_n_update / total_n_update * 100, curr_n_update/1000, total_n_update/1000, 69 | time_used / 3600, est_total_time / 3600, int(curr_n_update / time_used))) 70 | 71 | def print_stage3_time_est(time_used, curr_n_frames, total_n_frames): 72 | time_per_update = time_used / curr_n_frames 73 | est_total_time = time_per_update * total_n_frames 74 | est_time_remaining = est_total_time - time_used 75 | print("Stage 3 [{:.2f}%]. Frames:[{:.0f}/{:.0f}]K. Time:[{:.2f}/{:.2f}]hrs. Overall FPS: {}.".format( 76 | curr_n_frames / total_n_frames * 100, curr_n_frames/1000, total_n_frames/1000, 77 | time_used / 3600, est_total_time / 3600, int(curr_n_frames / time_used))) 78 | 79 | class Workspace: 80 | def __init__(self, cfg): 81 | self.work_dir = Path.cwd() 82 | print("\n=== Training log stored to: ===") 83 | print(f'workspace: {self.work_dir}') 84 | self.direct_folder_name = os.path.basename(self.work_dir) 85 | 86 | self.cfg = cfg 87 | utils.set_seed_everywhere(cfg.seed) 88 | self.device = torch.device(cfg.device) 89 | self.replay_buffer_fetch_every = 1000 90 | if self.cfg.debug > 0: # if debug mode, then change hyperparameters for quick testing 91 | self.set_debug_hyperparameters() 92 | self.setup() 93 | 94 | self.agent = make_agent(self.train_env.observation_spec(), 95 | self.train_env.action_spec(), 96 | self.cfg.agent) 97 | self.timer = utils.Timer() 98 | self._global_step = 0 99 | self._global_episode = 0 100 | 101 | def set_debug_hyperparameters(self): 102 | self.cfg.num_seed_frames=1000 if self.cfg.num_seed_frames > 1000 else self.cfg.num_seed_frames 103 | self.cfg.agent.num_expl_steps=500 if self.cfg.agent.num_expl_steps > 1000 else self.cfg.agent.num_expl_steps 104 | if self.cfg.replay_buffer_num_workers > 1: 105 | self.cfg.replay_buffer_num_workers = 1 106 | self.cfg.num_eval_episodes = 1 107 | self.cfg.replay_buffer_size = 30000 108 | self.cfg.batch_size = 8 109 | self.cfg.feature_dim = 8 110 | self.cfg.num_train_frames = 5050 111 | self.replay_buffer_fetch_every = 30 112 | self.cfg.stage2_n_update = 100 113 | self.cfg.num_demo = 3 114 | self.cfg.eval_every_frames = 3000 115 | self.cfg.agent.hidden_dim = 8 116 | self.cfg.agent.num_expl_steps = 500 117 | self.cfg.stage2_eval_every_frames = 50 118 | 119 | def setup(self): 120 | warnings.filterwarnings('ignore', category=DeprecationWarning) 121 | 122 | if self.cfg.save_models: 123 | assert self.cfg.action_repeat % 2 == 0 124 | 125 | # create logger 126 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb) 127 | env_name = self.cfg.task_name 128 | env_type = 'adroit' if env_name in ('hammer-v0','door-v0','pen-v0','relocate-v0') else 'dmc' 129 | # assert env_name in ('hammer-v0','door-v0','pen-v0','relocate-v0',) 130 | 131 | if self.cfg.agent.encoder_lr_scale == 'auto': 132 | if env_name == 'relocate-v0': 133 | self.cfg.agent.encoder_lr_scale = 0.01 134 | else: 135 | self.cfg.agent.encoder_lr_scale = 1 136 | 137 | self.env_feature_type = self.cfg.env_feature_type 138 | if env_type == 'adroit': 139 | # reward rescale can either be added in the env or in the agent code when reward is used 140 | self.train_env = AdroitEnv(env_name, test_image=False, num_repeats=self.cfg.action_repeat, 141 | num_frames=self.cfg.frame_stack, env_feature_type=self.env_feature_type, 142 | device=self.device, reward_rescale=self.cfg.reward_rescale) 143 | self.eval_env = AdroitEnv(env_name, test_image=False, num_repeats=self.cfg.action_repeat, 144 | num_frames=self.cfg.frame_stack, env_feature_type=self.env_feature_type, 145 | device=self.device, reward_rescale=self.cfg.reward_rescale) 146 | 147 | data_specs = (self.train_env.observation_spec(), 148 | self.train_env.observation_sensor_spec(), 149 | self.train_env.action_spec(), 150 | specs.Array((1,), np.float32, 'reward'), 151 | specs.Array((1,), np.float32, 'discount'), 152 | specs.Array((1,), np.int8, 'n_goal_achieved'), 153 | specs.Array((1,), np.float32, 'time_limit_reached'), 154 | ) 155 | else: 156 | self.train_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack, 157 | self.cfg.action_repeat, self.cfg.seed) 158 | self.eval_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack, 159 | self.cfg.action_repeat, self.cfg.seed) 160 | 161 | data_specs = (self.train_env.observation_spec(), 162 | self.train_env.action_spec(), 163 | specs.Array((1,), np.float32, 'reward'), 164 | specs.Array((1,), np.float32, 'discount')) 165 | 166 | # create replay buffer 167 | self.replay_storage = ReplayBufferStorage(data_specs, self.work_dir / 'buffer') 168 | 169 | 170 | 171 | self.replay_loader = make_replay_loader( 172 | self.work_dir / 'buffer', self.cfg.replay_buffer_size, 173 | self.cfg.batch_size, self.cfg.replay_buffer_num_workers, 174 | self.cfg.save_snapshot, self.cfg.nstep, self.cfg.discount, self.replay_buffer_fetch_every, 175 | is_adroit=IS_ADROIT) 176 | self._replay_iter = None 177 | 178 | self.video_recorder = VideoRecorder( 179 | self.work_dir if self.cfg.save_video else None) 180 | self.train_video_recorder = TrainVideoRecorder( 181 | self.work_dir if self.cfg.save_train_video else None) 182 | 183 | def set_demo_buffer_nstep(self, nstep): 184 | self.replay_loader_demo.dataset._nstep = nstep 185 | 186 | @property 187 | def global_step(self): 188 | return self._global_step 189 | 190 | @property 191 | def global_episode(self): 192 | return self._global_episode 193 | 194 | @property 195 | def global_frame(self): 196 | return self.global_step * self.cfg.action_repeat 197 | 198 | @property 199 | def replay_iter(self): 200 | if self._replay_iter is None: 201 | self._replay_iter = iter(self.replay_loader) 202 | return self._replay_iter 203 | 204 | def eval_dmc(self): 205 | step, episode, total_reward = 0, 0, 0 206 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 207 | 208 | while eval_until_episode(episode): 209 | time_step = self.eval_env.reset() 210 | self.video_recorder.init(self.eval_env, enabled=(episode == 0)) 211 | while not time_step.last(): 212 | with torch.no_grad(), utils.eval_mode(self.agent): 213 | action = self.agent.act(time_step.observation, 214 | self.global_step, 215 | eval_mode=True) 216 | time_step = self.eval_env.step(action) 217 | self.video_recorder.record(self.eval_env) 218 | total_reward += time_step.reward 219 | step += 1 220 | 221 | episode += 1 222 | self.video_recorder.save(f'{self.global_frame}.mp4') 223 | 224 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log: 225 | log('episode_reward', total_reward / episode) 226 | log('episode_length', step * self.cfg.action_repeat / episode) 227 | log('episode', self.global_episode) 228 | log('step', self.global_step) 229 | 230 | def eval_adroit(self, force_number_episodes=None, do_log=True): 231 | if ENV_TYPE != 'adroit': 232 | return self.eval_dmc() 233 | 234 | step, episode, total_reward = 0, 0, 0 235 | n_eval_episode = force_number_episodes if force_number_episodes is not None else self.cfg.num_eval_episodes 236 | eval_until_episode = utils.Until(n_eval_episode) 237 | total_success = 0.0 238 | while eval_until_episode(episode): 239 | n_goal_achieved_total = 0 240 | time_step = self.eval_env.reset() 241 | self.video_recorder.init(self.eval_env, enabled=(episode == 0)) 242 | while not time_step.last(): 243 | with torch.no_grad(), utils.eval_mode(self.agent): 244 | observation = time_step.observation 245 | action = self.agent.act(observation, 246 | self.global_step, 247 | eval_mode=True, 248 | obs_sensor=time_step.observation_sensor) 249 | time_step = self.eval_env.step(action) 250 | n_goal_achieved_total += time_step.n_goal_achieved 251 | self.video_recorder.record(self.eval_env) 252 | total_reward += time_step.reward 253 | step += 1 254 | 255 | # here check if success for Adroit tasks. The threshold values come from the mj_envs code 256 | # e.g. https://github.com/ShahRutav/mj_envs/blob/5ee75c6e294dda47983eb4c60b6dd8f23a3f9aec/mj_envs/hand_manipulation_suite/pen_v0.py 257 | # can also use the evaluate_success function from Adroit envs, but can be more complicated 258 | if self.cfg.task_name == 'pen-v0': 259 | threshold = 20 260 | else: 261 | threshold = 25 262 | if n_goal_achieved_total > threshold: 263 | total_success += 1 264 | 265 | episode += 1 266 | self.video_recorder.save(f'{self.global_frame}.mp4') 267 | success_rate_standard = total_success / n_eval_episode 268 | episode_reward_standard = total_reward / episode 269 | episode_length_standard = step * self.cfg.action_repeat / episode 270 | 271 | if do_log: 272 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log: 273 | log('episode_reward', episode_reward_standard) 274 | log('success_rate', success_rate_standard) 275 | log('episode_length', episode_length_standard) 276 | log('episode', self.global_episode) 277 | log('step', self.global_step) 278 | 279 | return episode_reward_standard, success_rate_standard 280 | 281 | def get_data_folder_path(self): 282 | amlt_data_path = os.getenv('AMLT_DATA_DIR', '') 283 | if amlt_data_path != '': # if on cluster 284 | return amlt_data_path 285 | else: # if not on cluster, return the data folder path specified in cfg 286 | return self.cfg.local_data_dir 287 | 288 | def get_demo_path(self, env_name): 289 | # given a environment name (with the -v0 part), return the path to its demo file 290 | data_folder_path = self.get_data_folder_path() 291 | demo_folder_path = os.path.join(data_folder_path, "demonstrations") 292 | demo_path = os.path.join(demo_folder_path, env_name + "_demos.pickle") 293 | return demo_path 294 | 295 | def load_demo(self, replay_storage, env_name, verbose=True): 296 | # will load demo data and put them into a replay storage 297 | demo_path = self.get_demo_path(env_name) 298 | if verbose: 299 | print("Trying to get demo data from:", demo_path) 300 | 301 | # get the raw state demo data, a list of length 25 302 | demo_data = pickle.load(open(demo_path, 'rb')) 303 | if self.cfg.num_demo >= 0: 304 | demo_data = demo_data[:self.cfg.num_demo] 305 | 306 | """ 307 | the adroit demo data is in raw state, so we need to convert them into image data 308 | we init an env and run episodes with the stored actions in the demo data 309 | then put the image and sensor data into the replay buffer, also need to clip actions here 310 | this part is basically following the RRL code 311 | """ 312 | demo_env = AdroitEnv(env_name, test_image=False, num_repeats=1, num_frames=self.cfg.frame_stack, 313 | env_feature_type=self.env_feature_type, device=self.device, reward_rescale=self.cfg.reward_rescale) 314 | demo_env.reset() 315 | 316 | total_data_count = 0 317 | for i_path in range(len(demo_data)): 318 | path = demo_data[i_path] 319 | demo_env.reset() 320 | demo_env.set_env_state(path['init_state_dict']) 321 | time_step = demo_env.get_current_obs_without_reset() 322 | replay_storage.add(time_step) 323 | 324 | ep_reward = 0 325 | ep_n_goal = 0 326 | for i_act in range(len(path['actions'])): 327 | total_data_count += 1 328 | action = path['actions'][i_act] 329 | action = action.astype(np.float32) 330 | # when action is put into the environment, they will be clipped. 331 | action[action > 1] = 1 332 | action[action < -1] = -1 333 | 334 | # when they collect the demo data, they actually did not use a timelimit... 335 | if i_act == len(path['actions']) - 1: 336 | force_step_type = 'last' 337 | else: 338 | force_step_type = 'mid' 339 | 340 | time_step = demo_env.step(action, force_step_type=force_step_type) 341 | replay_storage.add(time_step) 342 | 343 | reward = time_step.reward 344 | ep_reward += reward 345 | 346 | goal_achieved = time_step.n_goal_achieved 347 | ep_n_goal += goal_achieved 348 | if verbose: 349 | print('demo trajectory %d, len: %d, return: %.2f, goal achieved steps: %d' % 350 | (i_path, len(path['actions']), ep_reward, ep_n_goal)) 351 | if verbose: 352 | print("Demo data load finished, total data count:", total_data_count) 353 | 354 | def get_pretrained_model_path(self, stage1_model_name): 355 | # given a stage1 model name, return the path to the pretrained model 356 | data_folder_path = self.get_data_folder_path() 357 | model_folder_path = os.path.join(data_folder_path, "trained_models") 358 | model_path = os.path.join(model_folder_path, stage1_model_name + '_checkpoint.pth.tar') 359 | return model_path 360 | 361 | def train(self): 362 | train_start_time = time.time() 363 | print("\n=== Training started! ===") 364 | """=================================== LOAD PRETRAINED MODEL ===================================""" 365 | if self.cfg.stage1_use_pretrain: 366 | self.agent.load_pretrained_encoder(self.get_pretrained_model_path(self.cfg.stage1_model_name)) 367 | self.agent.switch_to_RL_stages() 368 | 369 | """========================================= LOAD DATA =========================================""" 370 | if self.cfg.load_demo: 371 | self.load_demo(self.replay_storage, self.cfg.task_name) 372 | print("Model and data loading finished in %.2f hours." % ((time.time()-train_start_time) / 3600)) 373 | 374 | """========================================== STAGE 2 ==========================================""" 375 | print("\n=== Stage 2 started ===") 376 | stage2_start_time = time.time() 377 | stage2_n_update = self.cfg.stage2_n_update 378 | if stage2_n_update > 0: 379 | for i_stage2 in range(stage2_n_update): 380 | metrics = self.agent.update(self.replay_iter, i_stage2, stage=2, use_sensor=IS_ADROIT) 381 | if i_stage2 % self.cfg.stage2_eval_every_frames == 0: 382 | average_score, succ_rate = self.eval_adroit(force_number_episodes=self.cfg.stage2_num_eval_episodes, 383 | do_log=False) 384 | print('Stage 2 step %d, Q(s,a): %.2f, Q loss: %.2f, score: %.2f, succ rate: %.2f' % 385 | (i_stage2, metrics['critic_q1'], metrics['critic_loss'], average_score, succ_rate)) 386 | if self.cfg.show_computation_time_est and i_stage2 > 0 and i_stage2 % self.cfg.show_time_est_interval == 0: 387 | print_stage2_time_est(time.time()-stage2_start_time, i_stage2+1, stage2_n_update) 388 | print("Stage 2 finished in %.2f hours." % ((time.time()-stage2_start_time) / 3600)) 389 | 390 | """========================================== STAGE 3 ==========================================""" 391 | print("\n=== Stage 3 started ===") 392 | stage3_start_time = time.time() 393 | # predicates 394 | train_until_step = utils.Until(self.cfg.num_train_frames, self.cfg.action_repeat) 395 | seed_until_step = utils.Until(self.cfg.num_seed_frames, self.cfg.action_repeat) 396 | eval_every_step = utils.Every(self.cfg.eval_every_frames, self.cfg.action_repeat) 397 | 398 | episode_step, episode_reward = 0, 0 399 | time_step = self.train_env.reset() 400 | self.replay_storage.add(time_step) 401 | self.train_video_recorder.init(time_step.observation) 402 | metrics = None 403 | 404 | episode_step_since_log, episode_reward_list, episode_frame_list = 0, [0], [0] 405 | self.timer.reset() 406 | while train_until_step(self.global_step): 407 | # if 1000 steps passed, do some logging 408 | if self.global_step % 1000 == 0 and metrics is not None: 409 | elapsed_time, total_time = self.timer.reset() 410 | episode_frame_since_log = episode_step_since_log * self.cfg.action_repeat 411 | with self.logger.log_and_dump_ctx(self.global_frame, ty='train') as log: 412 | log('fps', episode_frame_since_log / elapsed_time) 413 | log('total_time', total_time) 414 | log('episode_reward', np.mean(episode_reward_list)) 415 | log('episode_length', np.mean(episode_frame_list)) 416 | log('episode', self.global_episode) 417 | log('buffer_size', len(self.replay_storage)) 418 | log('step', self.global_step) 419 | episode_step_since_log, episode_reward_list, episode_frame_list = 0, [0], [0] 420 | if self.cfg.show_computation_time_est and self.global_step > 0 and self.global_step % self.cfg.show_time_est_interval == 0: 421 | print_stage3_time_est(time.time() - stage3_start_time, self.global_frame + 1, self.cfg.num_train_frames) 422 | 423 | # if reached end of episode 424 | if time_step.last(): 425 | self._global_episode += 1 426 | self.train_video_recorder.save(f'{self.global_frame}.mp4') 427 | # wait until all the metrics schema is populated 428 | if metrics is not None: 429 | # log stats 430 | episode_step_since_log += episode_step 431 | episode_reward_list.append(episode_reward) 432 | episode_frame = episode_step * self.cfg.action_repeat 433 | episode_frame_list.append(episode_frame) 434 | 435 | # reset env 436 | time_step = self.train_env.reset() 437 | self.replay_storage.add(time_step) 438 | self.train_video_recorder.init(time_step.observation) 439 | # try to save snapshot 440 | if self.cfg.save_snapshot: 441 | self.save_snapshot() 442 | episode_step, episode_reward = 0, 0 443 | 444 | # try to evaluate 445 | if eval_every_step(self.global_step): 446 | self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame) 447 | self.eval_adroit() 448 | 449 | # sample action 450 | if IS_ADROIT: 451 | obs_sensor = time_step.observation_sensor 452 | else: 453 | obs_sensor = None 454 | with torch.no_grad(), utils.eval_mode(self.agent): 455 | action = self.agent.act(time_step.observation, 456 | self.global_step, 457 | eval_mode=False, 458 | obs_sensor=obs_sensor) 459 | 460 | # update the agent 461 | if not seed_until_step(self.global_step): 462 | metrics = self.agent.update(self.replay_iter, self.global_step, stage=3, use_sensor=IS_ADROIT) 463 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 464 | 465 | # take env step 466 | time_step = self.train_env.step(action) 467 | episode_reward += time_step.reward 468 | self.replay_storage.add(time_step) 469 | self.train_video_recorder.record(time_step.observation) 470 | episode_step += 1 471 | self._global_step += 1 472 | """here we move the experiment files to azure blob""" 473 | if (self.global_step==1) or (self.global_step == 10000) or (self.global_step % 100000 == 0): 474 | try: 475 | self.copy_to_azure() 476 | except Exception as e: 477 | print(e) 478 | 479 | """here save model for later""" 480 | if self.cfg.save_models: 481 | if self.global_frame in (2, 100000, 500000, 1000000, 2000000, 4000000): 482 | self.save_snapshot(suffix=str(self.global_frame)) 483 | 484 | try: 485 | self.copy_to_azure() 486 | except Exception as e: 487 | print(e) 488 | print("Stage 3 finished in %.2f hours." % ((time.time()-stage3_start_time) / 3600)) 489 | print("All stages finished in %.2f hrs. Work dir:" % ((time.time()-train_start_time)/3600)) 490 | print(self.work_dir) 491 | 492 | def save_snapshot(self, suffix=None): 493 | if suffix is None: 494 | save_name = 'snapshot.pt' 495 | else: 496 | save_name = 'snapshot' + suffix + '.pt' 497 | snapshot = self.work_dir / save_name 498 | keys_to_save = ['agent', 'timer', '_global_step', '_global_episode'] 499 | payload = {k: self.__dict__[k] for k in keys_to_save} 500 | with snapshot.open('wb') as f: 501 | torch.save(payload, f) 502 | print("snapshot saved to:", str(snapshot)) 503 | 504 | def load_snapshot(self): 505 | snapshot = self.work_dir / 'snapshot.pt' 506 | with snapshot.open('rb') as f: 507 | payload = torch.load(f) 508 | for k, v in payload.items(): 509 | self.__dict__[k] = v 510 | 511 | def copy_to_azure(self): 512 | amlt_path = os.getenv('AMLT_OUTPUT_DIR', '') 513 | if amlt_path != '': # if on cluster 514 | container_log_path = self.work_dir 515 | amlt_path_to = os.path.join(amlt_path, self.direct_folder_name) 516 | copy_tree(str(container_log_path), amlt_path_to, update=1) 517 | # copytree(str(container_log_path), amlt_path_to, dirs_exist_ok=True, ignore=ignore_patterns('*.npy')) 518 | print("Data copied to:", amlt_path_to) 519 | # else: # if at local 520 | # container_log_path = self.work_dir 521 | # amlt_path_to = '/vrl3data/logs' 522 | # copy_tree(str(container_log_path), amlt_path_to, update=1) 523 | 524 | @hydra.main(config_path='cfgs_adroit', config_name='config') 525 | def main(cfg): 526 | # TODO potentially check the task name and decide which libs to load here? 527 | W = Workspace 528 | root_dir = Path.cwd() 529 | workspace = W(cfg) 530 | snapshot = root_dir / 'snapshot.pt' 531 | if snapshot.exists(): 532 | print(f'resuming: {snapshot}') 533 | workspace.load_snapshot() 534 | workspace.train() 535 | 536 | 537 | if __name__ == '__main__': 538 | main() -------------------------------------------------------------------------------- /src/train_stage1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # this file is modified from the pytorch official tutorial 5 | # NOTE: stage 1 training code is currently being cleaned up 6 | 7 | import argparse 8 | import os 9 | import random 10 | import shutil 11 | import time 12 | import warnings 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.backends.cudnn as cudnn 18 | import torch.distributed as dist 19 | import torch.optim 20 | import torch.multiprocessing as mp 21 | import torch.utils.data 22 | import torch.utils.data.distributed 23 | import torchvision.transforms as transforms 24 | import torchvision.datasets as datasets 25 | import torchvision.models as models 26 | 27 | # TODO use another config file to indicate the location of the training data and also where to save models... 28 | # should also add an option to just test the accuracy of models... 29 | # we probably can test this locally .... 30 | 31 | from stage1_models import BasicBlock, ResNet84 32 | 33 | rl_model_names = ['resnet6_32channel', 'resnet10_32channel', 'resnet18_32channel', 34 | 'resnet6_64channel', 'resnet10_64channel', 'resnet18_64channel',] 35 | model_names = sorted(name for name in models.__dict__ 36 | if name.islower() and not name.startswith("__") 37 | and callable(models.__dict__[name])) + rl_model_names 38 | 39 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 40 | parser.add_argument('data', metavar='DIR', 41 | help='path to dataset') 42 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet10_32channel', 43 | choices=model_names, 44 | help='model architecture: ' + 45 | ' | '.join(model_names) + 46 | ' (default: resnet18)') 47 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 48 | help='number of data loading workers (default: 4)') 49 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 50 | help='number of total epochs to run') 51 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 52 | help='manual epoch number (useful on restarts)') 53 | parser.add_argument('-b', '--batch-size', default=256, type=int, 54 | metavar='N', 55 | help='mini-batch size (default: 256), this is the total ' 56 | 'batch size of all GPUs on the current node when ' 57 | 'using Data Parallel or Distributed Data Parallel') 58 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 59 | metavar='LR', help='initial learning rate', dest='lr') 60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 61 | help='momentum') 62 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 63 | metavar='W', help='weight decay (default: 1e-4)', 64 | dest='weight_decay') 65 | parser.add_argument('-p', '--print-freq', default=10, type=int, 66 | metavar='N', help='print frequency (default: 10)') 67 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 68 | help='path to latest checkpoint (default: none)') 69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 70 | help='evaluate model on validation set') 71 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 72 | help='use pre-trained model') 73 | parser.add_argument('--world-size', default=-1, type=int, 74 | help='number of nodes for distributed training') 75 | parser.add_argument('--rank', default=-1, type=int, 76 | help='node rank for distributed training') 77 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 78 | help='url used to set up distributed training') 79 | parser.add_argument('--dist-backend', default='nccl', type=str, 80 | help='distributed backend') 81 | parser.add_argument('--seed', default=None, type=int, 82 | help='seed for initializing training. ') 83 | parser.add_argument('--gpu', default=None, type=int, 84 | help='GPU id to use.') 85 | parser.add_argument('--multiprocessing-distributed', action='store_true', 86 | help='Use multi-processing distributed training to launch ' 87 | 'N processes per node, which has N GPUs. This is the ' 88 | 'fastest way to use PyTorch for either single node or ' 89 | 'multi node data parallel training') 90 | parser.add_argument('--debug', default=0, type=int, 91 | help='1 for debug mode, 2 for super fast debug mode') 92 | 93 | best_acc1 = 0 94 | INPUT_SIZE = 84 95 | VAL_RESIZE = 100 96 | 97 | def main(): 98 | print(model_names) 99 | 100 | args = parser.parse_args() 101 | 102 | if args.seed is not None: 103 | random.seed(args.seed) 104 | torch.manual_seed(args.seed) 105 | cudnn.deterministic = True 106 | warnings.warn('You have chosen to seed training. ' 107 | 'This will turn on the CUDNN deterministic setting, ' 108 | 'which can slow down your training considerably! ' 109 | 'You may see unexpected behavior when restarting ' 110 | 'from checkpoints.') 111 | 112 | if args.gpu is not None: 113 | warnings.warn('You have chosen a specific GPU. This will completely ' 114 | 'disable data parallelism.') 115 | 116 | if args.dist_url == "env://" and args.world_size == -1: 117 | args.world_size = int(os.environ["WORLD_SIZE"]) 118 | 119 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 120 | 121 | ngpus_per_node = torch.cuda.device_count() 122 | if args.multiprocessing_distributed: 123 | # Since we have ngpus_per_node processes per node, the total world_size 124 | # needs to be adjusted accordingly 125 | args.world_size = ngpus_per_node * args.world_size 126 | # Use torch.multiprocessing.spawn to launch distributed processes: the 127 | # main_worker process function 128 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 129 | else: 130 | # Simply call main_worker function 131 | main_worker(args.gpu, ngpus_per_node, args) 132 | 133 | 134 | def main_worker(gpu, ngpus_per_node, args): 135 | global best_acc1 136 | args.gpu = gpu 137 | 138 | if args.gpu is not None: 139 | print("Use GPU: {} for training".format(args.gpu)) 140 | 141 | if args.distributed: 142 | if args.dist_url == "env://" and args.rank == -1: 143 | args.rank = int(os.environ["RANK"]) 144 | if args.multiprocessing_distributed: 145 | # For multiprocessing distributed training, rank needs to be the 146 | # global rank among all the processes 147 | args.rank = args.rank * ngpus_per_node + gpu 148 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 149 | world_size=args.world_size, rank=args.rank) 150 | # create model 151 | if args.debug > 0: 152 | print("=> creating model for debug 2") 153 | # model = ResNet84(BasicBlock, [1, 1, 1, 1], num_classes=5) # 1, 1, 1, 1 will make a resnet10 154 | model = ResNet84(BasicBlock, [0, 0, 0, 0], num_classes=5) # 0, 0, 0, 0 make a convnet6 (5 conv layers in total lol) 155 | x = torch.rand((1, 3, 84, 84)).float() 156 | out = model(x) 157 | print(model) 158 | quit() 159 | 160 | # model = ResNetTest2(BasicBlock, [2, 2, 2, 2]) 161 | #model = Drq4Encoder((3, 84, 84), n_channel, 200) 162 | else: 163 | if args.pretrained: 164 | print("=> using pre-trained model '{}'".format(args.arch)) 165 | model = models.__dict__[args.arch](pretrained=True) 166 | else: 167 | print("=> creating model '{}'".format(args.arch)) 168 | if args.arch in rl_model_names: 169 | if args.arch == 'resnet18_32channel': 170 | model = ResNet84(BasicBlock, [2, 2, 2, 2], start_num_channel=32) # 1, 1, 1, 1 will make a resnet10 171 | elif args.arch == 'resnet10_32channel': 172 | model = ResNet84(BasicBlock, [1, 1, 1, 1], start_num_channel=32) # 1, 1, 1, 1 will make a resnet10 173 | elif args.arch == 'resnet6_32channel': 174 | model = ResNet84(BasicBlock, [0, 0, 0, 0], start_num_channel=32) # resnet 6 (actually not even resnet because no skip connection) 175 | elif args.arch == 'resnet18_64channel': 176 | model = ResNet84(BasicBlock, [2, 2, 2, 2], start_num_channel=64) 177 | elif args.arch == 'resnet10_64channel': 178 | model = ResNet84(BasicBlock, [1, 1, 1, 1], start_num_channel=64) 179 | elif args.arch == 'resnet6_64channel': 180 | model = ResNet84(BasicBlock, [0, 0, 0, 0], start_num_channel=64) 181 | else: 182 | print("specialized model not yet implemented") 183 | quit() 184 | else: 185 | model = models.__dict__[args.arch]() 186 | 187 | if not torch.cuda.is_available(): 188 | print('using CPU, this will be slow') 189 | elif args.distributed: 190 | print("distributed") 191 | # For multiprocessing distributed, DistributedDataParallel constructor 192 | # should always set the single device scope, otherwise, 193 | # DistributedDataParallel will use all available devices. 194 | if args.gpu is not None: 195 | torch.cuda.set_device(args.gpu) 196 | model.cuda(args.gpu) 197 | # When using a single GPU per process and per 198 | # DistributedDataParallel, we need to divide the batch size 199 | # ourselves based on the total number of GPUs we have 200 | args.batch_size = int(args.batch_size / ngpus_per_node) 201 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 202 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 203 | else: 204 | model.cuda() 205 | # DistributedDataParallel will divide and allocate batch_size to all 206 | # available GPUs if device_ids are not set 207 | model = torch.nn.parallel.DistributedDataParallel(model) 208 | elif args.gpu is not None: 209 | print("use gpu:", args.gpu) 210 | torch.cuda.set_device(args.gpu) 211 | model = model.cuda(args.gpu) 212 | else: 213 | print("data parallel") 214 | # DataParallel will divide and allocate batch_size to all available GPUs 215 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 216 | model.features = torch.nn.DataParallel(model.features) 217 | model.cuda() 218 | else: 219 | model = torch.nn.DataParallel(model).cuda() 220 | 221 | # define loss function (criterion) and optimizer 222 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 223 | 224 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 225 | momentum=args.momentum, 226 | weight_decay=args.weight_decay) 227 | 228 | # optionally resume from a checkpoint 229 | if args.resume: 230 | if os.path.isfile(args.resume): 231 | print("=> loading checkpoint '{}'".format(args.resume)) 232 | if args.gpu is None: 233 | checkpoint = torch.load(args.resume) 234 | else: 235 | # Map model to be loaded to specified single gpu. 236 | loc = 'cuda:{}'.format(args.gpu) 237 | checkpoint = torch.load(args.resume, map_location=loc) 238 | args.start_epoch = checkpoint['epoch'] 239 | best_acc1 = checkpoint['best_acc1'] 240 | if args.gpu is not None: 241 | # best_acc1 may be from a checkpoint from a different GPU 242 | best_acc1 = best_acc1.to(args.gpu) 243 | model.load_state_dict(checkpoint['state_dict']) 244 | optimizer.load_state_dict(checkpoint['optimizer']) 245 | print("=> loaded checkpoint '{}' (epoch {})" 246 | .format(args.resume, checkpoint['epoch'])) 247 | else: 248 | print("=> no checkpoint found at '{}'".format(args.resume)) 249 | 250 | cudnn.benchmark = True 251 | 252 | # Data loading code 253 | traindir = os.path.join(args.data, 'train') 254 | valdir = os.path.join(args.data, 'val') 255 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 256 | std=[0.229, 0.224, 0.225]) 257 | 258 | print("train directory is:", traindir) 259 | print("val directory is:", valdir) 260 | 261 | train_dataset = datasets.ImageFolder( 262 | traindir, 263 | transforms.Compose([ 264 | transforms.RandomResizedCrop(INPUT_SIZE), 265 | transforms.RandomHorizontalFlip(), 266 | transforms.ToTensor(), 267 | normalize, 268 | ])) 269 | 270 | print("data set ready") 271 | 272 | if args.distributed: 273 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 274 | else: 275 | train_sampler = None 276 | 277 | train_loader = torch.utils.data.DataLoader( 278 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 279 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 280 | 281 | val_loader = torch.utils.data.DataLoader( 282 | datasets.ImageFolder(valdir, transforms.Compose([ 283 | # transforms.Resize(256), 284 | transforms.Resize(VAL_RESIZE), 285 | transforms.CenterCrop(INPUT_SIZE), 286 | transforms.ToTensor(), 287 | normalize, 288 | ])), 289 | batch_size=args.batch_size, shuffle=False, 290 | num_workers=args.workers, pin_memory=True) 291 | 292 | if args.evaluate: 293 | validate(val_loader, model, criterion, args) 294 | return 295 | 296 | for epoch in range(args.start_epoch, args.epochs): 297 | print(epoch) 298 | epoch_start_time = time.time() 299 | 300 | if args.distributed: 301 | train_sampler.set_epoch(epoch) 302 | adjust_learning_rate(optimizer, epoch, args) 303 | 304 | # train for one epoch 305 | train(train_loader, model, criterion, optimizer, epoch, args) 306 | 307 | # evaluate on validation set 308 | acc1 = validate(val_loader, model, criterion, args) 309 | 310 | # remember best acc@1 and save checkpoint 311 | is_best = acc1 > best_acc1 312 | best_acc1 = max(acc1, best_acc1) 313 | 314 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 315 | and args.rank % ngpus_per_node == 0): 316 | save_checkpoint({ 317 | 'epoch': epoch + 1, 318 | 'arch': args.arch, 319 | 'state_dict': model.state_dict(), 320 | 'best_acc1': best_acc1, 321 | 'optimizer' : optimizer.state_dict(), 322 | }, is_best, 323 | save_name_prefix=args.arch) 324 | 325 | epoch_end_time = time.time() - epoch_start_time 326 | print("epoch finished in %.3f hour" % (epoch_end_time/3600)) 327 | 328 | def train(train_loader, model, criterion, optimizer, epoch, args): 329 | batch_time = AverageMeter('Time', ':6.3f') 330 | data_time = AverageMeter('Data', ':6.3f') 331 | losses = AverageMeter('Loss', ':.4e') 332 | top1 = AverageMeter('Acc@1', ':6.2f') 333 | top5 = AverageMeter('Acc@5', ':6.2f') 334 | progress = ProgressMeter( 335 | len(train_loader), 336 | [batch_time, data_time, losses, top1, top5], 337 | prefix="Epoch: [{}]".format(epoch)) 338 | 339 | # switch to train mode 340 | model.train() 341 | 342 | end = time.time() 343 | for i, (images, target) in enumerate(train_loader): 344 | # measure data loading time 345 | data_time.update(time.time() - end) 346 | 347 | if args.gpu is not None: 348 | images = images.cuda(args.gpu, non_blocking=True) 349 | if torch.cuda.is_available(): 350 | target = target.cuda(args.gpu, non_blocking=True) 351 | 352 | # compute output 353 | output = model(images) 354 | loss = criterion(output, target) 355 | 356 | # measure accuracy and record loss 357 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 358 | losses.update(loss.item(), images.size(0)) 359 | top1.update(acc1[0], images.size(0)) 360 | top5.update(acc5[0], images.size(0)) 361 | 362 | # compute gradient and do SGD step 363 | optimizer.zero_grad() 364 | loss.backward() 365 | optimizer.step() 366 | 367 | # measure elapsed time 368 | batch_time.update(time.time() - end) 369 | end = time.time() 370 | 371 | if i % args.print_freq == 0: 372 | progress.display(i) 373 | 374 | 375 | def validate(val_loader, model, criterion, args): 376 | batch_time = AverageMeter('Time', ':6.3f') 377 | losses = AverageMeter('Loss', ':.4e') 378 | top1 = AverageMeter('Acc@1', ':6.2f') 379 | top5 = AverageMeter('Acc@5', ':6.2f') 380 | progress = ProgressMeter( 381 | len(val_loader), 382 | [batch_time, losses, top1, top5], 383 | prefix='Test: ') 384 | 385 | # switch to evaluate mode 386 | model.eval() 387 | 388 | with torch.no_grad(): 389 | end = time.time() 390 | for i, (images, target) in enumerate(val_loader): 391 | if args.gpu is not None: 392 | images = images.cuda(args.gpu, non_blocking=True) 393 | if torch.cuda.is_available(): 394 | target = target.cuda(args.gpu, non_blocking=True) 395 | 396 | # compute output 397 | output = model(images) 398 | loss = criterion(output, target) 399 | 400 | # measure accuracy and record loss 401 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 402 | losses.update(loss.item(), images.size(0)) 403 | top1.update(acc1[0], images.size(0)) 404 | top5.update(acc5[0], images.size(0)) 405 | 406 | # measure elapsed time 407 | batch_time.update(time.time() - end) 408 | end = time.time() 409 | 410 | if i % args.print_freq == 0: 411 | progress.display(i) 412 | 413 | # TODO: this should also be done with the ProgressMeter 414 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 415 | .format(top1=top1, top5=top5)) 416 | 417 | return top1.avg 418 | 419 | 420 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_name_prefix = ''): 421 | save_name = save_name_prefix + '_' + filename 422 | torch.save(state, save_name) 423 | if is_best: 424 | best_model_save_name = save_name_prefix + '_' + 'model_best.pth.tar' 425 | shutil.copyfile(save_name, best_model_save_name) 426 | 427 | class AverageMeter(object): 428 | """Computes and stores the average and current value""" 429 | def __init__(self, name, fmt=':f'): 430 | self.name = name 431 | self.fmt = fmt 432 | self.reset() 433 | 434 | def reset(self): 435 | self.val = 0 436 | self.avg = 0 437 | self.sum = 0 438 | self.count = 0 439 | 440 | def update(self, val, n=1): 441 | self.val = val 442 | self.sum += val * n 443 | self.count += n 444 | self.avg = self.sum / self.count 445 | 446 | def __str__(self): 447 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 448 | return fmtstr.format(**self.__dict__) 449 | 450 | 451 | class ProgressMeter(object): 452 | def __init__(self, num_batches, meters, prefix=""): 453 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 454 | self.meters = meters 455 | self.prefix = prefix 456 | 457 | def display(self, batch): 458 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 459 | entries += [str(meter) for meter in self.meters] 460 | print('\t'.join(entries)) 461 | 462 | def _get_batch_fmtstr(self, num_batches): 463 | num_digits = len(str(num_batches // 1)) 464 | fmt = '{:' + str(num_digits) + 'd}' 465 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 466 | 467 | 468 | def adjust_learning_rate(optimizer, epoch, args): 469 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 470 | lr = args.lr * (0.1 ** (epoch // 30)) 471 | for param_group in optimizer.param_groups: 472 | param_group['lr'] = lr 473 | 474 | 475 | def accuracy(output, target, topk=(1,)): 476 | """Computes the accuracy over the k top predictions for the specified values of k""" 477 | with torch.no_grad(): 478 | maxk = max(topk) 479 | batch_size = target.size(0) 480 | 481 | _, pred = output.topk(maxk, 1, True, True) 482 | pred = pred.t() 483 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 484 | 485 | res = [] 486 | for k in topk: 487 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 488 | res.append(correct_k.mul_(100.0 / batch_size)) 489 | return res 490 | 491 | 492 | if __name__ == '__main__': 493 | main() -------------------------------------------------------------------------------- /src/transfer_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from torchvision import datasets, models, transforms 5 | 6 | def set_parameter_requires_grad(model, n_resblock_finetune): 7 | assert n_resblock_finetune in (0, 1, 2, 3, 4, 5) 8 | for param in model.parameters(): 9 | param.requires_grad = False 10 | 11 | for name, param in model.named_parameters(): 12 | condition = (n_resblock_finetune >= 1 and 'layer4' in name) or (n_resblock_finetune >= 2 and 'layer3' in name) or \ 13 | (n_resblock_finetune >= 3 and 'layer2' in name) or (n_resblock_finetune >= 4 and 'layer1' in name) or \ 14 | (n_resblock_finetune >= 5) 15 | 16 | if condition: 17 | param.requires_grad = True 18 | 19 | for name, param in model.named_parameters(): 20 | if 'bn' in name: 21 | param.requires_grad = False 22 | 23 | def initialize_model(model_name, n_resblock_finetune, use_pretrained=True): 24 | # Initialize these variables which will be set in this if statement. Each of these 25 | # variables is model specific. 26 | model_ft = None 27 | input_size = 0 28 | 29 | if model_name == "resnet18": 30 | """ Resnet18 31 | """ 32 | model_ft = models.resnet18(pretrained=use_pretrained) 33 | set_parameter_requires_grad(model_ft, n_resblock_finetune) 34 | feature_size = model_ft.fc.in_features 35 | input_size = 224 36 | elif model_name == 'resnet34': 37 | model_ft = models.resnet34(pretrained=use_pretrained) 38 | set_parameter_requires_grad(model_ft, n_resblock_finetune) 39 | feature_size = model_ft.fc.in_features 40 | input_size = 224 41 | else: 42 | print("Invalid model name, exiting...") 43 | exit() 44 | 45 | return model_ft, input_size, feature_size 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import random 6 | import re 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from omegaconf import OmegaConf 14 | from torch import distributions as pyd 15 | from torch.distributions.utils import _standard_normal 16 | 17 | 18 | class eval_mode: 19 | def __init__(self, *models): 20 | self.models = models 21 | 22 | def __enter__(self): 23 | self.prev_states = [] 24 | for model in self.models: 25 | self.prev_states.append(model.training) 26 | model.train(False) 27 | 28 | def __exit__(self, *args): 29 | for model, state in zip(self.models, self.prev_states): 30 | model.train(state) 31 | return False 32 | 33 | 34 | def set_seed_everywhere(seed): 35 | torch.manual_seed(seed) 36 | if torch.cuda.is_available(): 37 | torch.cuda.manual_seed_all(seed) 38 | np.random.seed(seed) 39 | random.seed(seed) 40 | 41 | 42 | def soft_update_params(net, target_net, tau): 43 | for param, target_param in zip(net.parameters(), target_net.parameters()): 44 | target_param.data.copy_(tau * param.data + 45 | (1 - tau) * target_param.data) 46 | 47 | 48 | def to_torch(xs, device): 49 | return tuple(torch.as_tensor(x, device=device) for x in xs) 50 | 51 | 52 | def weight_init(m): 53 | if isinstance(m, nn.Linear): 54 | nn.init.orthogonal_(m.weight.data) 55 | if hasattr(m.bias, 'data'): 56 | m.bias.data.fill_(0.0) 57 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 58 | gain = nn.init.calculate_gain('relu') 59 | nn.init.orthogonal_(m.weight.data, gain) 60 | if hasattr(m.bias, 'data'): 61 | m.bias.data.fill_(0.0) 62 | 63 | 64 | class Until: 65 | def __init__(self, until, action_repeat=1): 66 | self._until = until 67 | self._action_repeat = action_repeat 68 | 69 | def __call__(self, step): 70 | if self._until is None: 71 | return True 72 | until = self._until // self._action_repeat 73 | return step < until 74 | 75 | 76 | class Every: 77 | def __init__(self, every, action_repeat=1): 78 | self._every = every 79 | self._action_repeat = action_repeat 80 | 81 | def __call__(self, step): 82 | if self._every is None: 83 | return False 84 | every = self._every // self._action_repeat 85 | if step % every == 0: 86 | return True 87 | return False 88 | 89 | 90 | class Timer: 91 | def __init__(self): 92 | self._start_time = time.time() 93 | self._last_time = time.time() 94 | 95 | def reset(self): 96 | elapsed_time = time.time() - self._last_time 97 | self._last_time = time.time() 98 | total_time = time.time() - self._start_time 99 | return elapsed_time, total_time 100 | 101 | def total_time(self): 102 | return time.time() - self._start_time 103 | 104 | 105 | class TruncatedNormal(pyd.Normal): 106 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 107 | super().__init__(loc, scale, validate_args=False) 108 | self.low = low 109 | self.high = high 110 | self.eps = eps 111 | 112 | def _clamp(self, x): 113 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 114 | x = x - x.detach() + clamped_x.detach() 115 | return x 116 | 117 | def sample(self, clip=None, sample_shape=torch.Size()): 118 | shape = self._extended_shape(sample_shape) 119 | eps = _standard_normal(shape, 120 | dtype=self.loc.dtype, 121 | device=self.loc.device) 122 | eps *= self.scale 123 | if clip is not None: 124 | eps = torch.clamp(eps, -clip, clip) 125 | x = self.loc + eps 126 | return self._clamp(x) 127 | 128 | 129 | def schedule(schdl, step): 130 | try: 131 | return float(schdl) 132 | except ValueError: 133 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 134 | if match: 135 | init, final, duration = [float(g) for g in match.groups()] 136 | mix = np.clip(step / duration, 0.0, 1.0) 137 | return (1.0 - mix) * init + mix * final 138 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 139 | if match: 140 | init, final1, duration1, final2, duration2 = [ 141 | float(g) for g in match.groups() 142 | ] 143 | if step <= duration1: 144 | mix = np.clip(step / duration1, 0.0, 1.0) 145 | return (1.0 - mix) * init + mix * final1 146 | else: 147 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 148 | return (1.0 - mix) * final1 + mix * final2 149 | raise NotImplementedError(schdl) 150 | -------------------------------------------------------------------------------- /src/video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import cv2 6 | import imageio 7 | import numpy as np 8 | 9 | class VideoRecorder: 10 | def __init__(self, root_dir, render_size=256, fps=20): 11 | if root_dir is not None: 12 | self.save_dir = root_dir / 'eval_video' 13 | self.save_dir.mkdir(exist_ok=True) 14 | else: 15 | self.save_dir = None 16 | 17 | self.render_size = render_size 18 | self.fps = fps 19 | self.frames = [] 20 | 21 | def init(self, env, enabled=True): 22 | self.frames = [] 23 | self.enabled = self.save_dir is not None and enabled 24 | self.record(env) 25 | 26 | def record(self, env): 27 | if self.enabled: 28 | if hasattr(env, 'physics'): 29 | frame = env.physics.render(height=self.render_size, 30 | width=self.render_size, 31 | camera_id=0) 32 | else: 33 | frame = env.render() 34 | self.frames.append(frame) 35 | 36 | def save(self, file_name): 37 | if self.enabled: 38 | path = self.save_dir / file_name 39 | imageio.mimsave(str(path), self.frames, fps=self.fps) 40 | 41 | 42 | class TrainVideoRecorder: 43 | def __init__(self, root_dir, render_size=256, fps=20): 44 | if root_dir is not None: 45 | self.save_dir = root_dir / 'train_video' 46 | self.save_dir.mkdir(exist_ok=True) 47 | else: 48 | self.save_dir = None 49 | 50 | self.render_size = render_size 51 | self.fps = fps 52 | self.frames = [] 53 | 54 | def init(self, obs, enabled=True): 55 | self.frames = [] 56 | self.enabled = self.save_dir is not None and enabled 57 | self.record(obs) 58 | 59 | def record(self, obs): 60 | if self.enabled: 61 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0), 62 | dsize=(self.render_size, self.render_size), 63 | interpolation=cv2.INTER_CUBIC) 64 | self.frames.append(frame) 65 | 66 | def save(self, file_name): 67 | if self.enabled: 68 | path = self.save_dir / file_name 69 | imageio.mimsave(str(path), self.frames, fps=self.fps) 70 | -------------------------------------------------------------------------------- /src/vrl3_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import datasets, models, transforms 9 | from transfer_util import initialize_model 10 | from stage1_models import BasicBlock, ResNet84 11 | import os 12 | import copy 13 | from PIL import Image 14 | import platform 15 | from numbers import Number 16 | import utils 17 | 18 | class RandomShiftsAug(nn.Module): 19 | def __init__(self, pad): 20 | super().__init__() 21 | self.pad = pad 22 | 23 | def forward(self, x): 24 | n, c, h, w = x.size() 25 | assert h == w 26 | padding = tuple([self.pad] * 4) 27 | x = F.pad(x, padding, 'replicate') 28 | eps = 1.0 / (h + 2 * self.pad) 29 | arange = torch.linspace(-1.0 + eps, 30 | 1.0 - eps, 31 | h + 2 * self.pad, 32 | device=x.device, 33 | dtype=x.dtype)[:h] 34 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 35 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 36 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 37 | 38 | shift = torch.randint(0, 39 | 2 * self.pad + 1, 40 | size=(n, 1, 1, 2), 41 | device=x.device, 42 | dtype=x.dtype) 43 | shift *= 2.0 / (h + 2 * self.pad) 44 | 45 | grid = base_grid + shift 46 | return F.grid_sample(x, 47 | grid, 48 | padding_mode='zeros', 49 | align_corners=False) 50 | 51 | class Identity(nn.Module): 52 | def __init__(self, input_placeholder=None): 53 | super(Identity, self).__init__() 54 | 55 | def forward(self, x): 56 | return x 57 | 58 | class RLEncoder(nn.Module): 59 | def __init__(self, obs_shape, model_name, device): 60 | super().__init__() 61 | # a wrapper over a non-RL encoder model 62 | self.device = device 63 | assert len(obs_shape) == 3 64 | self.n_input_channel = obs_shape[0] 65 | assert self.n_input_channel % 3 == 0 66 | self.n_images = self.n_input_channel // 3 67 | self.model = self.init_model(model_name) 68 | self.model.fc = Identity() 69 | self.repr_dim = self.model.get_feature_size() 70 | 71 | self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406), 72 | (0.229, 0.224, 0.225)) 73 | self.channel_mismatch = True 74 | 75 | def init_model(self, model_name): 76 | # model name is e.g. resnet6_32channel 77 | n_layer_string, n_channel_string = model_name.split('_') 78 | layer_string_to_layer_list = { 79 | 'resnet6': [0, 0, 0, 0], 80 | 'resnet10': [1, 1, 1, 1], 81 | 'resnet18': [2, 2, 2, 2], 82 | } 83 | channel_string_to_n_channel = { 84 | '32channel': 32, 85 | '64channel': 64, 86 | } 87 | layer_list = layer_string_to_layer_list[n_layer_string] 88 | start_num_channel = channel_string_to_n_channel[n_channel_string] 89 | return ResNet84(BasicBlock, layer_list, start_num_channel=start_num_channel).to(self.device) 90 | 91 | def expand_first_layer(self): 92 | # convolutional channel expansion to deal with input mismatch 93 | multiplier = self.n_images 94 | self.model.conv1.weight.data = self.model.conv1.weight.data.repeat(1,multiplier,1,1) / multiplier 95 | means = (0.485, 0.456, 0.406) * multiplier 96 | stds = (0.229, 0.224, 0.225) * multiplier 97 | self.normalize_op = transforms.Normalize(means, stds) 98 | self.channel_mismatch = False 99 | 100 | def freeze_bn(self): 101 | # freeze batch norm layers (VRL3 ablation shows modifying how 102 | # batch norm is trained does not affect performance) 103 | for module in self.model.modules(): 104 | if isinstance(module, nn.BatchNorm2d): 105 | if hasattr(module, 'weight'): 106 | module.weight.requires_grad_(False) 107 | if hasattr(module, 'bias'): 108 | module.bias.requires_grad_(False) 109 | module.eval() 110 | 111 | def get_parameters_that_require_grad(self): 112 | params = [] 113 | for name, param in self.named_parameters(): 114 | if param.requires_grad == True: 115 | params.append(param) 116 | return params 117 | 118 | def transform_obs_tensor_batch(self, obs): 119 | # transform obs batch before put into the pretrained resnet 120 | new_obs = self.normalize_op(obs.float()/255) 121 | return new_obs 122 | 123 | def _forward_impl(self, x): 124 | x = self.model.get_features(x) 125 | return x 126 | 127 | def forward(self, obs): 128 | o = self.transform_obs_tensor_batch(obs) 129 | h = self._forward_impl(o) 130 | return h 131 | 132 | class Stage3ShallowEncoder(nn.Module): 133 | def __init__(self, obs_shape, n_channel): 134 | super().__init__() 135 | 136 | assert len(obs_shape) == 3 137 | self.repr_dim = n_channel * 35 * 35 138 | 139 | self.n_input_channel = obs_shape[0] 140 | self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2) 141 | self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1) 142 | self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1) 143 | self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1) 144 | self.relu = nn.ReLU(inplace=True) 145 | 146 | # TODO here add prediction head so we can do contrastive learning... 147 | 148 | self.apply(utils.weight_init) 149 | self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406, 0.485, 0.456, 0.406, 0.485, 0.456, 0.406), 150 | (0.229, 0.224, 0.225, 0.229, 0.224, 0.225, 0.229, 0.224, 0.225)) 151 | 152 | self.compress = nn.Sequential(nn.Linear(self.repr_dim, 50), nn.LayerNorm(50), nn.Tanh()) 153 | self.pred_layer = nn.Linear(50, 50, bias=False) 154 | 155 | def transform_obs_tensor_batch(self, obs): 156 | # transform obs batch before put into the pretrained resnet 157 | # correct order might be first augment, then resize, then normalize 158 | # obs = F.interpolate(obs, size=self.pretrained_model_input_size) 159 | new_obs = obs / 255.0 - 0.5 160 | # new_obs = self.normalize_op(new_obs) 161 | return new_obs 162 | 163 | def _forward_impl(self, x): 164 | x = self.relu(self.conv1(x)) 165 | x = self.relu(self.conv2(x)) 166 | x = self.relu(self.conv3(x)) 167 | x = self.relu(self.conv4(x)) 168 | return x 169 | 170 | def forward(self, obs): 171 | o = self.transform_obs_tensor_batch(obs) 172 | h = self._forward_impl(o) 173 | h = h.view(h.shape[0], -1) 174 | return h 175 | 176 | def get_anchor_output(self, obs, actions=None): 177 | # typically go through conv and then compression layer and then a mlp 178 | # used for UL update 179 | conv_out = self.forward(obs) 180 | compressed = self.compress(conv_out) 181 | pred = self.pred_layer(compressed) 182 | return pred, conv_out 183 | 184 | def get_positive_output(self, obs): 185 | # typically go through conv, compression 186 | # used for UL update 187 | conv_out = self.forward(obs) 188 | compressed = self.compress(conv_out) 189 | return compressed 190 | 191 | class Encoder(nn.Module): 192 | def __init__(self, obs_shape, n_channel): 193 | super().__init__() 194 | 195 | assert len(obs_shape) == 3 196 | self.repr_dim = n_channel * 35 * 35 197 | 198 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], n_channel, 3, stride=2), 199 | nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1), 200 | nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1), 201 | nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1), 202 | nn.ReLU()) 203 | 204 | self.apply(utils.weight_init) 205 | 206 | def forward(self, obs): 207 | obs = obs / 255.0 - 0.5 208 | h = self.convnet(obs) 209 | h = h.view(h.shape[0], -1) 210 | return h 211 | 212 | class IdentityEncoder(nn.Module): 213 | def __init__(self, obs_shape): 214 | super().__init__() 215 | 216 | assert len(obs_shape) == 1 217 | self.repr_dim = obs_shape[0] 218 | 219 | def forward(self, obs): 220 | return obs 221 | 222 | class Actor(nn.Module): 223 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 224 | super().__init__() 225 | 226 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 227 | nn.LayerNorm(feature_dim), nn.Tanh()) 228 | 229 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 230 | nn.ReLU(inplace=True), 231 | nn.Linear(hidden_dim, hidden_dim), 232 | nn.ReLU(inplace=True), 233 | nn.Linear(hidden_dim, action_shape[0])) 234 | 235 | self.action_shift=0 236 | self.action_scale=1 237 | self.apply(utils.weight_init) 238 | 239 | def forward(self, obs, std): 240 | h = self.trunk(obs) 241 | 242 | mu = self.policy(h) 243 | mu = torch.tanh(mu) 244 | mu = mu * self.action_scale + self.action_shift 245 | std = torch.ones_like(mu) * std 246 | 247 | dist = utils.TruncatedNormal(mu, std) 248 | return dist 249 | 250 | def forward_with_pretanh(self, obs, std): 251 | h = self.trunk(obs) 252 | 253 | mu = self.policy(h) 254 | pretanh = mu 255 | mu = torch.tanh(mu) 256 | mu = mu * self.action_scale + self.action_shift 257 | std = torch.ones_like(mu) * std 258 | 259 | dist = utils.TruncatedNormal(mu, std) 260 | return dist, pretanh 261 | 262 | class Critic(nn.Module): 263 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 264 | super().__init__() 265 | 266 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 267 | nn.LayerNorm(feature_dim), nn.Tanh()) 268 | 269 | self.Q1 = nn.Sequential( 270 | nn.Linear(feature_dim + action_shape[0], hidden_dim), 271 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 272 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 273 | 274 | self.Q2 = nn.Sequential( 275 | nn.Linear(feature_dim + action_shape[0], hidden_dim), 276 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 277 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 278 | 279 | self.apply(utils.weight_init) 280 | 281 | def forward(self, obs, action): 282 | h = self.trunk(obs) 283 | h_action = torch.cat([h, action], dim=-1) 284 | q1 = self.Q1(h_action) 285 | q2 = self.Q2(h_action) 286 | 287 | return q1, q2 288 | 289 | class VRL3Agent: 290 | def __init__(self, obs_shape, action_shape, device, use_sensor, lr, feature_dim, 291 | hidden_dim, critic_target_tau, num_expl_steps, 292 | update_every_steps, stddev_clip, use_tb, use_data_aug, encoder_lr_scale, 293 | stage1_model_name, safe_q_target_factor, safe_q_threshold, pretanh_penalty, pretanh_threshold, 294 | stage2_update_encoder, cql_weight, cql_temp, cql_n_random, stage2_std, stage2_bc_weight, 295 | stage3_update_encoder, std0, std1, std_n_decay, 296 | stage3_bc_lam0, stage3_bc_lam1): 297 | self.device = device 298 | self.critic_target_tau = critic_target_tau 299 | self.update_every_steps = update_every_steps 300 | self.use_tb = use_tb 301 | self.num_expl_steps = num_expl_steps 302 | 303 | self.stage2_std = stage2_std 304 | self.stage2_update_encoder = stage2_update_encoder 305 | 306 | if std1 > std0: 307 | std1 = std0 308 | self.stddev_schedule = "linear(%s,%s,%s)" % (str(std0), str(std1), str(std_n_decay)) 309 | 310 | self.stddev_clip = stddev_clip 311 | self.use_data_aug = use_data_aug 312 | self.safe_q_target_factor = safe_q_target_factor 313 | self.q_threshold = safe_q_threshold 314 | self.pretanh_penalty = pretanh_penalty 315 | 316 | self.cql_temp = cql_temp 317 | self.cql_weight = cql_weight 318 | self.cql_n_random = cql_n_random 319 | 320 | self.pretanh_threshold = pretanh_threshold 321 | 322 | self.stage2_bc_weight = stage2_bc_weight 323 | self.stage3_bc_lam0 = stage3_bc_lam0 324 | self.stage3_bc_lam1 = stage3_bc_lam1 325 | 326 | if stage3_update_encoder and encoder_lr_scale > 0 and len(obs_shape) > 1: 327 | self.stage3_update_encoder = True 328 | else: 329 | self.stage3_update_encoder = False 330 | 331 | self.encoder = RLEncoder(obs_shape, stage1_model_name, device).to(device) 332 | 333 | self.act_dim = action_shape[0] 334 | 335 | if use_sensor: 336 | downstream_input_dim = self.encoder.repr_dim + 24 337 | else: 338 | downstream_input_dim = self.encoder.repr_dim 339 | 340 | self.actor = Actor(downstream_input_dim, action_shape, feature_dim, 341 | hidden_dim).to(device) 342 | self.critic = Critic(downstream_input_dim, action_shape, feature_dim, 343 | hidden_dim).to(device) 344 | self.critic_target = Critic(downstream_input_dim, action_shape, 345 | feature_dim, hidden_dim).to(device) 346 | self.critic_target.load_state_dict(self.critic.state_dict()) 347 | 348 | # optimizers 349 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 350 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 351 | 352 | encoder_lr = lr * encoder_lr_scale 353 | """ set up encoder optimizer """ 354 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=encoder_lr) 355 | # data augmentation 356 | self.aug = RandomShiftsAug(pad=4) 357 | self.train() 358 | self.critic_target.train() 359 | 360 | def load_pretrained_encoder(self, model_path, verbose=True): 361 | if verbose: 362 | print("Trying to load pretrained model from:", model_path) 363 | checkpoint = torch.load(model_path, map_location=torch.device(self.device)) 364 | state_dict = checkpoint['state_dict'] 365 | 366 | pretrained_dict = {} 367 | # remove `module.` if model was pretrained with distributed mode 368 | for k, v in state_dict.items(): 369 | if 'module.' in k: 370 | name = k[7:] 371 | else: 372 | name = k 373 | pretrained_dict[name] = v 374 | self.encoder.model.load_state_dict(pretrained_dict, strict=False) 375 | if verbose: 376 | print("Pretrained model loaded!") 377 | 378 | def switch_to_RL_stages(self, verbose=True): 379 | # run convolutional channel expansion to match input shape 380 | self.encoder.expand_first_layer() 381 | if verbose: 382 | print("Convolutional channel expansion finished: now can take in %d images as input." % self.encoder.n_images) 383 | 384 | def train(self, training=True): 385 | self.training = training 386 | self.encoder.train(training) 387 | self.actor.train(training) 388 | self.critic.train(training) 389 | 390 | def act(self, obs, step, eval_mode, obs_sensor=None, is_tensor_input=False, force_action_std=None): 391 | # eval_mode should be False when taking an exploration action in stage 3 392 | # eval_mode should be True when evaluate agent performance 393 | if force_action_std == None: 394 | stddev = utils.schedule(self.stddev_schedule, step) 395 | if step < self.num_expl_steps and not eval_mode: 396 | action = np.random.uniform(0, 1, (self.act_dim,)).astype(np.float32) 397 | return action 398 | else: 399 | stddev = force_action_std 400 | 401 | if is_tensor_input: 402 | obs = self.encoder(obs) 403 | else: 404 | obs = torch.as_tensor(obs, device=self.device) 405 | obs = self.encoder(obs.unsqueeze(0)) 406 | 407 | if obs_sensor is not None: 408 | obs_sensor = torch.as_tensor(obs_sensor, device=self.device) 409 | obs_sensor = obs_sensor.unsqueeze(0) 410 | obs_combined = torch.cat([obs, obs_sensor], dim=1) 411 | else: 412 | obs_combined = obs 413 | 414 | dist = self.actor(obs_combined, stddev) 415 | if eval_mode: 416 | action = dist.mean 417 | else: 418 | action = dist.sample(clip=None) 419 | if step < self.num_expl_steps: 420 | action.uniform_(-1.0, 1.0) 421 | return action.cpu().numpy()[0] 422 | 423 | def update(self, replay_iter, step, stage, use_sensor): 424 | # for stage 2 and 3, we use the same functions but with different hyperparameters 425 | assert stage in (2, 3) 426 | metrics = dict() 427 | 428 | if stage == 2: 429 | update_encoder = self.stage2_update_encoder 430 | stddev = self.stage2_std 431 | conservative_loss_weight = self.cql_weight 432 | bc_weight = self.stage2_bc_weight 433 | 434 | if stage == 3: 435 | if step % self.update_every_steps != 0: 436 | return metrics 437 | update_encoder = self.stage3_update_encoder 438 | 439 | stddev = utils.schedule(self.stddev_schedule, step) 440 | conservative_loss_weight = 0 441 | 442 | # compute stage 3 BC weight 443 | bc_data_per_iter = 40000 444 | i_iter = step // bc_data_per_iter 445 | bc_weight = self.stage3_bc_lam0 * self.stage3_bc_lam1 ** i_iter 446 | 447 | # batch data 448 | batch = next(replay_iter) 449 | if use_sensor: # TODO might want to...? 450 | obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next = utils.to_torch(batch, self.device) 451 | else: 452 | obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device) 453 | obs_sensor, obs_sensor_next = None, None 454 | 455 | # augment 456 | if self.use_data_aug: 457 | obs = self.aug(obs.float()) 458 | next_obs = self.aug(next_obs.float()) 459 | else: 460 | obs = obs.float() 461 | next_obs = next_obs.float() 462 | 463 | # encode 464 | if update_encoder: 465 | obs = self.encoder(obs) 466 | else: 467 | with torch.no_grad(): 468 | obs = self.encoder(obs) 469 | 470 | with torch.no_grad(): 471 | next_obs = self.encoder(next_obs) 472 | 473 | # concatenate obs with additional sensor observation if needed 474 | obs_combined = torch.cat([obs, obs_sensor], dim=1) if obs_sensor is not None else obs 475 | obs_next_combined = torch.cat([next_obs, obs_sensor_next], dim=1) if obs_sensor_next is not None else next_obs 476 | 477 | # update critic 478 | metrics.update(self.update_critic_vrl3(obs_combined, action, reward, discount, obs_next_combined, 479 | stddev, update_encoder, conservative_loss_weight)) 480 | 481 | # update actor, following previous works, we do not use actor gradient for encoder update 482 | metrics.update(self.update_actor_vrl3(obs_combined.detach(), action, stddev, bc_weight, 483 | self.pretanh_penalty, self.pretanh_threshold)) 484 | 485 | metrics['batch_reward'] = reward.mean().item() 486 | 487 | # update critic target networks 488 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau) 489 | return metrics 490 | 491 | def update_critic_vrl3(self, obs, action, reward, discount, next_obs, stddev, update_encoder, conservative_loss_weight): 492 | metrics = dict() 493 | batch_size = obs.shape[0] 494 | 495 | """ 496 | STANDARD Q LOSS COMPUTATION: 497 | - get standard Q loss first, this is the same as in any other online RL methods 498 | - except for the safe Q technique, which controls how large the Q value can be 499 | """ 500 | with torch.no_grad(): 501 | dist = self.actor(next_obs, stddev) 502 | next_action = dist.sample(clip=self.stddev_clip) 503 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 504 | target_V = torch.min(target_Q1, target_Q2) 505 | target_Q = reward + (discount * target_V) 506 | 507 | if self.safe_q_target_factor < 1: 508 | target_Q[target_Q > (self.q_threshold + 1)] = self.q_threshold + (target_Q[target_Q > (self.q_threshold+1)] - self.q_threshold) ** self.safe_q_target_factor 509 | 510 | Q1, Q2 = self.critic(obs, action) 511 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 512 | 513 | """ 514 | CONSERVATIVE Q LOSS COMPUTATION: 515 | - sample random actions, actions from policy and next actions from policy, as done in CQL authors' code 516 | (though this detail is not really discussed in the CQL paper) 517 | - only compute this loss when conservative loss weight > 0 518 | """ 519 | if conservative_loss_weight > 0: 520 | random_actions = (torch.rand((batch_size * self.cql_n_random, self.act_dim), device=self.device) - 0.5) * 2 521 | 522 | dist = self.actor(obs, stddev) 523 | current_actions = dist.sample(clip=self.stddev_clip) 524 | 525 | dist = self.actor(next_obs, stddev) 526 | next_current_actions = dist.sample(clip=self.stddev_clip) 527 | 528 | # now get Q values for all these actions (for both Q networks) 529 | obs_repeat = obs.unsqueeze(1).repeat(1, self.cql_n_random, 1).view(obs.shape[0] * self.cql_n_random, 530 | obs.shape[1]) 531 | 532 | Q1_rand, Q2_rand = self.critic(obs_repeat, 533 | random_actions) # TODO might want to double check the logic here see if the repeat is correct 534 | Q1_rand = Q1_rand.view(obs.shape[0], self.cql_n_random) 535 | Q2_rand = Q2_rand.view(obs.shape[0], self.cql_n_random) 536 | 537 | Q1_curr, Q2_curr = self.critic(obs, current_actions) 538 | Q1_curr_next, Q2_curr_next = self.critic(obs, next_current_actions) 539 | 540 | # now concat all these Q values together 541 | Q1_cat = torch.cat([Q1_rand, Q1, Q1_curr, Q1_curr_next], 1) 542 | Q2_cat = torch.cat([Q2_rand, Q2, Q2_curr, Q2_curr_next], 1) 543 | 544 | cql_min_q1_loss = torch.logsumexp(Q1_cat / self.cql_temp, 545 | dim=1, ).mean() * conservative_loss_weight * self.cql_temp 546 | cql_min_q2_loss = torch.logsumexp(Q2_cat / self.cql_temp, 547 | dim=1, ).mean() * conservative_loss_weight * self.cql_temp 548 | 549 | """Subtract the log likelihood of data""" 550 | conservative_q_loss = cql_min_q1_loss + cql_min_q2_loss - (Q1.mean() + Q2.mean()) * conservative_loss_weight 551 | critic_loss_combined = critic_loss + conservative_q_loss 552 | else: 553 | critic_loss_combined = critic_loss 554 | 555 | # logging 556 | metrics['critic_target_q'] = target_Q.mean().item() 557 | metrics['critic_q1'] = Q1.mean().item() 558 | metrics['critic_q2'] = Q2.mean().item() 559 | metrics['critic_loss'] = critic_loss.item() 560 | 561 | # if needed, also update encoder with critic loss 562 | if update_encoder: 563 | self.encoder_opt.zero_grad(set_to_none=True) 564 | self.critic_opt.zero_grad(set_to_none=True) 565 | critic_loss_combined.backward() 566 | self.critic_opt.step() 567 | if update_encoder: 568 | self.encoder_opt.step() 569 | 570 | return metrics 571 | 572 | def update_actor_vrl3(self, obs, action, stddev, bc_weight, pretanh_penalty, pretanh_threshold): 573 | metrics = dict() 574 | 575 | """ 576 | get standard actor loss 577 | """ 578 | dist, pretanh = self.actor.forward_with_pretanh(obs, stddev) 579 | current_action = dist.sample(clip=self.stddev_clip) 580 | log_prob = dist.log_prob(current_action).sum(-1, keepdim=True) 581 | Q1, Q2 = self.critic(obs, current_action) 582 | Q = torch.min(Q1, Q2) 583 | actor_loss = -Q.mean() 584 | 585 | """ 586 | add BC loss 587 | """ 588 | if bc_weight > 0: 589 | # get mean action with no action noise (though this might not be necessary) 590 | stddev_bc = 0 591 | dist_bc = self.actor(obs, stddev_bc) 592 | current_mean_action = dist_bc.sample(clip=self.stddev_clip) 593 | actor_loss_bc = F.mse_loss(current_mean_action, action) * bc_weight 594 | else: 595 | actor_loss_bc = torch.FloatTensor([0]).to(self.device) 596 | 597 | """ 598 | add pretanh penalty (might not be necessary for Adroit) 599 | """ 600 | pretanh_loss = 0 601 | if pretanh_penalty > 0: 602 | pretanh_loss = pretanh.abs() - pretanh_threshold 603 | pretanh_loss[pretanh_loss < 0] = 0 604 | pretanh_loss = (pretanh_loss ** 2).mean() * pretanh_penalty 605 | 606 | """ 607 | combine actor losses and optimize 608 | """ 609 | actor_loss_combined = actor_loss + actor_loss_bc + pretanh_loss 610 | 611 | self.actor_opt.zero_grad(set_to_none=True) 612 | actor_loss_combined.backward() 613 | self.actor_opt.step() 614 | 615 | metrics['actor_loss'] = actor_loss.item() 616 | metrics['actor_loss_bc'] = actor_loss_bc.item() 617 | metrics['actor_logprob'] = log_prob.mean().item() 618 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() 619 | metrics['abs_pretanh'] = pretanh.abs().mean().item() 620 | metrics['max_abs_pretanh'] = pretanh.abs().max().item() 621 | 622 | return metrics 623 | --------------------------------------------------------------------------------